Browse Source

Client authentication with JWT assertion

Closes gh-59
Rafal Lewczuk 3 years ago
parent
commit
16e4f5130b
32 changed files with 1798 additions and 51 deletions
  1. 1 0
      oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle
  2. 24 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientMetadataClaimAccessor.java
  3. 14 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientMetadataClaimNames.java
  4. 24 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientRegistration.java
  5. 2 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverter.java
  6. 192 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java
  7. 37 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java
  8. 50 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ClientSettings.java
  9. 11 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ConfigurationSettingNames.java
  10. 33 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/jackson2/MacAlgorithmMixin.java
  11. 2 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/jackson2/OAuth2AuthorizationServerJackson2Module.java
  12. 67 5
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java
  13. 2 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcProviderConfigurationEndpointFilter.java
  14. 2 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationServerMetadataEndpointFilter.java
  15. 2 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java
  16. 1 14
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretBasicAuthenticationConverter.java
  17. 4 14
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java
  18. 92 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java
  19. 12 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java
  20. 187 6
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcClientRegistrationTests.java
  21. 65 0
      oauth2-authorization-server/src/test/java/org/springframework/security/config/util/ValueCaptureMatcher.java
  22. 9 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/OidcClientRegistrationTests.java
  23. 8 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverterTests.java
  24. 164 2
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java
  25. 11 3
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationTokenTests.java
  26. 436 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/RegisteredClientJwtAssertionDecoderFactoryTests.java
  27. 12 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java
  28. 21 2
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/ClientSettingsTests.java
  29. 161 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java
  30. 1 1
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcProviderConfigurationEndpointFilterTests.java
  31. 3 3
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationServerMetadataEndpointFilterTests.java
  32. 148 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverterTests.java

+ 1 - 0
oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle

@@ -19,6 +19,7 @@ dependencies {
 	testCompile 'org.assertj:assertj-core'
 	testCompile 'org.mockito:mockito-core'
 	testCompile 'com.jayway.jsonpath:json-path'
+	testCompile 'com.squareup.okhttp3:mockwebserver'
 
 	testRuntime 'org.hsqldb:hsqldb'
 

+ 24 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientMetadataClaimAccessor.java

@@ -99,6 +99,20 @@ public interface OidcClientMetadataClaimAccessor extends ClaimAccessor {
 		return getClaimAsString(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD);
 	}
 
+	/**
+	 * Returns the {@link SignatureAlgorithm JWS} algorithm that must be used for signing the JWT used to authenticate
+	 * the Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt} authentication
+	 * methods {@code (token_endpoint_auth_signing_alg)}
+	 *
+	 * @return the {@link SignatureAlgorithm JWS} algorithm that must be used for signing the JWT used to authenticate
+	 * 	       the Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt}
+	 * 	       authentication methods {@code (token_endpoint_auth_signing_alg)}
+	 * @since 0.2.1
+	 */
+	default String getTokenEndpointAuthenticationSigningAlgorithm() {
+		return getClaimAsString(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG);
+	}
+
 	/**
 	 * Returns the OAuth 2.0 {@code grant_type} values that the Client will restrict itself to using {@code (grant_types)}.
 	 *
@@ -155,4 +169,14 @@ public interface OidcClientMetadataClaimAccessor extends ClaimAccessor {
 		return getClaimAsURL(OidcClientMetadataClaimNames.REGISTRATION_CLIENT_URI);
 	}
 
+	/**
+	 * Returns {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
+	 *
+	 * @return {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
+	 * @since 0.2.1
+	 */
+	default URL getJwkSetUrl() {
+		return getClaimAsURL(OidcClientMetadataClaimNames.JWKS_URI);
+	}
+
 }

+ 14 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientMetadataClaimNames.java

@@ -16,6 +16,7 @@
 package org.springframework.security.oauth2.core.oidc;
 
 import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
+import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 
 /**
  * The names of the "claims" defined by OpenID Connect Dynamic Client Registration 1.0
@@ -95,4 +96,17 @@ public interface OidcClientMetadataClaimNames {
 	 */
 	String REGISTRATION_CLIENT_URI = "registration_client_uri";
 
+	/**
+	 * {@code jwks_uri} - {@code URL} for the Client's JSON Web Key Set
+	 * @since 0.2.1
+	 */
+	String JWKS_URI = "jwks_uri";
+
+	/**
+	 * {@code token_endpoint_auth_signing_alg} - {@link SignatureAlgorithm JWS} algorithm that must be used for signing
+	 * the JWT used to authenticate the Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt}
+	 * authentication methods
+	 * @since 0.2.1
+	 */
+	String TOKEN_ENDPOINT_AUTH_SIGNING_ALG = "token_endpoint_auth_signing_alg";
 }

+ 24 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/OidcClientRegistration.java

@@ -172,6 +172,20 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce
 			return claim(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, tokenEndpointAuthenticationMethod);
 		}
 
