Просмотр исходного кода

Extract JwtDecoderFactory from JwtClientAssertionAuthenticationProvider

Closes gh-944
Joe Grandja 2 лет назад
Родитель
Сommit
8c2b095195

+ 18 - 138
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtClientAssertionAuthenticationProvider.java

@@ -15,53 +15,27 @@
  */
 package org.springframework.security.oauth2.server.authorization.authentication;
 
-import java.nio.charset.StandardCharsets;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.function.Predicate;
-
-import javax.crypto.spec.SecretKeySpec;
-
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
-import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
-import org.springframework.security.oauth2.core.OAuth2TokenValidator;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
-import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.security.oauth2.jwt.Jwt;
-import org.springframework.security.oauth2.jwt.JwtClaimNames;
-import org.springframework.security.oauth2.jwt.JwtClaimValidator;
 import org.springframework.security.oauth2.jwt.JwtDecoder;
 import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
 import org.springframework.security.oauth2.jwt.JwtException;
-import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
-import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
-import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContext;
-import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder;
-import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
 import org.springframework.util.Assert;
-import org.springframework.util.CollectionUtils;
-import org.springframework.util.StringUtils;
-import org.springframework.web.util.UriComponentsBuilder;
 
 /**
  * An {@link AuthenticationProvider} implementation used for OAuth 2.0 Client Authentication,
- * which authenticates the (JWT) {@link OAuth2ParameterNames#CLIENT_ASSERTION client_assertion} parameter.
+ * which authenticates the {@link Jwt} {@link OAuth2ParameterNames#CLIENT_ASSERTION client_assertion} parameter.
  *
  * @author Rafal Lewczuk
  * @author Joe Grandja
@@ -70,6 +44,7 @@ import org.springframework.web.util.UriComponentsBuilder;
  * @see OAuth2ClientAuthenticationToken
  * @see RegisteredClientRepository
  * @see OAuth2AuthorizationService
+ * @see JwtClientAssertionDecoderFactory
  */
 public final class JwtClientAssertionAuthenticationProvider implements AuthenticationProvider {
 	private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-3.2.1";
@@ -77,7 +52,7 @@ public final class JwtClientAssertionAuthenticationProvider implements Authentic
 			new ClientAuthenticationMethod("urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
 	private final RegisteredClientRepository registeredClientRepository;
 	private final CodeVerifierAuthenticator codeVerifierAuthenticator;
-	private final JwtClientAssertionDecoderFactory jwtClientAssertionDecoderFactory;
+	private JwtDecoderFactory<RegisteredClient> jwtDecoderFactory;
 
 	/**
 	 * Constructs a {@code JwtClientAssertionAuthenticationProvider} using the provided parameters.
@@ -91,7 +66,7 @@ public final class JwtClientAssertionAuthenticationProvider implements Authentic
 		Assert.notNull(authorizationService, "authorizationService cannot be null");
 		this.registeredClientRepository = registeredClientRepository;
 		this.codeVerifierAuthenticator = new CodeVerifierAuthenticator(authorizationService);
-		this.jwtClientAssertionDecoderFactory = new JwtClientAssertionDecoderFactory();
+		this.jwtDecoderFactory = new JwtClientAssertionDecoderFactory();
 	}
 
 	@Override
@@ -119,7 +94,7 @@ public final class JwtClientAssertionAuthenticationProvider implements Authentic
 		}
 
 		Jwt jwtAssertion = null;
-		JwtDecoder jwtDecoder = this.jwtClientAssertionDecoderFactory.createDecoder(registeredClient);
+		JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(registeredClient);
 		try {
 			jwtAssertion = jwtDecoder.decode(clientAuthentication.getCredentials().toString());
 		} catch (JwtException ex) {
@@ -142,6 +117,19 @@ public final class JwtClientAssertionAuthenticationProvider implements Authentic
 		return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication);
 	}
 
+	/**
+	 * Sets the {@link JwtDecoderFactory} that provides a {@link JwtDecoder} for the specified {@link RegisteredClient}
+	 * and is used for authenticating a {@link Jwt} Bearer Token during OAuth 2.0 Client Authentication.
+	 * The default factory is {@link JwtClientAssertionDecoderFactory}.
+	 *
+	 * @param jwtDecoderFactory the {@link JwtDecoderFactory} that provides a {@link JwtDecoder} for the specified {@link RegisteredClient}
+	 * @since 0.4.0
+	 */
+	public void setJwtDecoderFactory(JwtDecoderFactory<RegisteredClient> jwtDecoderFactory) {
+		Assert.notNull(jwtDecoderFactory, "jwtDecoderFactory cannot be null");
+		this.jwtDecoderFactory = jwtDecoderFactory;
+	}
+
 	private static void throwInvalidClient(String parameterName) {
 		throwInvalidClient(parameterName, null);
 	}
@@ -155,112 +143,4 @@ public final class JwtClientAssertionAuthenticationProvider implements Authentic
 		throw new OAuth2AuthenticationException(error, error.toString(), cause);
 	}
 
-	private static class JwtClientAssertionDecoderFactory implements JwtDecoderFactory<RegisteredClient> {
-		private static final String JWT_CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc7523#section-3";
-
-		private static final Map<JwsAlgorithm, String> JCA_ALGORITHM_MAPPINGS;
-
-		static {
-			Map<JwsAlgorithm, String> mappings = new HashMap<>();
-			mappings.put(MacAlgorithm.HS256, "HmacSHA256");
-			mappings.put(MacAlgorithm.HS384, "HmacSHA384");
-			mappings.put(MacAlgorithm.HS512, "HmacSHA512");
-			JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings);
-		}
-
-		private final Map<String, JwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
-
-		@Override
-		public JwtDecoder createDecoder(RegisteredClient registeredClient) {
-			Assert.notNull(registeredClient, "registeredClient cannot be null");
-			return this.jwtDecoders.computeIfAbsent(registeredClient.getId(), (key) -> {
-				NimbusJwtDecoder jwtDecoder = buildDecoder(registeredClient);
-				jwtDecoder.setJwtValidator(createJwtValidator(registeredClient));
-				return jwtDecoder;
-			});
-		}
-
-		private static NimbusJwtDecoder buildDecoder(RegisteredClient registeredClient) {
-			JwsAlgorithm jwsAlgorithm = registeredClient.getClientSettings().getTokenEndpointAuthenticationSigningAlgorithm();
-			if (jwsAlgorithm instanceof SignatureAlgorithm) {
-				String jwkSetUrl = registeredClient.getClientSettings().getJwkSetUrl();
-				if (!StringUtils.hasText(jwkSetUrl)) {
-					OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
-							"Failed to find a Signature Verifier for Client: '"
-									+ registeredClient.getId()
-									+ "'. Check to ensure you have configured the JWK Set URL.",
-							JWT_CLIENT_AUTHENTICATION_ERROR_URI);
-					throw new OAuth2AuthenticationException(oauth2Error);
-				}
-				return NimbusJwtDecoder.withJwkSetUri(jwkSetUrl).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build();
-			}
-			if (jwsAlgorithm instanceof MacAlgorithm) {
-				String clientSecret = registeredClient.getClientSecret();
-				if (!StringUtils.hasText(clientSecret)) {
-					OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
-							"Failed to find a Signature Verifier for Client: '"
-									+ registeredClient.getId()
-									+ "'. Check to ensure you have configured the client secret.",
-							JWT_CLIENT_AUTHENTICATION_ERROR_URI);
-					throw new OAuth2AuthenticationException(oauth2Error);
-				}
-				SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8),
-						JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm));
-				return NimbusJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build();
-			}
-			OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
-					"Failed to find a Signature Verifier for Client: '"
-							+ registeredClient.getId()
-							+ "'. Check to ensure you have configured a valid JWS Algorithm: '" + jwsAlgorithm + "'.",
-					JWT_CLIENT_AUTHENTICATION_ERROR_URI);
-			throw new OAuth2AuthenticationException(oauth2Error);
-		}
-
-		private static OAuth2TokenValidator<Jwt> createJwtValidator(RegisteredClient registeredClient) {
-			String clientId = registeredClient.getClientId();
-			return new DelegatingOAuth2TokenValidator<>(
-					new JwtClaimValidator<>(JwtClaimNames.ISS, clientId::equals),
-					new JwtClaimValidator<>(JwtClaimNames.SUB, clientId::equals),
-					new JwtClaimValidator<>(JwtClaimNames.AUD, containsAudience()),
-					new JwtClaimValidator<>(JwtClaimNames.EXP, Objects::nonNull),
-					new JwtTimestampValidator()
-			);
-		}
-
-		private static Predicate<List<String>> containsAudience() {
-			return (audienceClaim) -> {
-				if (CollectionUtils.isEmpty(audienceClaim)) {
-					return false;
-				}
-				List<String> audienceList = getAudience();
-				for (String audience : audienceClaim) {
-					if (audienceList.contains(audience)) {
-						return true;
-					}
-				}
-				return false;
-			};
-		}
-
-		private static List<String> getAudience() {
-			AuthorizationServerContext authorizationServerContext = AuthorizationServerContextHolder.getContext();
-			if (!StringUtils.hasText(authorizationServerContext.getIssuer())) {
-				return Collections.emptyList();
-			}
-
-			AuthorizationServerSettings authorizationServerSettings = authorizationServerContext.getAuthorizationServerSettings();
-			List<String> audience = new ArrayList<>();
-			audience.add(authorizationServerContext.getIssuer());
-			audience.add(asUrl(authorizationServerContext.getIssuer(), authorizationServerSettings.getTokenEndpoint()));
-			audience.add(asUrl(authorizationServerContext.getIssuer(), authorizationServerSettings.getTokenIntrospectionEndpoint()));
-			audience.add(asUrl(authorizationServerContext.getIssuer(), authorizationServerSettings.getTokenRevocationEndpoint()));
-			return audience;
-		}
-
-		private static String asUrl(String issuer, String endpoint) {
-			return UriComponentsBuilder.fromUriString(issuer).path(endpoint).build().toUriString();
-		}
-
-	}
-
 }