+		/**
+		 * Sets the {@link SignatureAlgorithm JWS} algorithm that must be used for signing the JWT used to authenticate
+		 * the Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt} authentication
+		 * methods
+		 * @param signingAlgorithm the {@link SignatureAlgorithm JWS} algorithm that must be used for signing
+		 *        the JWT used to authenticate the Client at the Token Endpoint for the {@code private_key_jwt} and
+		 *        {@code client_secret_jwt} authentication methods
+		 * @return the {@link Builder} for further configuration
+		 * @since 0.2.1
+		 */
+		public Builder tokenEndpointAuthenticationSigningAlgorithm(String signingAlgorithm) {
+			return claim(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG, signingAlgorithm);
+		}
+
 		/**
 		 * Add the OAuth 2.0 {@code grant_type} that the Client will restrict itself to using, OPTIONAL.
 		 *
@@ -273,6 +287,16 @@ public final class OidcClientRegistration implements OidcClientMetadataClaimAcce
 			return claim(OidcClientMetadataClaimNames.REGISTRATION_CLIENT_URI, registrationClientUrl);
 		}
 
+		/**
+		 * Sets {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
+		 * @param jwksSetUrl {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
+		 * @return the {@link Builder} for further configuration
+		 * @since 0.2.1
+		 */
+		public Builder jwkSetUrl(String jwksSetUrl) {
+			return claim(OidcClientMetadataClaimNames.JWKS_URI, jwksSetUrl);
+		}
+
 		/**
 		 * Sets the claim.
 		 *

+ 2 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverter.java

@@ -150,6 +150,8 @@ public class OidcClientRegistrationHttpMessageConverter extends AbstractHttpMess
 			claimConverters.put(OidcClientMetadataClaimNames.RESPONSE_TYPES, collectionStringConverter);
 			claimConverters.put(OidcClientMetadataClaimNames.SCOPE, MapOidcClientRegistrationConverter::convertScope);
 			claimConverters.put(OidcClientMetadataClaimNames.ID_TOKEN_SIGNED_RESPONSE_ALG, stringConverter);
+			claimConverters.put(OidcClientMetadataClaimNames.JWKS_URI, stringConverter);
+			claimConverters.put(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG, stringConverter);
 			this.claimTypeConverter = new ClaimTypeConverter(claimConverters);
 		}
 

+ 192 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProvider.java

@@ -19,34 +19,59 @@ import java.nio.charset.StandardCharsets;
 import java.security.MessageDigest;
 import java.security.NoSuchAlgorithmException;
 import java.util.Base64;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
+import java.util.Objects;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Function;
 
+import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.crypto.factory.PasswordEncoderFactories;
 import org.springframework.security.crypto.password.PasswordEncoder;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.DelegatingOAuth2TokenValidator;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2TokenType;
+import org.springframework.security.oauth2.core.OAuth2TokenValidator;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
+import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
+import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
+import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
+import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtClaimValidator;
+import org.springframework.security.oauth2.jwt.JwtDecoder;
+import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
+import org.springframework.security.oauth2.jwt.JwtException;
+import org.springframework.security.oauth2.jwt.JwtTimestampValidator;
+import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 
+import javax.crypto.spec.SecretKeySpec;
+
 /**
  * An {@link AuthenticationProvider} implementation used for authenticating an OAuth 2.0 Client.
  *
  * @author Joe Grandja
  * @author Patryk Kostrzewa
  * @author Daniel Garnier-Moiroux
+ * @author Rafal Lewczuk
  * @since 0.0.1
  * @see AuthenticationProvider
  * @see OAuth2ClientAuthenticationToken
@@ -56,9 +81,15 @@ import org.springframework.util.StringUtils;
  */
 public final class OAuth2ClientAuthenticationProvider implements AuthenticationProvider {
 	private static final String CLIENT_AUTHENTICATION_ERROR_URI = "https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-01#section-3.2.1";
+
+	private static final ClientAuthenticationMethod JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD =
+			new ClientAuthenticationMethod("urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
+
 	private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
 	private final RegisteredClientRepository registeredClientRepository;
 	private final OAuth2AuthorizationService authorizationService;
+	private JwtDecoderFactory<RegisteredClient> jwtDecoderFactory;
+	private ProviderSettings providerSettings;
 	private PasswordEncoder passwordEncoder;
 
 	/**
@@ -74,6 +105,7 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
 		this.registeredClientRepository = registeredClientRepository;
 		this.authorizationService = authorizationService;
 		this.passwordEncoder = PasswordEncoderFactories.createDelegatingPasswordEncoder();
+		this.jwtDecoderFactory = new RegisteredClientJwtAssertionDecoderFactory();
 	}
 
 	/**
@@ -89,11 +121,25 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
 		this.passwordEncoder = passwordEncoder;
 	}
 
+	@Autowired
+	protected void setProviderSettings(ProviderSettings providerSettings) {
+		this.providerSettings = providerSettings;
+	}
+
 	@Override
 	public Authentication authenticate(Authentication authentication) throws AuthenticationException {
 		OAuth2ClientAuthenticationToken clientAuthentication =
 				(OAuth2ClientAuthenticationToken) authentication;
 
+		return JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD.equals(clientAuthentication.getClientAuthenticationMethod()) ?
+				authenticateClientAssertion(authentication) :
+				authenticationClientCredentials(authentication);
+	}
+
+	private Authentication authenticationClientCredentials(Authentication authentication) throws AuthenticationException {
+		OAuth2ClientAuthenticationToken clientAuthentication =
+				(OAuth2ClientAuthenticationToken) authentication;
+
 		String clientId = clientAuthentication.getPrincipal().toString();
 		RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
 		if (registeredClient == null) {
@@ -125,6 +171,64 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
 				clientAuthentication.getClientAuthenticationMethod(), clientAuthentication.getCredentials());
 	}
 
+	private Authentication authenticateClientAssertion(Authentication authentication) throws AuthenticationException {
+		OAuth2ClientAuthenticationToken clientAuthentication =
+				(OAuth2ClientAuthenticationToken) authentication;
+
+		String clientId = clientAuthentication.getPrincipal().toString();
+		RegisteredClient registeredClient = this.registeredClientRepository.findByClientId(clientId);
+		if (registeredClient == null) {
+			throwInvalidClient(OAuth2ParameterNames.CLIENT_ID);
+		}
+
+		Set<ClientAuthenticationMethod> allowedAuthenticationMethods = registeredClient.getClientAuthenticationMethods();
+
+		if (!allowedAuthenticationMethods.contains(ClientAuthenticationMethod.CLIENT_SECRET_JWT) &&
+				!allowedAuthenticationMethods.contains(ClientAuthenticationMethod.PRIVATE_KEY_JWT)) {
+			throwInvalidClient("authentication_method");
+		}
+
+		boolean credentialsAuthenticated = false;
+
+		try {
+			JwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(registeredClient);
+			Jwt jwt = jwtDecoder.decode(clientAuthentication.getCredentials().toString());
+			List<String> aud = jwt.getClaimAsStringList("aud");
+			String issuer = getIssuerUri(clientAuthentication.getRequestUri());
+			if (aud == null || !aud.contains(issuer)) {
+				throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION);
+			}
+			credentialsAuthenticated = true;
+		} catch (JwtException e) {
+			throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION);
+		}
+
+		boolean pkceAuthenticated = authenticatePkceIfAvailable(clientAuthentication, registeredClient);
+		credentialsAuthenticated = credentialsAuthenticated || pkceAuthenticated;
+		if (!credentialsAuthenticated) {
+			throwInvalidClient("credentials");
+		}
+
+		JwsAlgorithm tokenEndpointSigningAlgorithm = registeredClient.getClientSettings().getTokenEndpointSigningAlgorithm();
+		ClientAuthenticationMethod clientAuthentiationMethod = tokenEndpointSigningAlgorithm instanceof MacAlgorithm ?
+				ClientAuthenticationMethod.CLIENT_SECRET_JWT : ClientAuthenticationMethod.PRIVATE_KEY_JWT;
+
+		return new OAuth2ClientAuthenticationToken(registeredClient,
+				clientAuthentiationMethod, clientAuthentication.getCredentials());
+	}
+
+	private String getIssuerUri(String requestUri) throws AuthenticationException {
+		if (requestUri.endsWith(providerSettings.getTokenEndpoint())) {
+			return providerSettings.getIssuer() + providerSettings.getTokenEndpoint();
+		} else if (requestUri.endsWith(providerSettings.getTokenIntrospectionEndpoint())) {
+			return providerSettings.getIssuer() + providerSettings.getTokenIntrospectionEndpoint();
+		} else if (requestUri.endsWith(providerSettings.getTokenRevocationEndpoint())) {
+			return providerSettings.getIssuer() + providerSettings.getTokenRevocationEndpoint();
+		}
+		throwInvalidClient(OAuth2ParameterNames.CLIENT_ASSERTION);
+		return null;
+	}
+
 	@Override
 	public boolean supports(Class<?> authentication) {
 		return OAuth2ClientAuthenticationToken.class.isAssignableFrom(authentication);
@@ -201,4 +305,92 @@ public final class OAuth2ClientAuthenticationProvider implements AuthenticationP
 		throw new OAuth2AuthenticationException(error);
 	}
 
+	private static class CachedJwtDecoder {
+		private final NimbusJwtDecoder jwtDecoder;
+		private final RegisteredClient registeredClient;
+
+		CachedJwtDecoder(NimbusJwtDecoder jwtDecoder, RegisteredClient registeredClient) {
+			this.jwtDecoder = jwtDecoder;
+			this.registeredClient = registeredClient;
+		}
+	}
+
+	private static class RegisteredClientJwtAssertionDecoderFactory implements JwtDecoderFactory<RegisteredClient> {
+
+		private static final String CLIENT_ASSERTION_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc7523#section-3";
+
+		private static final Map<JwsAlgorithm, String> JCA_ALGORITHM_MAPPINGS;
+
+		static {
+			Map<JwsAlgorithm, String> mappings = new HashMap<>();
+			mappings.put(MacAlgorithm.HS256, "HmacSHA256");
+			mappings.put(MacAlgorithm.HS384, "HmacSHA384");
+			mappings.put(MacAlgorithm.HS512, "HmacSHA512");
+			JCA_ALGORITHM_MAPPINGS = Collections.unmodifiableMap(mappings);
+		}
+
+		private final Function<RegisteredClient, JwsAlgorithm> jwsAlgorithmResolver =
+				rc -> rc.getClientSettings().getTokenEndpointSigningAlgorithm();
+
+		private final Map<String, CachedJwtDecoder> cachedDecoders = new ConcurrentHashMap<>();
+
+		@Override
+		public JwtDecoder createDecoder(RegisteredClient registeredClient) {
+			Assert.notNull(registeredClient, "registeredClient cannot be null");
+
+			CachedJwtDecoder cachedDecoder = this.cachedDecoders.get(registeredClient.getClientId());
+			if (cachedDecoder != null && registeredClient.equals(cachedDecoder.registeredClient)) {
+				return cachedDecoder.jwtDecoder;
+			}
+
+			cachedDecoder = new CachedJwtDecoder(buildDecoder(registeredClient), registeredClient);
+			cachedDecoder.jwtDecoder.setJwtValidator(createTokenValidator(registeredClient));
+			this.cachedDecoders.put(registeredClient.getClientId(), cachedDecoder);
+			return cachedDecoder.jwtDecoder;
+		}
+
+		private NimbusJwtDecoder buildDecoder(RegisteredClient registeredClient) {
+			JwsAlgorithm jwsAlgorithm = this.jwsAlgorithmResolver.apply(registeredClient);
+
+			if (jwsAlgorithm != null && SignatureAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
+				String jwkSetUrl = registeredClient.getClientSettings().getJwkSetUrl();
+				if (!StringUtils.hasText(jwkSetUrl)) {
+					OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
+							"misconfigured client", CLIENT_ASSERTION_ERROR_URI);
+					throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
+				}
+				return NimbusJwtDecoder.withJwkSetUri(jwkSetUrl).jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm).build();
+			}
+
+			if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
+				String clientSecret = registeredClient.getClientSecret();
+				if (!StringUtils.hasText(clientSecret)) {
+					OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
+							"misconfigured client", CLIENT_ASSERTION_ERROR_URI);
+					throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
+				}
+				SecretKeySpec secretKeySpec = new SecretKeySpec(clientSecret.getBytes(StandardCharsets.UTF_8),
+						JCA_ALGORITHM_MAPPINGS.get(jwsAlgorithm));
+				return NimbusJwtDecoder.withSecretKey(secretKeySpec).macAlgorithm((MacAlgorithm) jwsAlgorithm).build();
+			}
+
+			OAuth2Error oauth2Error = new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT,
+					"misconfigured client", CLIENT_ASSERTION_ERROR_URI);
+			throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
+		}
+
+		private OAuth2TokenValidator<Jwt> createTokenValidator(RegisteredClient registeredClient) {
+			String clientId = registeredClient.getClientId();
+			return new DelegatingOAuth2TokenValidator<>(
+					new JwtClaimValidator<String>("iss", clientId::equals),      // RFC 7523 section 3 (iss)
+					new JwtClaimValidator<String>("sub", clientId::equals),      // RFC 7523 section 3 (sub)
+					new JwtClaimValidator<>("exp", Objects::nonNull),            // RFC 7523 section 3 (exp != null)
+					new JwtTimestampValidator()                                  // RFC 7523 section 3 (exp, nbf)
+			);
+			// The `aud` claim is not verified here
+
+			// TODO RFC 7523 section 3 #7: JWT may contain "jti" claim that provides unique identified for the token (OPTIONAL)
+		}
+	}
+
 }

+ 37 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationToken.java

@@ -41,6 +41,7 @@ import org.springframework.util.Assert;
 @Transient
 public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken {
 	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
+	private final String requestUri;
 	private final String clientId;
 	private final RegisteredClient registeredClient;
 	private final ClientAuthenticationMethod clientAuthenticationMethod;
@@ -55,11 +56,37 @@ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken
 	 * @param credentials the client credentials
 	 * @param additionalParameters the additional parameters
 	 */
+	@Deprecated
 	public OAuth2ClientAuthenticationToken(String clientId, ClientAuthenticationMethod clientAuthenticationMethod,
 			@Nullable Object credentials, @Nullable Map<String, Object> additionalParameters) {
 		super(Collections.emptyList());
 		Assert.hasText(clientId, "clientId cannot be empty");
 		Assert.notNull(clientAuthenticationMethod, "clientAuthenticationMethod cannot be null");
+		this.requestUri = null;
+		this.clientId = clientId;
+		this.registeredClient = null;
+		this.clientAuthenticationMethod = clientAuthenticationMethod;
+		this.credentials = credentials;
+		this.additionalParameters = Collections.unmodifiableMap(
+				additionalParameters != null ? additionalParameters : Collections.emptyMap());
+	}
+
+	/**
+	 * Constructs an {@code OAuth2ClientAuthenticationToken} using the provided parameters.
+	 *
+	 * @param requestUri the issuer identifier
+	 * @param clientId the client identifier
+	 * @param clientAuthenticationMethod the authentication method used by the client
+	 * @param credentials the client credentials
+	 * @param additionalParameters the additional parameters
+	 */
+	public OAuth2ClientAuthenticationToken(String requestUri, String clientId, ClientAuthenticationMethod clientAuthenticationMethod,
+			@Nullable Object credentials, @Nullable Map<String, Object> additionalParameters) {
+		super(Collections.emptyList());
+		Assert.hasText(requestUri, "requestUri cannot be empty");
+		Assert.hasText(clientId, "clientId cannot be empty");
+		Assert.notNull(clientAuthenticationMethod, "clientAuthenticationMethod cannot be null");
+		this.requestUri = requestUri;
 		this.clientId = clientId;
 		this.registeredClient = null;
 		this.clientAuthenticationMethod = clientAuthenticationMethod;
@@ -80,6 +107,7 @@ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken
 		super(Collections.emptyList());
 		Assert.notNull(registeredClient, "registeredClient cannot be null");
 		Assert.notNull(clientAuthenticationMethod, "clientAuthenticationMethod cannot be null");
+		this.requestUri = null;
 		this.clientId = registeredClient.getClientId();
 		this.registeredClient = registeredClient;
 		this.clientAuthenticationMethod = clientAuthenticationMethod;
@@ -118,6 +146,15 @@ public class OAuth2ClientAuthenticationToken extends AbstractAuthenticationToken
 		return this.clientAuthenticationMethod;
 	}
 
+	/**
+	 * Returns URI of authenticated request. This is used to validate aud claim of JWT client assertions
+	 *
+	 * @return URI of authenticated request
+	 */
+	public String getRequestUri() {
+		return requestUri;
+	}
+
 	/**
 	 * Returns the additional parameters.
 	 *

+ 50 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ClientSettings.java

@@ -17,6 +17,8 @@ package org.springframework.security.oauth2.server.authorization.config;
 
 import java.util.Map;
 
+import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
+import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.util.Assert;
 
 /**
@@ -53,6 +55,27 @@ public final class ClientSettings extends AbstractSettings {
 		return getSetting(ConfigurationSettingNames.Client.REQUIRE_AUTHORIZATION_CONSENT);
 	}
 
+	/**
+	 * Returns {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
+	 * @return {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
+	 * @since 0.2.1
+	 */
+	public String getJwkSetUrl() {
+		return getSetting(ConfigurationSettingNames.Client.JWK_SET_URL);
+	}
+
+	/**
+	 * Returns {@link SignatureAlgorithm JWS} algorithm that must be used for signing the JWT used to authenticate the
+	 * Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt} authentication methods
+	 * @return {@link SignatureAlgorithm JWS} algorithm that must be used for signing the JWT used to authenticate the
+	 * 	       Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt} authentication
+	 * 	       methods
+	 * @since 0.2.1
+	 */
+	public JwsAlgorithm getTokenEndpointSigningAlgorithm() {
+		return getSetting(ConfigurationSettingNames.Client.TOKEN_ENDPOINT_SIGNING_ALGORITHM);
+	}
+
 	/**
 	 * Constructs a new {@link Builder} with the default settings.
 	 *
@@ -61,7 +84,8 @@ public final class ClientSettings extends AbstractSettings {
 	public static Builder builder() {
 		return new Builder()
 				.requireProofKey(false)
-				.requireAuthorizationConsent(false);
+				.requireAuthorizationConsent(false)
+				.tokenEndpointSigningAlgorithm(SignatureAlgorithm.RS256);
 	}
 
 	/**
@@ -106,6 +130,31 @@ public final class ClientSettings extends AbstractSettings {
 			return setting(ConfigurationSettingNames.Client.REQUIRE_AUTHORIZATION_CONSENT, requireAuthorizationConsent);
 		}
 
+		/**
+		 * Sets {@code URL} for the Client's JSON Web Key Set
+		 *
+		 * @param jwkSetUrl {@code URL} for the Client's JSON Web Key Set
+		 * @return the {@link Builder} for further configuration
+		 * @since 0.2.1
+		 */
+		public Builder jwkSetUrl(String jwkSetUrl) {
+			return setting(ConfigurationSettingNames.Client.JWK_SET_URL, jwkSetUrl);
+		}
+
+		/**
+		 * Sets {@link SignatureAlgorithm JWS} algorithm that must be used for signing the JWT used to authenticate the
+		 * Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt} authentication methods
+		 *
+		 * @param signingAlgorithm {@link SignatureAlgorithm JWS} algorithm that must be used for signing
+		 *        the JWT used to authenticate the Client at the Token Endpoint for the {@code private_key_jwt} and
+		 *        {@code client_secret_jwt} authentication methods
+		 * @return the {@link Builder} for further configuration
+		 * @since 0.2.1
+		 */
+		public Builder tokenEndpointSigningAlgorithm(JwsAlgorithm signingAlgorithm) {
+			return setting(ConfigurationSettingNames.Client.TOKEN_ENDPOINT_SIGNING_ALGORITHM, signingAlgorithm);
+		}
+
 		/**
 		 * Builds the {@link ClientSettings}.
 		 *

+ 11 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/ConfigurationSettingNames.java

@@ -48,6 +48,17 @@ public final class ConfigurationSettingNames {
 		 */
 		public static final String REQUIRE_AUTHORIZATION_CONSENT = CLIENT_SETTINGS_NAMESPACE.concat("require-authorization-consent");
 
+		/**
+		 * {@code URL} for the Client's JSON Web Key Set {@code (jwks_uri)}
+		 */
+		public static final String JWK_SET_URL = CLIENT_SETTINGS_NAMESPACE.concat("jwk-set-url");
+
+		/**
+		 * {@link SignatureAlgorithm JWS} algorithm that must be used for signing the JWT used to authenticate the
+		 * Client at the Token Endpoint for the {@code private_key_jwt} and {@code client_secret_jwt} authentication methods
+		 */
+		public static final String TOKEN_ENDPOINT_SIGNING_ALGORITHM = CLIENT_SETTINGS_NAMESPACE.concat("token-endpoint-signing-algorithm");
+
 		private Client() {
 		}
 

+ 33 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/jackson2/MacAlgorithmMixin.java

@@ -0,0 +1,33 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.jackson2;
+
+import com.fasterxml.jackson.annotation.JsonAutoDetect;
+import com.fasterxml.jackson.annotation.JsonTypeInfo;
+import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
+
+/**
+ * This mixin class is used to serialize/deserialize {@link MacAlgorithm}.
+ *
+ * @author Rafal Lewczuk
+ * @since 0.2.1
+ * @see MacAlgorithm
+ */
+@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS)
+@JsonAutoDetect(fieldVisibility = JsonAutoDetect.Visibility.ANY, getterVisibility = JsonAutoDetect.Visibility.NONE,
+		isGetterVisibility = JsonAutoDetect.Visibility.NONE)
+public class MacAlgorithmMixin {
+}

+ 2 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/jackson2/OAuth2AuthorizationServerJackson2Module.java

@@ -25,6 +25,7 @@ import com.fasterxml.jackson.databind.module.SimpleModule;
 
 import org.springframework.security.jackson2.SecurityJackson2Modules;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 
 /**
@@ -76,6 +77,7 @@ public class OAuth2AuthorizationServerJackson2Module extends SimpleModule {
 		context.setMixInAnnotations(OAuth2AuthorizationRequest.class, OAuth2AuthorizationRequestMixin.class);
 		context.setMixInAnnotations(Duration.class, DurationMixin.class);
 		context.setMixInAnnotations(SignatureAlgorithm.class, SignatureAlgorithmMixin.class);
+		context.setMixInAnnotations(MacAlgorithm.class, MacAlgorithmMixin.class);
 	}
 
 }

+ 67 - 5
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java

@@ -41,6 +41,8 @@ import org.springframework.security.oauth2.core.OAuth2TokenType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.oidc.OidcClientRegistration;
+import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
+import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.security.oauth2.jwt.JoseHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
@@ -63,6 +65,7 @@ import org.springframework.web.util.UriComponentsBuilder;
  *
  * @author Ovidiu Popa
  * @author Joe Grandja
+ * @author Rafal Lewczuk
  * @since 0.1.1
  * @see RegisteredClientRepository
  * @see OAuth2AuthorizationService
@@ -196,6 +199,12 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 			throw new OAuth2AuthenticationException("invalid_redirect_uri");
 		}
 
+		if (!isValidJwtClientAuthenticationMetadata(clientRegistrationAuthentication.getClientRegistration())) {
+			// TODO Add OAuth2ErrorCodes.INVALID_CLIENT_METADATA
+			// TODO populate "error_description"
+			throw new OAuth2AuthenticationException("invalid_client_metadata");
+		}
+
 		RegisteredClient registeredClient = createClient(clientRegistrationAuthentication.getClientRegistration());
 		this.registeredClientRepository.save(registeredClient);
 
@@ -283,6 +292,16 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 				.idTokenSignedResponseAlgorithm(registeredClient.getTokenSettings().getIdTokenSignatureAlgorithm().getName())
 				.registrationClientUrl(registrationClientUri);
 
+		ClientSettings clientSettings = registeredClient.getClientSettings();
+
+		if (clientSettings.getJwkSetUrl() != null) {
+			builder.jwkSetUrl(clientSettings.getJwkSetUrl());
+		}
+
+		if (clientSettings.getTokenEndpointSigningAlgorithm() != null) {
+			builder.tokenEndpointAuthenticationSigningAlgorithm(clientSettings.getTokenEndpointSigningAlgorithm().getName());
+		}
+
 		return builder;
 		// @formatter:on
 	}
@@ -328,6 +347,31 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 		return true;
 	}
 
+	private static boolean isValidJwtClientAuthenticationMetadata(OidcClientRegistration clientRegistration) {
+		String authenticationMethod = clientRegistration.getTokenEndpointAuthenticationMethod();
+		String signingAlgorithm = clientRegistration.getTokenEndpointAuthenticationSigningAlgorithm();
+
+		if ("none".equals(signingAlgorithm)) {
+			return false;
+		}
+
+		if (ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue().equals(authenticationMethod)) {
+			return signingAlgorithm == null || MacAlgorithm.from(signingAlgorithm) != null;
+		}
+
+		if (ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue().equals(authenticationMethod)) {
+			try {
+				return clientRegistration.getJwkSetUrl() != null && (signingAlgorithm == null
+						|| SignatureAlgorithm.from(signingAlgorithm) != null);
+			} catch (IllegalArgumentException e) {
+				return false;
+			}
+		}
+
+		// TODO return false if token_endpoint_auth_signing_alg or jwks_uri exists but authentication method is not client_secret_jwt nor private_key_jwt ?
+		return true;
+	}
+
 	private static RegisteredClient createClient(OidcClientRegistration clientRegistration) {
 		// @formatter:off
 		RegisteredClient.Builder builder = RegisteredClient.withId(UUID.randomUUID().toString())
@@ -338,8 +382,12 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 
 		if (ClientAuthenticationMethod.CLIENT_SECRET_POST.getValue().equals(clientRegistration.getTokenEndpointAuthenticationMethod())) {
 			builder.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST);
+		} else if (ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue().equals(clientRegistration.getTokenEndpointAuthenticationMethod())) {
+			builder.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT);
+		} else if (ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue().equals(clientRegistration.getTokenEndpointAuthenticationMethod())) {
+			builder.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT);
 		} else {
-			builder.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC);
+				builder.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC);
 		}
 
 		builder.redirectUris(redirectUris ->
@@ -362,11 +410,25 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 					scopes.addAll(clientRegistration.getScopes()));
 		}
 
+		ClientSettings.Builder clientSettingsBuilder = ClientSettings.builder()
+				.requireProofKey(true)
+				.requireAuthorizationConsent(true);
+
+		String signatureAlgorithm = clientRegistration.getTokenEndpointAuthenticationSigningAlgorithm();
+
+		if (ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue().equals(clientRegistration.getTokenEndpointAuthenticationMethod())) {
+			JwsAlgorithm macAlgorithm = signatureAlgorithm != null ? MacAlgorithm.from(signatureAlgorithm) : MacAlgorithm.HS256;
+			clientSettingsBuilder.tokenEndpointSigningAlgorithm(macAlgorithm);
+		}
+
+		if (ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue().equals(clientRegistration.getTokenEndpointAuthenticationMethod())) {
+			JwsAlgorithm jwsAlgorithm = signatureAlgorithm != null ? SignatureAlgorithm.from(signatureAlgorithm) : SignatureAlgorithm.RS256;
+			clientSettingsBuilder.tokenEndpointSigningAlgorithm(jwsAlgorithm);
+			clientSettingsBuilder.jwkSetUrl(clientRegistration.getJwkSetUrl().toString());
+		}
+
 		builder
-				.clientSettings(ClientSettings.builder()
-						.requireProofKey(true)
-						.requireAuthorizationConsent(true)
-						.build())
+				.clientSettings(clientSettingsBuilder.build())
 				.tokenSettings(TokenSettings.builder()
 						.idTokenSignatureAlgorithm(SignatureAlgorithm.RS256)
 						.build());

+ 2 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcProviderConfigurationEndpointFilter.java

@@ -83,6 +83,8 @@ public final class OidcProviderConfigurationEndpointFilter extends OncePerReques
 				.tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue())
 				.tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST.getValue())
 				.jwkSetUrl(asUrl(this.providerSettings.getIssuer(), this.providerSettings.getJwkSetEndpoint()))
+				.tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue())
+				.tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue())
 				.responseType(OAuth2AuthorizationResponseType.CODE.getValue())
 				.grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue())
 				.grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())

+ 2 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationServerMetadataEndpointFilter.java

@@ -104,6 +104,8 @@ public final class OAuth2AuthorizationServerMetadataEndpointFilter extends OnceP
 		return (authenticationMethods) -> {
 			authenticationMethods.add(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue());
 			authenticationMethods.add(ClientAuthenticationMethod.CLIENT_SECRET_POST.getValue());
+			authenticationMethods.add(ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue());
+			authenticationMethods.add(ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue());
 		};
 	}
 

+ 2 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2ClientAuthenticationFilter.java

@@ -43,6 +43,7 @@ import org.springframework.security.oauth2.server.authorization.web.authenticati
 import org.springframework.security.oauth2.server.authorization.web.authentication.ClientSecretPostAuthenticationConverter;
 import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter;
 import org.springframework.security.oauth2.server.authorization.web.authentication.PublicClientAuthenticationConverter;
+import org.springframework.security.oauth2.server.authorization.web.authentication.JwtClientAssertionAuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
@@ -88,6 +89,7 @@ public final class OAuth2ClientAuthenticationFilter extends OncePerRequestFilter
 				Arrays.asList(
 						new ClientSecretBasicAuthenticationConverter(),
 						new ClientSecretPostAuthenticationConverter(),
+						new JwtClientAssertionAuthenticationConverter(),
 						new PublicClientAuthenticationConverter()));
 	}
 

+ 1 - 14
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretBasicAuthenticationConverter.java

@@ -18,9 +18,6 @@ package org.springframework.security.oauth2.server.authorization.web.authenticat
 import java.net.URLDecoder;
 import java.nio.charset.StandardCharsets;
 import java.util.Base64;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
 
 import javax.servlet.http.HttpServletRequest;
 
@@ -92,16 +89,6 @@ public final class ClientSecretBasicAuthenticationConverter implements Authentic
 		}
 
 		return new OAuth2ClientAuthenticationToken(clientID, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, clientSecret,
-				extractAdditionalParameters(request));
+				OAuth2EndpointUtils.extractAdditionalParameters(request));
 	}
-
-	private static Map<String, Object> extractAdditionalParameters(HttpServletRequest request) {
-		Map<String, Object> additionalParameters = Collections.emptyMap();
-		if (OAuth2EndpointUtils.matchesAuthorizationCodeGrantRequest(request)) {
-			// Confidential clients can also leverage PKCE
-			additionalParameters = new HashMap<>(OAuth2EndpointUtils.getParameters(request).toSingleValueMap());
-		}
-		return additionalParameters;
-	}
-
 }

+ 4 - 14
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java

@@ -15,8 +15,6 @@
  */
 package org.springframework.security.oauth2.server.authorization.web.authentication;
 
-import java.util.Collections;
-import java.util.HashMap;
 import java.util.Map;
 
 import javax.servlet.http.HttpServletRequest;
@@ -71,18 +69,10 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica
 			throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
 		}
 
-		return new OAuth2ClientAuthenticationToken(clientId, ClientAuthenticationMethod.CLIENT_SECRET_POST, clientSecret,
-				extractAdditionalParameters(request));
-	}
+		Map<String, Object> additionalParameters = OAuth2EndpointUtils.extractAdditionalParameters(request,
+				OAuth2ParameterNames.CLIENT_ID, OAuth2ParameterNames.CLIENT_SECRET);
 
-	private static Map<String, Object> extractAdditionalParameters(HttpServletRequest request) {
-		Map<String, Object> additionalParameters = Collections.emptyMap();
-		if (OAuth2EndpointUtils.matchesAuthorizationCodeGrantRequest(request)) {
-			// Confidential clients can also leverage PKCE
-			additionalParameters = new HashMap<>(OAuth2EndpointUtils.getParameters(request).toSingleValueMap());
-			additionalParameters.remove(OAuth2ParameterNames.CLIENT_ID);
-			additionalParameters.remove(OAuth2ParameterNames.CLIENT_SECRET);
-		}
-		return additionalParameters;
+		return new OAuth2ClientAuthenticationToken(clientId, ClientAuthenticationMethod.CLIENT_SECRET_POST, clientSecret,
+				additionalParameters);
 	}
 }

+ 92 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java

@@ -0,0 +1,92 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.server.authorization.web.authentication;
+
+import org.springframework.lang.Nullable;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.web.OAuth2ClientAuthenticationFilter;
+import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.util.MultiValueMap;
+import org.springframework.util.StringUtils;
+
+import javax.servlet.http.HttpServletRequest;
+import java.util.Map;
+
+/**
+ * Attempts to extract client assertion credentials from {@link HttpServletRequest}
+ * and then converts to an {@link OAuth2ClientAuthenticationToken} used for authenticating the client.
+ *
+ * @author Rafal Lewczuk
+ * @since 0.2.1
+ * @see AuthenticationConverter
+ * @see OAuth2ClientAuthenticationToken
+ * @see OAuth2ClientAuthenticationFilter
+ */
+public final class JwtClientAssertionAuthenticationConverter implements AuthenticationConverter {
+
+	private static final ClientAuthenticationMethod JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD
+			= new ClientAuthenticationMethod("urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
+
+	@Nullable
+	@Override
+	public Authentication convert(HttpServletRequest request) {
+		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
+
+		// client_assertion_type (REQUIRED), client_assertion (REQUIRED)
+		String clientJwtAssertionType = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE);
+		String clientJwtAssertion = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
+
+		if (!StringUtils.hasText(clientJwtAssertionType) || !StringUtils.hasText(clientJwtAssertion)) {
+			return null;
+		}
+
+		if (parameters.get(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE).size() != 1) {
+			throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
+		}
+
+		if (parameters.get(OAuth2ParameterNames.CLIENT_ASSERTION).size() != 1) {
+			throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
+		}
+
+		if (!JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD.getValue().equals(clientJwtAssertionType)) {
+			throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
+		}
+
+		// client_id (OPTIONAL as per specification but REQUIRED by this implementation)
+		String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
+		if (!StringUtils.hasText(clientId)) {
+			throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
+		}
+
+		if (parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
+			throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
+		}
+
+		Map<String, Object> additionalParameters = OAuth2EndpointUtils.extractAdditionalParameters(request,
+				OAuth2ParameterNames.CLIENT_ID,
+				OAuth2ParameterNames.CLIENT_ASSERTION_TYPE,
+				OAuth2ParameterNames.CLIENT_ASSERTION);
+
+		return new OAuth2ClientAuthenticationToken(request.getRequestURI(), clientId,
+				JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD, clientJwtAssertion, additionalParameters);
+	}
+}

+ 12 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java

@@ -15,6 +15,8 @@
  */
 package org.springframework.security.oauth2.server.authorization.web.authentication;
 
+import java.util.Collections;
+import java.util.HashMap;
 import java.util.Map;
 
 import javax.servlet.http.HttpServletRequest;
@@ -68,4 +70,14 @@ final class OAuth2EndpointUtils {
 		throw new OAuth2AuthenticationException(error);
 	}
 
+	static Map<String, Object> extractAdditionalParameters(HttpServletRequest request, String...exclusions) {
+		Map<String, Object> additionalParameters = Collections.emptyMap();
+		if (OAuth2EndpointUtils.matchesAuthorizationCodeGrantRequest(request)) {
+			additionalParameters = new HashMap<>(OAuth2EndpointUtils.getParameters(request).toSingleValueMap());
+			for (String exclusion : exclusions) {
+				additionalParameters.remove(exclusion);
+			}
+		}
+		return additionalParameters;
+	}
 }

+ 187 - 6
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcClientRegistrationTests.java

@@ -15,19 +15,27 @@
  */
 package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization;
 
-import java.net.URLEncoder;
-import java.nio.charset.StandardCharsets;
-import java.util.Base64;
-
+import com.nimbusds.jose.JOSEException;
+import com.nimbusds.jose.JWSAlgorithm;
+import com.nimbusds.jose.JWSHeader;
+import com.nimbusds.jose.JWSSigner;
+import com.nimbusds.jose.crypto.MACSigner;
+import com.nimbusds.jose.crypto.RSASSASigner;
+import com.nimbusds.jose.jwk.JWK;
 import com.nimbusds.jose.jwk.JWKSet;
+import com.nimbusds.jose.jwk.KeyUse;
+import com.nimbusds.jose.jwk.RSAKey;
 import com.nimbusds.jose.jwk.source.JWKSource;
 import com.nimbusds.jose.proc.SecurityContext;