+ 198 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtClientAssertionDecoderFactory.java

@@ -0,0 +1,198 @@
+/*
+ * Copyright 2020-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Function;
+import java.util.function.Predicate;
+
+import javax.crypto.spec.SecretKeySpec;
+
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.OAuth2TokenValidator;
+import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
+import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
+import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
+import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtClaimNames;
+import org.springframework.security.oauth2.jwt.JwtClaimValidator;
+import org.springframework.security.oauth2.jwt.JwtDecoder;
+import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
+import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
+import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContext;
+import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder;
+import org.springframework.security.oauth2.server.authorization.settings.AuthorizationServerSettings;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.StringUtils;
+import org.springframework.web.util.UriComponentsBuilder;
+
+/**
+ * A {@link JwtDecoderFactory factory} that provides a {@link JwtDecoder} for the specified {@link RegisteredClient}
+ * and is used for authenticating a {@link Jwt} Bearer Token during OAuth 2.0 Client Authentication.
+ *
+ * @author Rafal Lewczuk
+ * @author Joe Grandja
+ * @since 0.4.0
+ * @see JwtDecoderFactory
+ * @see RegisteredClient
+ * @see OAuth2TokenValidator
+ * @see JwtClientAssertionAuthenticationProvider
+ * @see ClientAuthenticationMethod#PRIVATE_KEY_JWT
+ * @see ClientAuthenticationMethod#CLIENT_SECRET_JWT
+ */
+public final class JwtClientAssertionDecoderFactory implements JwtDecoderFactory<RegisteredClient> {
+
+	/**
+	 * The default {@code OAuth2TokenValidator<Jwt>} factory that validates the {@link JwtClaimNames#ISS iss},
+	 * {@link JwtClaimNames#SUB sub}, {@link JwtClaimNames#AUD aud}, {@link JwtClaimNames#EXP exp} and
+	 * {@link JwtClaimNames#NBF nbf} claims of the {@link Jwt} for the specified {@link RegisteredClient}.
+	 */
+	public static final Function<RegisteredClient, OAuth2TokenValidator<Jwt>> DEFAULT_JWT_VALIDATOR_FACTORY = defaultJwtValidatorFactory();
+
+	private static final String JWT_CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc7523#section-3";
+	private static final Map<JwsAlgorithm, String> JCA_ALGORITHM_MAPPINGS;
+
+	static {
+		Map<JwsAlgorithm, String> mappings = new HashMap<>();
+		mappings.put(MacAlgorithm.HS256, "HmacSHA256");
+		mappings.put(MacAlgorithm.HS384, "HmacSHA384");
+		mappings.put(MacAlgorithm.HS512, "HmacSHA512");
+		JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings);
+	}
+
+	private final Map<String, JwtDecoder> jwtDecoders = new ConcurrentHashMap<>();
+	private Function<RegisteredClient, OAuth2TokenValidator<Jwt>> jwtValidatorFactory = DEFAULT_JWT_VALIDATOR_FACTORY;
+
+	@Override
+	public JwtDecoder createDecoder(RegisteredClient registeredClient) {
+		Assert.notNull(registeredClient, "registeredClient cannot be null");
+		return this.jwtDecoders.computeIfAbsent(registeredClient.getId(), (key) -> {
+			NimbusJwtDecoder jwtDecoder = buildDecoder(registeredClient);
+			jwtDecoder.setJwtValidator(this.jwtValidatorFactory.apply(registeredClient));
+			return jwtDecoder;
+		});
+	}
+
+	/**
+	 * Sets the factory that provides an {@link OAuth2TokenValidator}
+	 * for the specified {@link RegisteredClient} and is used by the {@link JwtDecoder}.
+	 * The default {@code OAuth2TokenValidator<Jwt>} factory is {@link #DEFAULT_JWT_VALIDATOR_FACTORY}.
+	 *
+	 * @param jwtValidatorFactory the factory that provides an {@link OAuth2TokenValidator} for the specified {@link RegisteredClient}
+	 */
+	public void setJwtValidatorFactory(Function<RegisteredClient, OAuth2TokenValidator<Jwt>> jwtValidatorFactory) {
+		Assert.notNull(jwtValidatorFactory, "jwtValidatorFactory cannot be null");
+		this.jwtValidatorFactory = jwtValidatorFactory;
+	}
+
+	private static NimbusJwtDecoder buildDecoder(RegisteredClient registeredClient) {
+		JwsAlgorithm jwsAlgorithm = registeredClient.getClientSettings().getTokenEndpointAuthenticationSigningAlgorithm();
+		if (jwsAlgorithm instanceof SignatureAlgorithm) {
+			String jwkSetUrl = registeredClient.getClientSettings().getJwkSetUrl();
+			if (!StringUtils.hasText(jwkSetUrl)) {
+				OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
+						"Failed to find a Signature Verifier for Client: '"
+								+ registeredClient.getId()
+								+ "'. Check to ensure you have configured the JWK Set URL.",
+						JWT_CLIENT_AUTHENTICATION_ERROR_URI);
+				throw new OAuth2AuthenticationException(oauth2Error);
+			}
+			return NimbusJwtDecoder.withJwkSetUri(jwkSetUrl).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build();
+		}
+		if (jwsAlgorithm instanceof MacAlgorithm) {
+			String clientSecret = registeredClient.getClientSecret();
+			if (!StringUtils.hasText(clientSecret)) {
+				OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
+						"Failed to find a Signature Verifier for Client: '"
+								+ registeredClient.getId()
+								+ "'. Check to ensure you have configured the client secret.",
+						JWT_CLIENT_AUTHENTICATION_ERROR_URI);
+				throw new OAuth2AuthenticationException(oauth2Error);
+			}
+			SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8),
+					JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm));
+			return NimbusJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build();
+		}
+		OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
+				"Failed to find a Signature Verifier for Client: '"
+						+ registeredClient.getId()
+						+ "'. Check to ensure you have configured a valid JWS Algorithm: '" + jwsAlgorithm + "'.",
+				JWT_CLIENT_AUTHENTICATION_ERROR_URI);
+		throw new OAuth2AuthenticationException(oauth2Error);
+	}
+
+	private static Function<RegisteredClient, OAuth2TokenValidator<Jwt>> defaultJwtValidatorFactory() {
+		return (registeredClient) -> {
+			String clientId = registeredClient.getClientId();
+			return new DelegatingOAuth2TokenValidator<>(
+					new JwtClaimValidator<>(JwtClaimNames.ISS, clientId::equals),
+					new JwtClaimValidator<>(JwtClaimNames.SUB, clientId::equals),
+					new JwtClaimValidator<>(JwtClaimNames.AUD, containsAudience()),
+					new JwtClaimValidator<>(JwtClaimNames.EXP, Objects::nonNull),
+					new JwtTimestampValidator()
+			);
+		};
+	}
+
+	private static Predicate<List<String>> containsAudience() {
+		return (audienceClaim) -> {
+			if (CollectionUtils.isEmpty(audienceClaim)) {
+				return false;
+			}
+			List<String> audienceList = getAudience();
+			for (String audience : audienceClaim) {
+				if (audienceList.contains(audience)) {
+					return true;
+				}
+			}
+			return false;
+		};
+	}
+
+	private static List<String> getAudience() {
+		AuthorizationServerContext authorizationServerContext = AuthorizationServerContextHolder.getContext();
+		if (!StringUtils.hasText(authorizationServerContext.getIssuer())) {
+			return Collections.emptyList();
+		}
+
+		AuthorizationServerSettings authorizationServerSettings = authorizationServerContext.getAuthorizationServerSettings();
+		List<String> audience = new ArrayList<>();
+		audience.add(authorizationServerContext.getIssuer());
+		audience.add(asUrl(authorizationServerContext.getIssuer(), authorizationServerSettings.getTokenEndpoint()));
+		audience.add(asUrl(authorizationServerContext.getIssuer(), authorizationServerSettings.getTokenIntrospectionEndpoint()));
+		audience.add(asUrl(authorizationServerContext.getIssuer(), authorizationServerSettings.getTokenRevocationEndpoint()));
+		return audience;
+	}
+
+	private static String asUrl(String issuer, String endpoint) {
+		return UriComponentsBuilder.fromUriString(issuer).path(endpoint).build().toUriString();
+	}
+
+}