+import com.nimbusds.jwt.JWTClaimsSet;
+import com.nimbusds.jwt.SignedJWT;
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
 import org.junit.After;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
-
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.http.HttpHeaders;
@@ -48,11 +56,13 @@ import org.springframework.security.config.annotation.web.configuration.EnableWe
 import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
 import org.springframework.security.config.annotation.web.configurers.oauth2.server.resource.OAuth2ResourceServerConfigurer;
 import org.springframework.security.config.test.SpringTestRule;
+import org.springframework.security.config.util.ValueCaptureMatcher;
 import org.springframework.security.crypto.password.NoOpPasswordEncoder;
 import org.springframework.security.crypto.password.PasswordEncoder;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2TokenType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -75,6 +85,15 @@ import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.MvcResult;
 
+import java.net.URLEncoder;
+import java.nio.charset.StandardCharsets;
+import java.security.KeyPair;
+import java.security.KeyPairGenerator;
+import java.security.interfaces.RSAPublicKey;
+import java.util.Base64;
+import java.util.Date;
+import java.util.UUID;
+
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.hamcrest.CoreMatchers.containsString;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
@@ -90,8 +109,12 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
  * @author Joe Grandja
  */
 public class OidcClientRegistrationTests {
+	private static final String DEFAULT_ISSUER = "https://auth-server:9000";
 	private static final String DEFAULT_TOKEN_ENDPOINT_URI = "/oauth2/token";
+	private static final String DEFAULT_INTROSPECTION_ENDPOINT_URI = "/oauth2/introspect";
+	private static final String DEFAULT_REVOCATION_ENDPOINT_URI = "/oauth2/revoke";
 	private static final String DEFAULT_OIDC_CLIENT_REGISTRATION_ENDPOINT_URI = "/connect/register";
+	private static final String JWT_ASSERTION_TYPE = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
 	private static final HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
 			new OAuth2AccessTokenResponseHttpMessageConverter();
 	private static final HttpMessageConverter<OidcClientRegistration> clientRegistrationHttpMessageConverter =
@@ -224,6 +247,139 @@ public class OidcClientRegistrationTests {
 		assertThat(clientConfigurationResponse.getRegistrationAccessToken()).isNull();
 	}
 
+	@Test
+	public void whenClientRegisterationWithClientSecretJwtAuthenticationThenJwtClientAuthenticationSuccess() throws Exception {
+		this.spring.register(AuthorizationServerConfiguration.class).autowire();
+
+		// @formatter:off
+		OidcClientRegistration clientRegistration = OidcClientRegistration.builder()
+				.clientName("client-name")
+				.redirectUri("https://client.example.com")
+				.grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue())
+				.grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
+				.tokenEndpointAuthenticationSigningAlgorithm("HS256")
+				.tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue())
+				.scope("scope1")
+				.scope("scope2")
+				.build();
+		// @formatter:on
+
+		OidcClientRegistration clientRegistrationResponse = registerClient(clientRegistration);
+		ValueCaptureMatcher<String> accessTokenCapture = new ValueCaptureMatcher<>();
+
+		// token creation with JWT assertion
+		String clientJwtAssertion = clientSecretJwtAssertion(clientRegistrationResponse, DEFAULT_TOKEN_ENDPOINT_URI);
+		this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
+				.param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
+				.param(OAuth2ParameterNames.SCOPE, "scope1")
+				.param(OAuth2ParameterNames.CLIENT_ID, clientRegistrationResponse.getClientId())
+				.param(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, JWT_ASSERTION_TYPE)
+				.param(OAuth2ParameterNames.CLIENT_ASSERTION, clientJwtAssertion))
+				.andExpect(status().isOk())
+				.andExpect(jsonPath("$.access_token").isNotEmpty())
+				.andExpect(jsonPath("$.access_token").value(accessTokenCapture))
+				.andExpect(jsonPath("$.scope").value("scope1"));
+
+		String accessToken = accessTokenCapture.lastValue();
+
+		// token introspection with JWT assertion
+		clientJwtAssertion = clientSecretJwtAssertion(clientRegistrationResponse, DEFAULT_INTROSPECTION_ENDPOINT_URI);
+		this.mvc.perform(post(DEFAULT_INTROSPECTION_ENDPOINT_URI)
+				.param(OAuth2ParameterNames.CLIENT_ID, clientRegistrationResponse.getClientId())
+				.param(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, JWT_ASSERTION_TYPE)
+				.param(OAuth2ParameterNames.CLIENT_ASSERTION, clientJwtAssertion)
+				.param(OAuth2ParameterNames.TOKEN, accessToken)
+				.param(OAuth2ParameterNames.TOKEN_TYPE_HINT, OAuth2TokenType.ACCESS_TOKEN.getValue()))
+				.andExpect(status().isOk());
+
+		// token revocation with JWT assertion
+		clientJwtAssertion = clientSecretJwtAssertion(clientRegistrationResponse, DEFAULT_REVOCATION_ENDPOINT_URI);
+		this.mvc.perform(post(DEFAULT_REVOCATION_ENDPOINT_URI)
+				.param(OAuth2ParameterNames.CLIENT_ID, clientRegistrationResponse.getClientId())
+				.param(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, JWT_ASSERTION_TYPE)
+				.param(OAuth2ParameterNames.CLIENT_ASSERTION, clientJwtAssertion)
+				.param(OAuth2ParameterNames.TOKEN, accessToken)
+				.param(OAuth2ParameterNames.TOKEN_TYPE_HINT, OAuth2TokenType.ACCESS_TOKEN.getValue()))
+				.andExpect(status().isOk());
+	}
+
+	@Test
+	public void whenClientRegistrationWithPrivateKeyJwtAuthenticationThenJwtClientAuthenticationSuccess() throws Exception {
+		this.spring.register(AuthorizationServerConfiguration.class).autowire();
+
+		KeyPairGenerator gen = KeyPairGenerator.getInstance("RSA");
+		gen.initialize(2048);
+		KeyPair keyPair = gen.generateKeyPair();
+
+		JWK jwk = new RSAKey.Builder((RSAPublicKey) keyPair.getPublic())
+				.keyUse(KeyUse.SIGNATURE)
+				.keyID(UUID.randomUUID().toString())
+				.build();
+
+		String jwks = "{\"keys\":[" + jwk.toJSONString() + "]}";
+
+		try (MockWebServer server = new MockWebServer()) {
+			String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
+
+			// @formatter:off
+			OidcClientRegistration clientRegistration = OidcClientRegistration.builder()
+					.clientName("client-name")
+					.redirectUri("https://client.example.com")
+					.grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue())
+					.grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
+					.tokenEndpointAuthenticationSigningAlgorithm("RS256")
+					.tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue())
+					.jwkSetUrl(jwkSetUrl)
+					.scope("scope1")
+					.scope("scope2")
+					.build();
+			// @formatter:on
+
+			OidcClientRegistration clientRegistrationResponse = registerClient(clientRegistration);
+			ValueCaptureMatcher<String> accessTokenCapture = new ValueCaptureMatcher<>();
+
+			// token creation with JWT assertion
+			String clientJwtAssertion = privateKeyJwtAssertion(keyPair, clientRegistrationResponse, DEFAULT_TOKEN_ENDPOINT_URI);
+			server.enqueue(new MockResponse().setBody(jwks));
+			this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
+					.param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
+					.param(OAuth2ParameterNames.SCOPE, "scope1")
+					.param(OAuth2ParameterNames.CLIENT_ID, clientRegistrationResponse.getClientId())
+					.param(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, JWT_ASSERTION_TYPE)
+					.param(OAuth2ParameterNames.CLIENT_ASSERTION, clientJwtAssertion))
+					.andExpect(status().isOk())
+					.andExpect(jsonPath("$.access_token").isNotEmpty())
+					.andExpect(jsonPath("$.access_token").value(accessTokenCapture))
+					.andExpect(jsonPath("$.scope").value("scope1"));
+
+			String accessToken = accessTokenCapture.lastValue();
+
+				// token introspection with JWT assertion
+				clientJwtAssertion = privateKeyJwtAssertion(keyPair, clientRegistrationResponse, DEFAULT_INTROSPECTION_ENDPOINT_URI);
+				server.enqueue(new MockResponse().setBody(jwks));
+				this.mvc.perform(post(DEFAULT_INTROSPECTION_ENDPOINT_URI)
+						.param(OAuth2ParameterNames.CLIENT_ID, clientRegistrationResponse.getClientId())
+						.param(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, JWT_ASSERTION_TYPE)
+						.param(OAuth2ParameterNames.CLIENT_ASSERTION, clientJwtAssertion)
+						.param(OAuth2ParameterNames.TOKEN, accessToken)
+						.param(OAuth2ParameterNames.TOKEN_TYPE_HINT, OAuth2TokenType.ACCESS_TOKEN.getValue()))
+						.andExpect(status().isOk());
+
+			// token revocation with JWT assertion
+			clientJwtAssertion = privateKeyJwtAssertion(keyPair, clientRegistrationResponse, DEFAULT_REVOCATION_ENDPOINT_URI);
+			server.enqueue(new MockResponse().setBody(jwks));
+			this.mvc.perform(post(DEFAULT_REVOCATION_ENDPOINT_URI)
+					.param(OAuth2ParameterNames.CLIENT_ID, clientRegistrationResponse.getClientId())
+					.param(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, JWT_ASSERTION_TYPE)
+					.param(OAuth2ParameterNames.CLIENT_ASSERTION, clientJwtAssertion)
+					.param(OAuth2ParameterNames.TOKEN, accessToken)
+					.param(OAuth2ParameterNames.TOKEN_TYPE_HINT, OAuth2TokenType.ACCESS_TOKEN.getValue()))
+					.andExpect(status().isOk());
+
+			server.shutdown();
+		}
+	}
+
 	private OidcClientRegistration registerClient(OidcClientRegistration clientRegistration) throws Exception {
 		// ***** (1) Obtain the "initial" access token used for registering the client
 
@@ -289,6 +445,31 @@ public class OidcClientRegistrationTests {
 		return clientRegistrationHttpMessageConverter.read(OidcClientRegistration.class, httpResponse);
 	}
 
+	private JWTClaimsSet jwtClientAuthenticationClaims(OidcClientRegistration clientRegistration, String endpointUri) {
+		return  new JWTClaimsSet.Builder()
+				.subject(clientRegistration.getClientId())
+				.issuer(clientRegistration.getClientId())
+				.expirationTime(new Date(new Date().getTime() + 60000))
+				.audience(DEFAULT_ISSUER + endpointUri)
+				.build();
+	}
+
+	private String clientSecretJwtAssertion(OidcClientRegistration clientRegistration, String endpointUri) throws JOSEException {
+		JWTClaimsSet claimsSet = jwtClientAuthenticationClaims(clientRegistration, endpointUri);
+		SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.HS256), claimsSet);
+		JWSSigner signer = new MACSigner(clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8));
+		signedJWT.sign(signer);
+		return signedJWT.serialize();
+	}
+
+	private String privateKeyJwtAssertion(KeyPair keyPair, OidcClientRegistration clientRegistration, String endpointUri) throws JOSEException {
+		JWTClaimsSet claimsSet = jwtClientAuthenticationClaims(clientRegistration, endpointUri);
+		SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.RS256), claimsSet);
+		JWSSigner signer = new RSASSASigner(keyPair.getPrivate());
+		signedJWT.sign(signer);
+		return signedJWT.serialize();
+	}
+
 	@EnableWebSecurity
 	static class AuthorizationServerConfiguration {
 
@@ -348,7 +529,7 @@ public class OidcClientRegistrationTests {
 		@Bean
 		ProviderSettings providerSettings() {
 			return ProviderSettings.builder()
-					.issuer("https://auth-server:9000")
+					.issuer(DEFAULT_ISSUER)
 					.build();
 		}
 

+ 65 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/config/util/ValueCaptureMatcher.java

@@ -0,0 +1,65 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.config.util;
+
+import org.assertj.core.util.Throwables;
+import org.hamcrest.BaseMatcher;
+import org.hamcrest.Description;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Hamcrest matcher that records matched values
+ *
+ * @author Rafal Lewczuk
+ * @since 0.2.1
+ * @param <T>
+ */
+public class ValueCaptureMatcher<T> extends BaseMatcher<T> {
+
+	private ClassCastException castException;
+	private List<T> values = new ArrayList<>();
+
+	public T lastValue() {
+		return values.isEmpty() ? null : values.get(values.size()-1);
+	}
+
+	public List<T> getValues() {
+		return values;
+	}
+
+	@Override
+	public boolean matches(Object item) {
+		try {
+			values.add((T) item);
+		} catch (ClassCastException e) {
+			castException = e;
+			return false;
+		}
+		return true;
+	}
+
+	@Override
+	public void describeTo(Description description) {
+		if (castException != null) {
+			description.appendText("ClassCastException with message: ");
+			description.appendText(castException.getMessage());
+			description.appendText(String.format("%n%nStacktrace was: "));
+			description.appendText(Throwables.getStackTrace(castException));
+		}
+	}
+}

+ 9 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/OidcClientRegistrationTests.java

@@ -26,6 +26,7 @@ import org.junit.Test;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
+import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -46,6 +47,7 @@ public class OidcClientRegistrationTests {
 
 	@Test
 	public void buildWhenAllClaimsProvidedThenCreated() {
+
 		// @formatter:off
 		Instant clientIdIssuedAt = Instant.now();
 		Instant clientSecretExpiresAt = clientIdIssuedAt.plus(30, ChronoUnit.DAYS);
@@ -56,7 +58,9 @@ public class OidcClientRegistrationTests {
 				.clientSecretExpiresAt(clientSecretExpiresAt)
 				.clientName("client-name")
 				.redirectUri("https://client.example.com")
+				.jwkSetUrl("https://client.example.com/jwks")
 				.tokenEndpointAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue())
+				.tokenEndpointAuthenticationSigningAlgorithm(MacAlgorithm.HS256.getName())
 				.grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue())
 				.grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
 				.responseType(OAuth2AuthorizationResponseType.CODE.getValue())
@@ -74,6 +78,7 @@ public class OidcClientRegistrationTests {
 		assertThat(clientRegistration.getClientSecret()).isEqualTo("client-secret");
 		assertThat(clientRegistration.getClientSecretExpiresAt()).isEqualTo(clientSecretExpiresAt);
 		assertThat(clientRegistration.getClientName()).isEqualTo("client-name");
+		assertThat(clientRegistration.getJwkSetUrl().toString()).isEqualTo("https://client.example.com/jwks");
 		assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com");
 		assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue());
 		assertThat(clientRegistration.getGrantTypes()).containsExactlyInAnyOrder("authorization_code", "client_credentials");
@@ -102,7 +107,9 @@ public class OidcClientRegistrationTests {
 		claims.put(OidcClientMetadataClaimNames.CLIENT_SECRET_EXPIRES_AT, clientSecretExpiresAt);
 		claims.put(OidcClientMetadataClaimNames.CLIENT_NAME, "client-name");
 		claims.put(OidcClientMetadataClaimNames.REDIRECT_URIS, Collections.singletonList("https://client.example.com"));
+		claims.put(OidcClientMetadataClaimNames.JWKS_URI, "https://client.example.com/jwks");
 		claims.put(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_METHOD, ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue());
+		claims.put(OidcClientMetadataClaimNames.TOKEN_ENDPOINT_AUTH_SIGNING_ALG, MacAlgorithm.HS256.getName());
 		claims.put(OidcClientMetadataClaimNames.GRANT_TYPES, Arrays.asList(
 				AuthorizationGrantType.AUTHORIZATION_CODE.getValue(), AuthorizationGrantType.CLIENT_CREDENTIALS.getValue()));
 		claims.put(OidcClientMetadataClaimNames.RESPONSE_TYPES, Collections.singletonList("code"));
@@ -120,7 +127,9 @@ public class OidcClientRegistrationTests {
 		assertThat(clientRegistration.getClientSecretExpiresAt()).isEqualTo(clientSecretExpiresAt);
 		assertThat(clientRegistration.getClientName()).isEqualTo("client-name");
 		assertThat(clientRegistration.getRedirectUris()).containsOnly("https://client.example.com");
+		assertThat(clientRegistration.getJwkSetUrl().toString()).isEqualTo("https://client.example.com/jwks");
 		assertThat(clientRegistration.getTokenEndpointAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.CLIENT_SECRET_BASIC.getValue());
+		assertThat(clientRegistration.getTokenEndpointAuthenticationSigningAlgorithm()).isEqualTo(MacAlgorithm.HS256.getName());
 		assertThat(clientRegistration.getGrantTypes()).containsExactlyInAnyOrder("authorization_code", "client_credentials");
 		assertThat(clientRegistration.getResponseTypes()).containsOnly("code");
 		assertThat(clientRegistration.getScopes()).containsExactlyInAnyOrder("scope1", "scope2");

+ 8 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/core/oidc/http/converter/OidcClientRegistrationHttpMessageConverterTests.java

@@ -107,6 +107,8 @@ public class OidcClientRegistrationHttpMessageConverterTests {
 				+"		],\n"
 				+"		\"scope\": \"scope1 scope2\",\n"
 				+"		\"id_token_signed_response_alg\": \"RS256\",\n"
+				+"      \"jwks_uri\": \"https://client.example.com/jwks\",\n"
+				+"      \"token_endpoint_auth_signing_alg\": \"HS256\",\n"
 				+"		\"a-claim\": \"a-value\"\n"
 				+"}\n";
 		// @formatter:on
@@ -126,6 +128,8 @@ public class OidcClientRegistrationHttpMessageConverterTests {
 		assertThat(clientRegistration.getResponseTypes()).containsOnly("code");
 		assertThat(clientRegistration.getScopes()).containsExactlyInAnyOrder("scope1", "scope2");
 		assertThat(clientRegistration.getIdTokenSignedResponseAlgorithm()).isEqualTo("RS256");
+		assertThat(clientRegistration.getJwkSetUrl().toString()).isEqualTo("https://client.example.com/jwks");
+		assertThat(clientRegistration.getTokenEndpointAuthenticationSigningAlgorithm()).isEqualTo("HS256");
 		assertThat(clientRegistration.getClaimAsString("a-claim")).isEqualTo("a-value");
 	}
 
@@ -186,6 +190,8 @@ public class OidcClientRegistrationHttpMessageConverterTests {
 				.idTokenSignedResponseAlgorithm(SignatureAlgorithm.RS256.getName())
 				.registrationAccessToken("registration-access-token")
 				.registrationClientUrl("https://auth-server.com/connect/register?client_id=1")
+				.jwkSetUrl("https://client.example.com/jwks")
+				.tokenEndpointAuthenticationSigningAlgorithm("HS256")
 				.claim("a-claim", "a-value")
 				.build();
 		// @formatter:on
@@ -207,6 +213,8 @@ public class OidcClientRegistrationHttpMessageConverterTests {
 		assertThat(clientRegistrationResponse).contains("\"id_token_signed_response_alg\":\"RS256\"");
 		assertThat(clientRegistrationResponse).contains("\"registration_access_token\":\"registration-access-token\"");
 		assertThat(clientRegistrationResponse).contains("\"registration_client_uri\":\"https://auth-server.com/connect/register?client_id=1\"");
+		assertThat(clientRegistrationResponse).contains("\"jwks_uri\":\"https://client.example.com/jwks\"");
+		assertThat(clientRegistrationResponse).contains("\"token_endpoint_auth_signing_alg\":\"HS256\"");
 		assertThat(clientRegistrationResponse).contains("\"a-claim\":\"a-value\"");
 	}
 

+ 164 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationProviderTests.java

@@ -15,6 +15,7 @@
  */
 package org.springframework.security.oauth2.server.authorization.authentication;
 
+import java.time.Instant;
 import java.util.HashMap;
 import java.util.Map;
 
@@ -30,18 +31,23 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2TokenType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
+import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
+import org.springframework.test.util.ReflectionTestUtils;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
@@ -53,6 +59,7 @@ import static org.mockito.Mockito.when;
  * @author Joe Grandja
  * @author Daniel Garnier-Moiroux
  * @author Anoop Garlapati
+ * @author Rafal Lewczuk
  */
 public class OAuth2ClientAuthenticationProviderTests {
 	private static final String PLAIN_CODE_VERIFIER = "pkce-key";
@@ -66,17 +73,20 @@ public class OAuth2ClientAuthenticationProviderTests {
 	private static final String AUTHORIZATION_CODE = "code";
 	private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
 
+	private static final ClientAuthenticationMethod JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD =
+			new ClientAuthenticationMethod("urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
+
 	private RegisteredClientRepository registeredClientRepository;
 	private OAuth2AuthorizationService authorizationService;
 	private OAuth2ClientAuthenticationProvider authenticationProvider;
 	private PasswordEncoder passwordEncoder;
+	private JwtDecoderFactory<RegisteredClient> jwtDecoderFactory;
 
 	@Before
 	public void setUp() {
 		this.registeredClientRepository = mock(RegisteredClientRepository.class);
 		this.authorizationService = mock(OAuth2AuthorizationService.class);
-		this.authenticationProvider = new OAuth2ClientAuthenticationProvider(
-				this.registeredClientRepository, this.authorizationService);
+		this.authenticationProvider = new OAuth2ClientAuthenticationProvider(this.registeredClientRepository, this.authorizationService);
 		this.passwordEncoder = spy(new PasswordEncoder() {
 			@Override
 			public String encode(CharSequence rawPassword) {
@@ -89,6 +99,9 @@ public class OAuth2ClientAuthenticationProviderTests {
 			}
 		});
 		this.authenticationProvider.setPasswordEncoder(this.passwordEncoder);
+		this.authenticationProvider.setProviderSettings(ProviderSettings.builder().issuer("https://auth-server.com").build());
+		this.jwtDecoderFactory = mock(JwtDecoderFactory.class);
+		ReflectionTestUtils.setField(this.authenticationProvider, "jwtDecoderFactory", this.jwtDecoderFactory);
 	}
 
 	@Test
@@ -207,6 +220,146 @@ public class OAuth2ClientAuthenticationProviderTests {
 		assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient);
 	}
 
+	@Test
+	public void authenticateWhenJwtBearerAndClientNotSupportingItThenThrowOAuth2AuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+				.thenReturn(registeredClient);
+		OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
+				registeredClient.getClientId(), JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD,
+				registeredClient.getClientSecret(), null);
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+
+		verify(this.jwtDecoderFactory, never()).createDecoder(any());
+	}
+
+	@Test
+	public void authenticateWhenClientJwtAssertionAndPrivateJwtAndFailedCreateDecoderThenThrowOAuth2AuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
+				.build();
+		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+				.thenReturn(registeredClient);
+		when(this.jwtDecoderFactory.createDecoder(any()))
+				.thenThrow(new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT));
+		OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
+				"https://auth-server.com", registeredClient.getClientId(), JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD,
+				registeredClient.getClientSecret(), null);
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+	}
+
+	@Test
+	public void authenticateWhenClientJwtAssertionAndPrivateKeyJwtAndFailedVerifyTokenThenThrowOAuth2AuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
+				.build();
+		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+				.thenReturn(registeredClient);
+		when(this.jwtDecoderFactory.createDecoder(any()))
+				.thenReturn(s -> { throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT); });
+		OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
+				"https://auth-server.com", registeredClient.getClientId(), JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD,
+				registeredClient.getClientSecret(), null);
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+	}
+
+	@Test
+	public void authenticateWhenClientJwtAssertionAndBadAudienceThenThrowOAuth2AuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT).build();
+		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+				.thenReturn(registeredClient);
+		when(this.jwtDecoderFactory.createDecoder(any()))
+				.thenReturn(s -> createJwtToken("client-1", "https://bad-server.com/oauth2/token"));
+		OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
+				"/oauth2/token", registeredClient.getClientId(),
+				JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD, registeredClient.getClientSecret(), null);
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+	}
+
+	@Test
+	public void authenticateWhenClientJwtAssertionAndPrivateJwtVerificationSuccessThenAuthenticate() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT).build();
+		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+				.thenReturn(registeredClient);
+		when(this.jwtDecoderFactory.createDecoder(any()))
+				.thenReturn(s -> createJwtToken("client-1", "https://auth-server.com/oauth2/token"));
+		OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
+				"https://auth-server.com/oauth2/token", registeredClient.getClientId(),
+				JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD, registeredClient.getClientSecret(), null);
+
+		OAuth2ClientAuthenticationToken authenticationResult =
+				(OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+		assertThat(authenticationResult.isAuthenticated()).isTrue();
+		assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId());
+		assertThat(authenticationResult.getCredentials().toString()).isEqualTo(registeredClient.getClientSecret());
+		assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient);
+	}
+
+	@Test
+	public void authenticateWhenClientJwtAssertionAndClientSecretJwtAndFailedVerifyTokenThenThrowOAuth2Exception() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
+				.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
+				.build();
+		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+				.thenReturn(registeredClient);
+		when(this.jwtDecoderFactory.createDecoder(any()))
+				.thenReturn(s -> { throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT); });
+		OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
+				"https://auth-server.com", registeredClient.getClientId(), JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD,
+				registeredClient.getClientSecret(), null);
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+	}
+
+	@Test
+	public void authenticateWhenClientJwtAssertionAndClientSecretJwtVerificationSuccess() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT).build();
+		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+				.thenReturn(registeredClient);
+		when(this.jwtDecoderFactory.createDecoder(any()))
+				.thenReturn(s -> createJwtToken("client-1", "https://auth-server.com/oauth2/token"));
+		OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken(
+				"https://auth-server.com/oauth2/token", registeredClient.getClientId(), JWT_CLIENT_ASSERTION_AUTHENTICATION_METHOD,
+				registeredClient.getClientSecret(), null);
+
+		OAuth2ClientAuthenticationToken authenticationResult =
+				(OAuth2ClientAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+		assertThat(authenticationResult.isAuthenticated()).isTrue();
+		assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(registeredClient.getClientId());
+		assertThat(authenticationResult.getCredentials().toString()).isEqualTo(registeredClient.getClientSecret());
+		assertThat(authenticationResult.getRegisteredClient()).isEqualTo(registeredClient);
+	}
+
 	@Test
 	public void authenticateWhenPkceAndInvalidCodeThenThrowOAuth2AuthenticationException() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build();
@@ -485,4 +638,13 @@ public class OAuth2ClientAuthenticationProviderTests {
 		parameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE);
 		return parameters;
 	}
+
+	private static Jwt createJwtToken(String subject, String audience) {
+		Map<String, Object> headers = new HashMap<>();
+		headers.put("kid", "123");
+		Map<String, Object> claims = new HashMap<>();
+		claims.put("sub", subject);
+		claims.put("aud", audience);
+		return new Jwt("123", Instant.now().minusSeconds(30), Instant.now().plusSeconds(30), headers, claims);
+	}
 }

+ 11 - 3
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientAuthenticationTokenTests.java

@@ -32,16 +32,23 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
  */
 public class OAuth2ClientAuthenticationTokenTests {
 
+	@Test
+	public void constructorWhenRequestUriNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2ClientAuthenticationToken(null, "clientId", ClientAuthenticationMethod.CLIENT_SECRET_BASIC, "secret", null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("requestUri cannot be empty");
+	}
+
 	@Test
 	public void constructorWhenClientIdNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2ClientAuthenticationToken(null, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, "secret", null))
+		assertThatThrownBy(() -> new OAuth2ClientAuthenticationToken("issuer", null, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, "secret", null))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("clientId cannot be empty");
 	}
 
 	@Test
 	public void constructorWhenClientAuthenticationMethodNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2ClientAuthenticationToken("clientId", null, "clientSecret", null))