+ 7 - 79
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/JwtClientAssertionAuthenticationProviderTests.java

@@ -41,7 +41,6 @@ import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.jose.TestJwks;
 import org.springframework.security.oauth2.jose.TestKeys;
 import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
-import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.security.oauth2.jwt.BadJwtException;
 import org.springframework.security.oauth2.jwt.JwsHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
@@ -117,6 +116,13 @@ public class JwtClientAssertionAuthenticationProviderTests {
 				.hasMessage("authorizationService cannot be null");
 	}
 
+	@Test
+	public void setJwtDecoderFactoryWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authenticationProvider.setJwtDecoderFactory(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("jwtDecoderFactory cannot be null");
+	}
+
 	@Test
 	public void supportsWhenTypeOAuth2ClientAuthenticationTokenThenReturnTrue() {
 		assertThat(this.authenticationProvider.supports(OAuth2ClientAuthenticationToken.class)).isTrue();
@@ -181,84 +187,6 @@ public class JwtClientAssertionAuthenticationProviderTests {
 				});
 	}
 
-	@Test
-	public void authenticateWhenMissingJwkSetUrlThenThrowOAuth2AuthenticationException() {
-		// @formatter:off
-		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
-				.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
-				.clientSettings(
-						ClientSettings.builder()
-								.tokenEndpointAuthenticationSigningAlgorithm(SignatureAlgorithm.RS256)
-								.build()
-				)
-				.build();
-		// @formatter:on
-		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
-				.thenReturn(registeredClient);
-
-		OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
-				registeredClient.getClientId(), JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD, "jwt-assertion", null);
-		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
-				.isInstanceOf(OAuth2AuthenticationException.class)
-				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.satisfies(error -> {
-					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
-					assertThat(error.getDescription()).isEqualTo("Failed to find a Signature Verifier for Client: '" +
-							registeredClient.getId() + "'. Check to ensure you have configured the JWK Set URL.");
-				});
-	}
-
-	@Test
-	public void authenticateWhenMissingClientSecretThenThrowOAuth2AuthenticationException() {
-		// @formatter:off
-		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
-				.clientSecret(null)
-				.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
-				.clientSettings(
-						ClientSettings.builder()
-								.tokenEndpointAuthenticationSigningAlgorithm(MacAlgorithm.HS256)
-								.build()
-				)
-				.build();
-		// @formatter:on
-		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
-				.thenReturn(registeredClient);
-
-		OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
-				registeredClient.getClientId(), JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD, "jwt-assertion", null);
-		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
-				.isInstanceOf(OAuth2AuthenticationException.class)
-				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.satisfies(error -> {
-					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
-					assertThat(error.getDescription()).isEqualTo("Failed to find a Signature Verifier for Client: '" +
-							registeredClient.getId() + "'. Check to ensure you have configured the client secret.");
-				});
-	}
-
-	@Test
-	public void authenticateWhenMissingSigningAlgorithmThenThrowOAuth2AuthenticationException() {
-		// @formatter:off
-		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
-				.clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY)
-				.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
-				.build();
-		// @formatter:on
-		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
-				.thenReturn(registeredClient);
-
-		OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
-				registeredClient.getClientId(), JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD, "jwt-assertion", null);
-		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
-				.isInstanceOf(OAuth2AuthenticationException.class)
-				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.satisfies(error -> {
-					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
-					assertThat(error.getDescription()).isEqualTo("Failed to find a Signature Verifier for Client: '" +
-							registeredClient.getId() + "'. Check to ensure you have configured a valid JWS Algorithm: 'null'.");
-				});
-	}
-
 	@Test
 	public void authenticateWhenInvalidCredentialsThenThrowOAuth2AuthenticationException() {
 		// @formatter:off

+ 112 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/JwtClientAssertionDecoderFactoryTests.java

@@ -0,0 +1,112 @@
+/*
+ * Copyright 2020-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+import org.junit.Test;
+
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
+import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.settings.ClientSettings;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Tests for {@link JwtClientAssertionDecoderFactory}.
+ *
+ * @author Joe Grandja
+ */
+public class JwtClientAssertionDecoderFactoryTests {
+	private JwtClientAssertionDecoderFactory jwtDecoderFactory = new JwtClientAssertionDecoderFactory();
+
+	@Test
+	public void setJwtValidatorFactoryWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.jwtDecoderFactory.setJwtValidatorFactory(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("jwtValidatorFactory cannot be null");
+	}
+
+	@Test
+	public void createDecoderWhenMissingJwkSetUrlThenThrowOAuth2AuthenticationException() {
+		// @formatter:off
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
+				.clientSettings(
+						ClientSettings.builder()
+								.tokenEndpointAuthenticationSigningAlgorithm(SignatureAlgorithm.RS256)
+								.build()
+				)
+				.build();
+		// @formatter:on
+
+		assertThatThrownBy(() -> this.jwtDecoderFactory.createDecoder(registeredClient))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+					assertThat(error.getDescription()).isEqualTo("Failed to find a Signature Verifier for Client: '" +
+							registeredClient.getId() + "'. Check to ensure you have configured the JWK Set URL.");
+				});
+	}
+
+	@Test
+	public void createDecoderWhenMissingClientSecretThenThrowOAuth2AuthenticationException() {
+		// @formatter:off
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientSecret(null)
+				.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
+				.clientSettings(
+						ClientSettings.builder()
+								.tokenEndpointAuthenticationSigningAlgorithm(MacAlgorithm.HS256)
+								.build()
+				)
+				.build();
+		// @formatter:on
+
+		assertThatThrownBy(() -> this.jwtDecoderFactory.createDecoder(registeredClient))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+					assertThat(error.getDescription()).isEqualTo("Failed to find a Signature Verifier for Client: '" +
+							registeredClient.getId() + "'. Check to ensure you have configured the client secret.");
+				});
+	}
+
+	@Test
+	public void createDecoderWhenMissingSigningAlgorithmThenThrowOAuth2AuthenticationException() {
+		// @formatter:off
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
+				.build();
+		// @formatter:on
+
+		assertThatThrownBy(() -> this.jwtDecoderFactory.createDecoder(registeredClient))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+					assertThat(error.getDescription()).isEqualTo("Failed to find a Signature Verifier for Client: '" +
+							registeredClient.getId() + "'. Check to ensure you have configured a valid JWS Algorithm: 'null'.");
+				});
+	}
+
+}