+		assertThatThrownBy(() -> new OAuth2ClientAuthenticationToken("issuer", "clientId", null, "clientSecret", null))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("clientAuthenticationMethod cannot be null");
 	}
@@ -55,9 +62,10 @@ public class OAuth2ClientAuthenticationTokenTests {
 
 	@Test
 	public void constructorWhenClientCredentialsProvidedThenCreated() {
-		OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken("clientId",
+		OAuth2ClientAuthenticationToken authentication = new OAuth2ClientAuthenticationToken("issuer", "clientId",
 				ClientAuthenticationMethod.CLIENT_SECRET_BASIC, "secret", null);
 		assertThat(authentication.isAuthenticated()).isFalse();
+		assertThat(authentication.getRequestUri()).isEqualTo("issuer");
 		assertThat(authentication.getPrincipal().toString()).isEqualTo("clientId");
 		assertThat(authentication.getCredentials()).isEqualTo("secret");
 		assertThat(authentication.getRegisteredClient()).isNull();

+ 436 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/RegisteredClientJwtAssertionDecoderFactoryTests.java

@@ -0,0 +1,436 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+import com.nimbusds.jose.JOSEException;
+import com.nimbusds.jose.JWSAlgorithm;
+import com.nimbusds.jose.JWSHeader;
+import com.nimbusds.jose.JWSSigner;
+import com.nimbusds.jose.crypto.MACSigner;
+import com.nimbusds.jose.crypto.RSASSASigner;
+import com.nimbusds.jose.jwk.JWK;
+import com.nimbusds.jose.jwk.KeyUse;
+import com.nimbusds.jose.jwk.RSAKey;
+import com.nimbusds.jwt.JWTClaimsSet;
+import com.nimbusds.jwt.SignedJWT;
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
+import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
+import org.springframework.security.oauth2.jwt.JwtDecoder;
+import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
+import org.springframework.security.oauth2.jwt.JwtException;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
+import org.springframework.test.util.ReflectionTestUtils;
+
+import java.nio.charset.StandardCharsets;
+import java.security.KeyPair;
+import java.security.KeyPairGenerator;
+import java.security.interfaces.RSAPublicKey;
+import java.time.Instant;
+import java.util.Date;
+import java.util.UUID;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for {@link OAuth2ClientAuthenticationProvider.RegisteredClientJwtAssertionDecoderFactory}
+ *
+ * @author Rafal Lewczuk
+ */
+public class RegisteredClientJwtAssertionDecoderFactoryTests {
+
+	private JwtDecoderFactory<RegisteredClient> registeredClientDecoderFactory;
+
+	@Before
+	public void setUp() {
+		OAuth2ClientAuthenticationProvider authenticationProvider = new OAuth2ClientAuthenticationProvider(
+				mock(RegisteredClientRepository.class), mock(OAuth2AuthorizationService.class));
+		this.registeredClientDecoderFactory = (JwtDecoderFactory<RegisteredClient>)
+				ReflectionTestUtils.getField(authenticationProvider, "jwtDecoderFactory");
+	}
+
+	@Test
+	public void createDecoderWhenRegisteredClientNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> registeredClientDecoderFactory.createDecoder(null))
+				.withMessage("registeredClient cannot be null");
+	}
+
+	@Test
+	public void createDecoderWhenClientAuthenticationMethodNotSupportedThenThrowOAuth2AuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+
+		assertThatThrownBy(() -> this.registeredClientDecoderFactory.createDecoder(registeredClient))
+				.isInstanceOf(OAuth2AuthenticationException.class);
+	}
+
+	@Test
+	public void createDecoderWithClientSecretJwtWhenClientSecretNullThenThrowOAuth2Exception() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientSecret(null)
+				.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
+				.clientSettings(ClientSettings.builder().tokenEndpointSigningAlgorithm(MacAlgorithm.HS256).build())
+				.build();
+
+		assertThatThrownBy(() -> this.registeredClientDecoderFactory.createDecoder(registeredClient))
+				.isInstanceOf(OAuth2AuthenticationException.class);
+	}
+
+	@Test
+	public void createDecoderWithClientSecretJwtClientThenReturnDecoder() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
+				.clientSecret("0123456789abcdef0123456789ABCDEF")
+				.clientSettings(ClientSettings.builder().tokenEndpointSigningAlgorithm(MacAlgorithm.HS256).build())
+				.build();
+
+		JwtDecoder jwtDecoder = this.registeredClientDecoderFactory.createDecoder(registeredClient);
+
+		assertThat(jwtDecoder).isNotNull();
+	}
+
+	@Test
+	public void createDecoderWithClientSecretJwtTwiceThenReturnCachedDecoder() {
+		RegisteredClient.Builder registeredClientBuilder = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
+				.clientSecret("0123456789abcdef0123456789ABCDEF")
+				.clientSettings(ClientSettings.builder().tokenEndpointSigningAlgorithm(MacAlgorithm.HS256).build());
+
+		JwtDecoder decoder1 = this.registeredClientDecoderFactory.createDecoder(registeredClientBuilder.build());
+		JwtDecoder decoder2 = this.registeredClientDecoderFactory.createDecoder(registeredClientBuilder.build());
+
+		assertThat(decoder1).isNotNull();
+		assertThat(decoder2).isSameAs(decoder1);
+	}
+
+	@Test
+	public void createDecoderWithClientSecretJwtAndSecondWithChangedAlgorithmThenReturnRecreatedDecoder() {
+		RegisteredClient.Builder registeredClientBuilder = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
+				.clientSecret("0123456789abcdef0123456789ABCDEF");
+		RegisteredClient registeredClient1 = registeredClientBuilder.clientSettings(
+				ClientSettings.builder().tokenEndpointSigningAlgorithm(MacAlgorithm.HS256).build()).build();
+		RegisteredClient registeredClient = registeredClientBuilder.clientSettings(
+				ClientSettings.builder().tokenEndpointSigningAlgorithm(MacAlgorithm.HS512).build()).build();
+
+		JwtDecoder decoder1 = this.registeredClientDecoderFactory.createDecoder(registeredClient1);
+		JwtDecoder decoder2 = this.registeredClientDecoderFactory.createDecoder(registeredClient);
+
+		assertThat(decoder1).isNotNull();
+		assertThat(decoder2).isNotNull();
+		assertThat(decoder1).isNotSameAs(decoder2);
+	}
+
+	@Test
+	public void createDecoderWithClientSecretJwtAndSecondWithChangedSecretThenReturnRecreatedDecoder() {
+		RegisteredClient.Builder registeredClientBuilder = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
+				.clientSettings(ClientSettings.builder().tokenEndpointSigningAlgorithm(MacAlgorithm.HS256).build());
+		RegisteredClient registeredClient1 = registeredClientBuilder.clientSecret("0123456789abcdef0123456789ABCDEF").build();
+		RegisteredClient registeredClient2 = registeredClientBuilder.clientSecret("0123456789ABCDEF0123456789abcdef").build();
+
+		JwtDecoder decoder1 = this.registeredClientDecoderFactory.createDecoder(registeredClient1);
+		JwtDecoder decoder2 = this.registeredClientDecoderFactory.createDecoder(registeredClient2);
+
+		assertThat(decoder1).isNotNull();
+		assertThat(decoder2).isNotNull();
+		assertThat(decoder1).isNotSameAs(decoder2);
+	}
+
+	@Test
+	public void createDecoderWithPrivateKeyJwtMissingJwksUrlThenThrowOAuth2AuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
+				.clientSettings(ClientSettings.builder()
+						.tokenEndpointSigningAlgorithm(SignatureAlgorithm.RS256).build())
+				.build();
+
+		assertThatThrownBy(() -> this.registeredClientDecoderFactory.createDecoder(registeredClient))
+				.isInstanceOf(OAuth2AuthenticationException.class);
+	}
+
+	@Test
+	public void createDecoderWithPrivateKeyJwtThenReturnDecoder() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
+				.clientSettings(ClientSettings.builder()
+						.tokenEndpointSigningAlgorithm(SignatureAlgorithm.RS256).jwkSetUrl("https://client.example.com/jwks").build())
+				.build();
+
+		JwtDecoder jwtDecoder = this.registeredClientDecoderFactory.createDecoder(registeredClient);
+
+		assertThat(jwtDecoder).isNotNull();
+	}
+
+	@Test
+	public void createDecoderWithPrivateKeyJwtAndSecondWithChangedAlgorithmThenReturnRecreatedDecoder() {
+		RegisteredClient.Builder registeredClientBuilder = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT);
+
+		RegisteredClient registeredClient1 = registeredClientBuilder.clientSettings(
+				ClientSettings.builder().tokenEndpointSigningAlgorithm(SignatureAlgorithm.RS256)
+						.jwkSetUrl("https://keysite.com/jwks").build()).build();
+		RegisteredClient registeredClient2 = registeredClientBuilder.clientSettings(
+				ClientSettings.builder().tokenEndpointSigningAlgorithm(SignatureAlgorithm.RS512)
+						.jwkSetUrl("https://keysite.com/jwks").build()).build();
+
+		JwtDecoder decoder1 = this.registeredClientDecoderFactory.createDecoder(registeredClient1);
+		JwtDecoder decoder2 = this.registeredClientDecoderFactory.createDecoder(registeredClient2);
+
+		assertThat(decoder1).isNotNull();
+		assertThat(decoder2).isNotNull();
+		assertThat(decoder1).isNotSameAs(decoder2);
+	}
+
+	@Test
+	public void createDecoderWithPrivateKeyJwtAndSecondWithChangedJwksUrlThenReturnRecreatedDecoder() {
+		RegisteredClient.Builder registeredClientBuilder = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT);
+
+		RegisteredClient client1 = registeredClientBuilder.clientSettings(
+				ClientSettings.builder().tokenEndpointSigningAlgorithm(SignatureAlgorithm.RS256)
+						.jwkSetUrl("https://keysite1.com/jwks").build()).build();
+		RegisteredClient client2 = registeredClientBuilder.clientSettings(
+				ClientSettings.builder().tokenEndpointSigningAlgorithm(SignatureAlgorithm.RS256)
+						.jwkSetUrl("https://keysite2.com/jwks").build()).build();
+
+		OAuth2ClientAuthenticationToken token = new OAuth2ClientAuthenticationToken(
+				"https://auth-server/oauth2/token", "client-1", ClientAuthenticationMethod.CLIENT_SECRET_JWT, "jwt", null);
+
+		JwtDecoder decoder1 = this.registeredClientDecoderFactory.createDecoder(client1);
+		JwtDecoder decoder2 = this.registeredClientDecoderFactory.createDecoder(client2);
+
+		assertThat(decoder1).isNotNull();
+		assertThat(decoder2).isNotNull();
+		assertThat(decoder1).isNotSameAs(decoder2);
+	}
+
+	@Test
+	public void createDecoderWithPrivateKeyJwtNullAlgorithmThenReturnDefaultRS256Decoder() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
+				.clientSettings(
+						ClientSettings.builder()
+								.jwkSetUrl("https://keysite1.com/jwks")
+								.build())
+				.build();
+
+		JwtDecoder decoder = this.registeredClientDecoderFactory.createDecoder(registeredClient);
+		assertThat(decoder).isNotNull();
+	}
+
+	@Test
+	public void validateClientSecretJwtTokenWhenValidThenReturnJwtObject() throws Exception {
+		RegisteredClient registeredClient = defaultRegisteredClient();
+		JwtDecoder jwtDecoder = this.registeredClientDecoderFactory.createDecoder(registeredClient);
+		String clientJwtAssertion = clientSecretJwtAssertion(registeredClient,
+				new JWTClaimsSet.Builder()
+						.issuer(registeredClient.getClientId())
+						.subject(registeredClient.getClientId())
+						.expirationTime(Date.from(Instant.now()))
+						.build());
+
+		assertThat(jwtDecoder.decode(clientJwtAssertion)).isNotNull();
+	}
+
+	@Test
+	public void validateClientSecretJwtTokenWhenBadIssuerThenThrowJwtException() throws Exception {
+		RegisteredClient registeredClient = defaultRegisteredClient();
+		JwtDecoder jwtDecoder = this.registeredClientDecoderFactory.createDecoder(registeredClient);
+		String clientJwtAssertion = clientSecretJwtAssertion(registeredClient,
+				new JWTClaimsSet.Builder()
+						.issuer("bad-issuer")
+						.subject(registeredClient.getClientId())
+						.expirationTime(Date.from(Instant.now()))
+						.build());
+
+		assertThatThrownBy(() -> jwtDecoder.decode(clientJwtAssertion))
+				.isInstanceOf(JwtException.class)
+				.extracting("message")
+				.matches(s -> s.toString().contains("The iss claim is not valid"));
+	}
+
+	@Test
+	public void validateClientSecretJwtTokenWhenBadSubjectThenThrowJwtException() throws Exception {
+		RegisteredClient registeredClient = defaultRegisteredClient();
+		JwtDecoder jwtDecoder = this.registeredClientDecoderFactory.createDecoder(registeredClient);
+		String clientJwtAssertion = clientSecretJwtAssertion(registeredClient,
+				new JWTClaimsSet.Builder()
+						.issuer(registeredClient.getClientId())
+						.subject("bad-client")
+						.expirationTime(Date.from(Instant.now()))
+						.build());
+
+		assertThatThrownBy(() -> jwtDecoder.decode(clientJwtAssertion))
+				.isInstanceOf(JwtException.class)
+				.extracting("message")
+				.matches(s -> s.toString().contains("The sub claim is not valid"));
+	}
+
+	@Test
+	public void validateClientSecretJwtTokenWhenNoExpClaimThenThrowJwtException() throws Exception {
+		RegisteredClient registeredClient = defaultRegisteredClient();
+		JwtDecoder jwtDecoder = this.registeredClientDecoderFactory.createDecoder(registeredClient);
+		String clientJwtAssertion = clientSecretJwtAssertion(registeredClient,
+				new JWTClaimsSet.Builder()
+						.subject(registeredClient.getClientId())
+						.issuer(registeredClient.getClientId())
+						.build());
+
+		assertThatThrownBy(() -> jwtDecoder.decode(clientJwtAssertion))
+				.isInstanceOf(JwtException.class)
+				.extracting("message")
+				.matches(s -> s.toString().contains("The exp claim is not valid"));
+	}
+
+	@Test
+	public void validateClientSecretJwtTokenWhenExpiredThenThrowJwtException() throws Exception {
+		RegisteredClient registeredClient = defaultRegisteredClient();
+		JwtDecoder jwtDecoder = this.registeredClientDecoderFactory.createDecoder(registeredClient);
+		String clientJwtAssertion = clientSecretJwtAssertion(registeredClient,
+				new JWTClaimsSet.Builder()
+						.subject(registeredClient.getClientId())
+						.issuer(registeredClient.getClientId())
+						.expirationTime(Date.from(Instant.now().minusSeconds(240)))
+						.build());
+
+		assertThatThrownBy(() -> jwtDecoder.decode(clientJwtAssertion))
+				.isInstanceOf(JwtException.class)
+				.extracting("message")
+				.matches(s -> s.toString().contains("Jwt expired at"));
+	}
+
+	@Test
+	public void validateClientSecretJwtTokenWhenExpiredWithinSkewThenReturnJwtObject() throws Exception {
+		RegisteredClient registeredClient = defaultRegisteredClient();
+		JwtDecoder jwtDecoder = this.registeredClientDecoderFactory.createDecoder(registeredClient);
+		String clientJwtAssertion = clientSecretJwtAssertion(registeredClient,
+				new JWTClaimsSet.Builder()
+						.subject(registeredClient.getClientId())
+						.issuer(registeredClient.getClientId())
+						.expirationTime(Date.from(Instant.now().minusSeconds(30)))
+						.build());
+
+		assertThat(jwtDecoder.decode(clientJwtAssertion)).isNotNull();
+	}
+
+	@Test
+	public void validateClientSecretJwtTokenWhenInvalidNbfThenThrowJwtException() throws Exception {
+		RegisteredClient registeredClient = defaultRegisteredClient();
+		JwtDecoder jwtDecoder = this.registeredClientDecoderFactory.createDecoder(registeredClient);
+		String clientJwtAssertion = clientSecretJwtAssertion(registeredClient,
+				new JWTClaimsSet.Builder()
+						.subject(registeredClient.getClientId())
+						.issuer(registeredClient.getClientId())
+						.expirationTime(Date.from(Instant.now()))
+						.notBeforeTime(Date.from(Instant.now().plusSeconds(90)))
+						.build());
+
+		assertThatThrownBy(() -> jwtDecoder.decode(clientJwtAssertion))
+				.isInstanceOf(JwtException.class)
+				.extracting("message")
+				.matches(s -> s.toString().contains("Jwt used before"));
+	}
+
+	@Test
+	public void validateClientSecretJwtTokenWhenInvalidIatThenThrowJwtException() throws Exception {
+		RegisteredClient registeredClient = defaultRegisteredClient();
+		JwtDecoder jwtDecoder = this.registeredClientDecoderFactory.createDecoder(registeredClient);
+		String clientJwtAssertion = clientSecretJwtAssertion(registeredClient,
+				new JWTClaimsSet.Builder()
+						.subject(registeredClient.getClientId())
+						.issuer(registeredClient.getClientId())
+						.expirationTime(Date.from(Instant.now()))
+						.issueTime(Date.from(Instant.now().plusSeconds(90)))
+						.build());
+
+		assertThatThrownBy(() -> jwtDecoder.decode(clientJwtAssertion))
+				.isInstanceOf(JwtException.class)
+				.extracting("message")
+				.matches(s -> s.toString().contains("expiresAt must be after issuedAt"));
+	}
+
+	@Test
+	public void validatePrivateKeyJwtTokenWhenValidThenReturnJwtObject() throws Exception {
+		KeyPairGenerator gen = KeyPairGenerator.getInstance("RSA");
+		gen.initialize(2048);
+		KeyPair keyPair = gen.generateKeyPair();
+
+		JWK jwk = new RSAKey.Builder((RSAPublicKey) keyPair.getPublic())
+				.keyUse(KeyUse.SIGNATURE)
+				.keyID(UUID.randomUUID().toString())
+				.build();
+
+		String jwks = "{\"keys\":[" + jwk.toJSONString() + "]}";
+
+		try (MockWebServer server = new MockWebServer()) {
+			String jwkSetUrl = server.url("/.well-known/jwks.json").toString();
+			server.enqueue(new MockResponse().setBody(jwks));
+
+			RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+					.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
+					.clientSettings(ClientSettings.builder()
+							.tokenEndpointSigningAlgorithm(SignatureAlgorithm.RS256)
+							.jwkSetUrl(jwkSetUrl).build())
+					.build();
+
+			JwtDecoder jwtDecoder = this.registeredClientDecoderFactory.createDecoder(registeredClient);
+
+			JWTClaimsSet claimsSet = new JWTClaimsSet.Builder()
+					.issuer(registeredClient.getClientId())
+					.subject(registeredClient.getClientId())
+					.expirationTime(Date.from(Instant.now()))
+					.build();
+			SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.RS256), claimsSet);
+			JWSSigner signer = new RSASSASigner(keyPair.getPrivate());
+			signedJWT.sign(signer);
+			String clientJwtAssertion = signedJWT.serialize();
+
+			assertThat(jwtDecoder.decode(clientJwtAssertion)).isNotNull();
+
+			server.shutdown();
+		}
+
+	}
+
+	private RegisteredClient defaultRegisteredClient() {
+		return TestRegisteredClients.registeredClient()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
+				.clientSecret("0123456789abcdef0123456789ABCDEF")
+				.clientSettings(ClientSettings.builder().tokenEndpointSigningAlgorithm(MacAlgorithm.HS256).build())
+				.build();
+	}
+
+	private String clientSecretJwtAssertion(RegisteredClient registeredClient, JWTClaimsSet claimsSet) throws JOSEException {
+		SignedJWT signedJWT = new SignedJWT(new JWSHeader(JWSAlgorithm.HS256), claimsSet);
+		JWSSigner signer = new MACSigner(registeredClient.getClientSecret().getBytes(StandardCharsets.UTF_8));
+		signedJWT.sign(signer);
+		String clientJwtAssertion = signedJWT.serialize();
+		return clientJwtAssertion;
+	}
+}

+ 12 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java

@@ -43,6 +43,7 @@ import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
 import org.springframework.security.jackson2.SecurityJackson2Modules;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository.RegisteredClientParametersMapper;
 import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository.RegisteredClientRowMapper;
 import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
@@ -148,6 +149,17 @@ public class JdbcRegisteredClientRepositoryTests {
 		assertThat(registeredClient).isEqualTo(expectedRegisteredClient);
 	}
 
+	@Test
+	public void saveWhenCustomTokenEndpointSigningAlgorithmsThenSaved() {
+		RegisteredClient expectedRegisteredClient = TestRegisteredClients.registeredClient()
+				.clientSettings(ClientSettings.builder()
+						.tokenEndpointSigningAlgorithm(MacAlgorithm.HS256).build())
+				.build();
+		this.registeredClientRepository.save(expectedRegisteredClient);
+		RegisteredClient registeredClient = this.registeredClientRepository.findById(expectedRegisteredClient.getId());
+		assertThat(registeredClient).isEqualTo(expectedRegisteredClient);
+	}
+
 	@Test
 	public void saveWhenClientSecretNullThenSaved() {
 		RegisteredClient expectedRegisteredClient = TestRegisteredClients.registeredClient()

+ 21 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/ClientSettingsTests.java

@@ -16,6 +16,8 @@
 package org.springframework.security.oauth2.server.authorization.config;
 
 import org.junit.Test;
+import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
+import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -29,9 +31,10 @@ public class ClientSettingsTests {
 	@Test
 	public void buildWhenDefaultThenDefaultsAreSet() {
 		ClientSettings clientSettings = ClientSettings.builder().build();
-		assertThat(clientSettings.getSettings()).hasSize(2);
+		assertThat(clientSettings.getSettings()).hasSize(3);
 		assertThat(clientSettings.isRequireProofKey()).isFalse();
 		assertThat(clientSettings.isRequireAuthorizationConsent()).isFalse();
+		assertThat(clientSettings.getTokenEndpointSigningAlgorithm()).isEqualTo(SignatureAlgorithm.RS256);
 	}
 
 	@Test
@@ -50,13 +53,29 @@ public class ClientSettingsTests {
 		assertThat(clientSettings.isRequireAuthorizationConsent()).isTrue();
 	}
 
+	@Test
+	public void tokenEndpointAlgorithmWhenHS256ThenSet() {
+		ClientSettings clientSettings = ClientSettings.builder()
+				.tokenEndpointSigningAlgorithm(MacAlgorithm.HS256)
+				.build();
+		assertThat(clientSettings.getTokenEndpointSigningAlgorithm()).isEqualTo(MacAlgorithm.HS256);
+	}
+
+	@Test
+	public void whenJwkSetUrlSetThenSet() {
+		ClientSettings clientSettings = ClientSettings.builder()
+				.jwkSetUrl("https://auth-server:9000/jwks")
+				.build();
+		assertThat(clientSettings.getJwkSetUrl()).isEqualTo("https://auth-server:9000/jwks");
+	}
+
 	@Test
 	public void settingWhenCustomThenSet() {
 		ClientSettings clientSettings = ClientSettings.builder()
 				.setting("name1", "value1")
 				.settings(settings -> settings.put("name2", "value2"))
 				.build();
-		assertThat(clientSettings.getSettings()).hasSize(4);
+		assertThat(clientSettings.getSettings()).hasSize(5);
 		assertThat(clientSettings.<String>getSetting("name1")).isEqualTo("value1");
 		assertThat(clientSettings.<String>getSetting("name2")).isEqualTo("value2");
 	}

+ 161 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java

@@ -27,6 +27,7 @@ import org.junit.Test;
 import org.mockito.ArgumentCaptor;
 
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
@@ -37,6 +38,7 @@ import org.springframework.security.oauth2.core.OAuth2TokenType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.oidc.OidcClientRegistration;
+import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.security.oauth2.jwt.JoseHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
@@ -423,6 +425,165 @@ public class OidcClientRegistrationAuthenticationProviderTests {
 		assertThat(clientRegistrationResult.getRegistrationAccessToken()).isEqualTo(jwt.getTokenValue());
 	}
 
+	private OidcClientRegistrationAuthenticationToken jwtClientAuthenticationRegistration(
+			String tokenAuthenticationMethod, String tokenSigningAlgorithm, String jwkSetUrl) {
+		Jwt jwt = createJwtClientRegistration();
+		OAuth2AccessToken jwtAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+				jwt.getTokenValue(), jwt.getIssuedAt(),
+				jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE));
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(
+				registeredClient, jwtAccessToken, jwt.getClaims()).build();
+		when(this.authorizationService.findByToken(
+				eq(jwtAccessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN)))
+				.thenReturn(authorization);
+		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwtClientConfiguration());
+
+		JwtAuthenticationToken principal = new JwtAuthenticationToken(
+				jwt, AuthorityUtils.createAuthorityList("SCOPE_client.create"));
+		// @formatter:off
+		OidcClientRegistration.Builder clientRegistrationBuilder = OidcClientRegistration.builder()
+				.clientId("client-id")
+				.clientName("client-name")
+				.redirectUri("https://client.example.com")
+				.grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
+				.scope("scope1")
+				.tokenEndpointAuthenticationMethod(tokenAuthenticationMethod);
+		// @formatter:on
+
+		if (tokenSigningAlgorithm != null) {
+			clientRegistrationBuilder = clientRegistrationBuilder.tokenEndpointAuthenticationSigningAlgorithm(tokenSigningAlgorithm);
+		}
+
+		if (jwkSetUrl != null) {
+			clientRegistrationBuilder = clientRegistrationBuilder.jwkSetUrl(jwkSetUrl);
+		}
+
+		return new OidcClientRegistrationAuthenticationToken(principal, clientRegistrationBuilder.build());
+	}
+
+	@Test
+	public void authenticateWhenClientRegistrationRequestAndPrivateKeyJwtAndAlgorithmNoneThenThrowOAuth2AuthenticationException() {
+		OidcClientRegistrationAuthenticationToken authentication = jwtClientAuthenticationRegistration(
+				ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue(), "none", "https://client.example.com/jwks");
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError()).extracting("errorCode")
+				.isEqualTo("invalid_client_metadata");
+	}
+
+	@Test
+	public void authenticateWhenClientRegistrationRequestAndPrivateKeyJwtAndMacAlgorithmThenThrowOAuth2AuthenticationException() {
+		OidcClientRegistrationAuthenticationToken authentication = jwtClientAuthenticationRegistration(
+				ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue(), "HS256", "https://client.example.com/jwks");
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError()).extracting("errorCode")
+				.isEqualTo("invalid_client_metadata");
+	}
+
+	@Test
+	public void authenticateWhenClientRegistrationRequestAndPrivateKeyJwtAndNoJwkSetUrlThenThrowOAuth2AuthenticationException() {
+		OidcClientRegistrationAuthenticationToken authentication = jwtClientAuthenticationRegistration(
+				ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue(), "RS256", null);
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError()).extracting("errorCode")
+				.isEqualTo("invalid_client_metadata");
+	}
+
+	@Test
+	public void authenticateWhenClientRegistrationRequestAndClientSecretJwtAndPkiAlgorithmThenThrowOAuth2AuthenticationException() {
+		OidcClientRegistrationAuthenticationToken authentication = jwtClientAuthenticationRegistration(
+				ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue(), "RS256", "https://client.example.com/jwks");
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError()).extracting("errorCode")
+				.isEqualTo("invalid_client_metadata");
+	}
+
+	@Test
+	public void authenticateWhenClientRegistrationAndProperClientSecretJwtRegistrationThenRegistered() {
+		OidcClientRegistrationAuthenticationToken authentication = jwtClientAuthenticationRegistration(
+				ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue(), "HS512", null);
+
+		Authentication authenticationResult = this.authenticationProvider.authenticate(authentication);
+
+		assertThat(authenticationResult).isNotNull();
+
+		ArgumentCaptor<RegisteredClient> registeredClientCaptor = ArgumentCaptor.forClass(RegisteredClient.class);
+		verify(this.registeredClientRepository).save(registeredClientCaptor.capture());
+		RegisteredClient registeredClientResult = registeredClientCaptor.getValue();
+
+		assertThat(registeredClientResult).isNotNull();
+		assertThat(registeredClientResult.getClientSecret()).hasSizeGreaterThan(32);
+		assertThat(registeredClientResult.getClientSettings().getTokenEndpointSigningAlgorithm()).isEqualTo(MacAlgorithm.HS512);
+	}
+
+	@Test
+	public void authenticateWhenClientRegistrationAndClientSecretJwtAndNullAlgorithmThenDefaultAlgorithmHS256() {
+		OidcClientRegistrationAuthenticationToken authentication = jwtClientAuthenticationRegistration(
+				ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue(), null, null);
+
+		Authentication authenticationResult = this.authenticationProvider.authenticate(authentication);
+
+		assertThat(authenticationResult).isNotNull();
+
+		ArgumentCaptor<RegisteredClient> registeredClientCaptor = ArgumentCaptor.forClass(RegisteredClient.class);
+		verify(this.registeredClientRepository).save(registeredClientCaptor.capture());
+		RegisteredClient registeredClientResult = registeredClientCaptor.getValue();
+
+		assertThat(registeredClientResult).isNotNull();
+		assertThat(registeredClientResult.getClientSecret()).hasSizeGreaterThan(32);
+		assertThat(registeredClientResult.getClientSettings().getTokenEndpointSigningAlgorithm()).isEqualTo(MacAlgorithm.HS256);
+		assertThat(registeredClientResult.getClientAuthenticationMethods()).contains(ClientAuthenticationMethod.CLIENT_SECRET_JWT);
+	}
+
+	@Test
+	public void authenticateWhenClientRegistrationAndProperPrivateKeyJwtRegistrationThenRegistered() {
+		OidcClientRegistrationAuthenticationToken authentication = jwtClientAuthenticationRegistration(
+				ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue(), "RS512", "https://client.example.com/jwks");
+
+		OidcClientRegistrationAuthenticationToken authenticationResult =
+				(OidcClientRegistrationAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+		assertThat(authenticationResult).isNotNull();
+
+		ArgumentCaptor<RegisteredClient> registeredClientCaptor = ArgumentCaptor.forClass(RegisteredClient.class);
+		verify(this.registeredClientRepository).save(registeredClientCaptor.capture());
+		RegisteredClient registeredClientResult = registeredClientCaptor.getValue();
+
+		assertThat(registeredClientResult).isNotNull();
+		assertThat(registeredClientResult.getClientSettings().getJwkSetUrl()).isEqualTo("https://client.example.com/jwks");
+		assertThat(registeredClientResult.getClientSettings().getTokenEndpointSigningAlgorithm()).isEqualTo(SignatureAlgorithm.RS512);
+
+		assertThat(authenticationResult.getClientRegistration().getTokenEndpointAuthenticationSigningAlgorithm()).isEqualTo("RS512");
+		assertThat(authenticationResult.getClientRegistration().getJwkSetUrl().toString()).isEqualTo("https://client.example.com/jwks");
+		assertThat(registeredClientResult.getClientAuthenticationMethods()).contains(ClientAuthenticationMethod.PRIVATE_KEY_JWT);
+	}
+
+	@Test
+	public void authenticateWhenClientRegistrationAndPrivateKeyJwtAndNullAlgorithmThenDefaultAlgorithmRS256() {
+		OidcClientRegistrationAuthenticationToken authentication = jwtClientAuthenticationRegistration(
+				ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue(), null, "https://client.example.com/jwks");
+
+		Authentication authenticationResult = this.authenticationProvider.authenticate(authentication);
+
+		assertThat(authenticationResult).isNotNull();
+
+		ArgumentCaptor<RegisteredClient> registeredClientCaptor = ArgumentCaptor.forClass(RegisteredClient.class);
+		verify(this.registeredClientRepository).save(registeredClientCaptor.capture());
+		RegisteredClient registeredClientResult = registeredClientCaptor.getValue();
+
+		assertThat(registeredClientResult).isNotNull();
+		assertThat(registeredClientResult.getClientSettings().getJwkSetUrl()).isEqualTo("https://client.example.com/jwks");
+		assertThat(registeredClientResult.getClientSettings().getTokenEndpointSigningAlgorithm()).isEqualTo(SignatureAlgorithm.RS256);
+	}
+
 	@Test
 	public void authenticateWhenClientConfigurationRequestAndAccessTokenNotAuthorizedThenThrowOAuth2AuthenticationException() {
 		Jwt jwt = createJwt(Collections.singleton("unauthorized.scope"));

+ 1 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcProviderConfigurationEndpointFilterTests.java

@@ -116,7 +116,7 @@ public class OidcProviderConfigurationEndpointFilterTests {
 		assertThat(providerConfigurationResponse).contains("\"grant_types_supported\":[\"authorization_code\",\"client_credentials\",\"refresh_token\"]");
 		assertThat(providerConfigurationResponse).contains("\"subject_types_supported\":[\"public\"]");
 		assertThat(providerConfigurationResponse).contains("\"id_token_signing_alg_values_supported\":[\"RS256\"]");
-		assertThat(providerConfigurationResponse).contains("\"token_endpoint_auth_methods_supported\":[\"client_secret_basic\",\"client_secret_post\"]");
+		assertThat(providerConfigurationResponse).contains("\"token_endpoint_auth_methods_supported\":[\"client_secret_basic\",\"client_secret_post\",\"client_secret_jwt\",\"private_key_jwt\"]");
 	}
 
 	@Test

+ 3 - 3
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationServerMetadataEndpointFilterTests.java

@@ -115,14 +115,14 @@ public class OAuth2AuthorizationServerMetadataEndpointFilterTests {
 		assertThat(authorizationServerMetadataResponse).contains("\"issuer\":\"https://example.com/issuer1\"");
 		assertThat(authorizationServerMetadataResponse).contains("\"authorization_endpoint\":\"https://example.com/issuer1/oauth2/v1/authorize\"");
 		assertThat(authorizationServerMetadataResponse).contains("\"token_endpoint\":\"https://example.com/issuer1/oauth2/v1/token\"");
-		assertThat(authorizationServerMetadataResponse).contains("\"token_endpoint_auth_methods_supported\":[\"client_secret_basic\",\"client_secret_post\"]");
+		assertThat(authorizationServerMetadataResponse).contains("\"token_endpoint_auth_methods_supported\":[\"client_secret_basic\",\"client_secret_post\",\"client_secret_jwt\",\"private_key_jwt\"]");
 		assertThat(authorizationServerMetadataResponse).contains("\"jwks_uri\":\"https://example.com/issuer1/oauth2/v1/jwks\"");
 		assertThat(authorizationServerMetadataResponse).contains("\"response_types_supported\":[\"code\"]");
 		assertThat(authorizationServerMetadataResponse).contains("\"grant_types_supported\":[\"authorization_code\",\"client_credentials\",\"refresh_token\"]");
 		assertThat(authorizationServerMetadataResponse).contains("\"revocation_endpoint\":\"https://example.com/issuer1/oauth2/v1/revoke\"");
-		assertThat(authorizationServerMetadataResponse).contains("\"revocation_endpoint_auth_methods_supported\":[\"client_secret_basic\",\"client_secret_post\"]");
+		assertThat(authorizationServerMetadataResponse).contains("\"revocation_endpoint_auth_methods_supported\":[\"client_secret_basic\",\"client_secret_post\",\"client_secret_jwt\",\"private_key_jwt\"]");
 		assertThat(authorizationServerMetadataResponse).contains("\"introspection_endpoint\":\"https://example.com/issuer1/oauth2/v1/introspect\"");
-		assertThat(authorizationServerMetadataResponse).contains("\"introspection_endpoint_auth_methods_supported\":[\"client_secret_basic\",\"client_secret_post\"]");
+		assertThat(authorizationServerMetadataResponse).contains("\"introspection_endpoint_auth_methods_supported\":[\"client_secret_basic\",\"client_secret_post\",\"client_secret_jwt\",\"private_key_jwt\"]");
 		assertThat(authorizationServerMetadataResponse).contains("\"code_challenge_methods_supported\":[\"plain\",\"S256\"]");
 	}
 

+ 148 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverterTests.java

@@ -0,0 +1,148 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.server.authorization.web.authentication;
+
+import org.junit.Test;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
+
+import static org.assertj.core.api.Assertions.*;
+import static org.assertj.core.api.Assertions.entry;
+
+/**
+ * Tests for {@link JwtClientAssertionAuthenticationConverter}
+ *
+ * @author Rafal Lewczuk
+ */
+public class JwtClientAssertionAuthenticationConverterTests {
+
+	private JwtClientAssertionAuthenticationConverter converter = new JwtClientAssertionAuthenticationConverter();
+
+	private static final String JWT_BEARER_TYPE = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
+
+	private void shouldThrow(MockHttpServletRequest request, String errorCode) {
+		assertThatThrownBy(() -> this.converter.convert(request))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(errorCode);
+	}
+
+	@Test
+	public void convertWhenClientAssertionTypeNullThenReturnNull() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION, "some_jwt_assertion");
+		Authentication authentication = this.converter.convert(request);
+		assertThat(authentication).isNull();
+	}
+
+	@Test
+	public void convertWhenMissingClientAssertionThenReturnNull() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, JWT_BEARER_TYPE);
+		Authentication authentication = this.converter.convert(request);
+		assertThat(authentication).isNull();
+	}
+
+	@Test
+	public void convertWhenMissingClientIdThenInvalidRequestError() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, JWT_BEARER_TYPE);
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION, "some_jwt_assertion");
+		shouldThrow(request, OAuth2ErrorCodes.INVALID_REQUEST);
+	}
+
+	@Test
+	public void convertWhenMultipleClientIdThenInvalidRequestError() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.addParameter(OAuth2ParameterNames.CLIENT_ID, "some_client");
+		request.addParameter(OAuth2ParameterNames.CLIENT_ID, "other_client");
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, JWT_BEARER_TYPE);
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION, "some_jwt_assertion");
+		shouldThrow(request, OAuth2ErrorCodes.INVALID_REQUEST);
+	}
+
+	@Test
+	public void convertWhenBadAssertionTypeThenInvalidRequestError() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.addParameter(OAuth2ParameterNames.CLIENT_ID, "some_client");
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, "borken");
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION, "some_jwt_assertion");
+		shouldThrow(request, OAuth2ErrorCodes.INVALID_REQUEST);
+	}
+
+	@Test
+	public void convertWhenMissingClientJwtAssertionTypeThenDoNotProcessClientIdAndReturnNull() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.addParameter(OAuth2ParameterNames.CLIENT_ID, "some_client");
+		request.addParameter(OAuth2ParameterNames.CLIENT_ID, "throw_something_when_client_id_is_processed");
+		Authentication authentication = this.converter.convert(request);
+		assertThat(authentication).isNull();
+	}
+
+	@Test
+	public void convertWhenMultipleAssertionsThenInvalidRequestError() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.addParameter(OAuth2ParameterNames.CLIENT_ID, "some_client");
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, JWT_BEARER_TYPE);
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION, "some_jwt_assertion");
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION, "other_jwt_assertion");
+		shouldThrow(request, OAuth2ErrorCodes.INVALID_REQUEST);
+	}
+
+	@Test
+	public void convertWhenValidAssertionJwt() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.addParameter(OAuth2ParameterNames.CLIENT_ID, "some_client");
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, JWT_BEARER_TYPE);
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION, "some_jwt_assertion");
+		request.setRequestURI("/oauth2/token");
+		OAuth2ClientAuthenticationToken authentication = (OAuth2ClientAuthenticationToken) this.converter.convert(request);
+		assertThat(authentication).isNotNull();
+		assertThat(authentication.getRequestUri()).isEqualTo("/oauth2/token");
+		assertThat(authentication.getPrincipal()).isEqualTo("some_client");
+		assertThat(authentication.getCredentials()).isEqualTo("some_jwt_assertion");
+	}
+
+	@Test
+	public void convertWhenConfidentialClientWithPkceParametersThenAdditionalParametersIncluded() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
+		request.addParameter(OAuth2ParameterNames.CODE, "code");
+		request.addParameter(PkceParameterNames.CODE_VERIFIER, "code-verifier-1");
+		request.addParameter(OAuth2ParameterNames.CLIENT_ID, "some_client");
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, JWT_BEARER_TYPE);
+		request.addParameter(OAuth2ParameterNames.CLIENT_ASSERTION, "some_jwt_assertion");
+		request.setRequestURI("/oauth2/token");
+		OAuth2ClientAuthenticationToken authentication = (OAuth2ClientAuthenticationToken) this.converter.convert(request);
+		assertThat(authentication).isNotNull();
+		assertThat(authentication.getRequestUri()).isEqualTo("/oauth2/token");
+		assertThat(authentication.getPrincipal()).isEqualTo("some_client");
+		assertThat(authentication.getCredentials()).isEqualTo("some_jwt_assertion");
+		assertThat(authentication.getAdditionalParameters())
+				.containsOnly(
+						entry(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue()),
+						entry(OAuth2ParameterNames.CODE, "code"),
+						entry(PkceParameterNames.CODE_VERIFIER, "code-verifier-1"));
+	}
+}