Pārlūkot izejas kodu

Implement Proof Key for Code Exchange (PKCE) RFC 7636

See https://tools.ietf.org/html/rfc7636

Closes gh-45
Daniel Garnier-Moiroux 5 gadi atpakaļ
vecāks
revīzija
ab090445b3
13 mainītis faili ar 1008 papildinājumiem un 55 dzēšanām
  1. 2 2
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java
  2. 73 15
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java
  3. 37 5
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java
  4. 0 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java
  5. 35 2
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java
  6. 30 9
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java
  7. 2 0
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java
  8. 178 0
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2PkceTests.java
  9. 7 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java
  10. 322 15
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java
  11. 41 6
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java
  12. 232 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java
  13. 49 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java

+ 2 - 2
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AccessTokenAuthenticationToken.java

@@ -38,8 +38,8 @@ import java.util.Collections;
  */
 public class OAuth2AccessTokenAuthenticationToken extends AbstractAuthenticationToken {
 	private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
-	private final RegisteredClient registeredClient;
-	private final Authentication clientPrincipal;
+	private RegisteredClient registeredClient;
+	private Authentication clientPrincipal;
 	private final OAuth2AccessToken accessToken;
 
 	/**

+ 73 - 15
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -23,6 +23,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.jose.JoseHeader;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
@@ -33,15 +34,20 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.TokenType;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 
+import java.nio.charset.StandardCharsets;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
 import java.net.MalformedURLException;
 import java.net.URI;
 import java.net.URL;
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
+import java.util.Base64;
 import java.util.Collections;
 
 /**
@@ -85,29 +91,30 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 				(OAuth2AuthorizationCodeAuthenticationToken) authentication;
 
 		OAuth2ClientAuthenticationToken clientPrincipal = null;
+		RegisteredClient registeredClient = null;
 		if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(authorizationCodeAuthentication.getPrincipal().getClass())) {
 			clientPrincipal = (OAuth2ClientAuthenticationToken) authorizationCodeAuthentication.getPrincipal();
-		}
-		if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) {
+			registeredClient = clientPrincipal.getRegisteredClient();
+		} else if (StringUtils.hasText(authorizationCodeAuthentication.getClientId())) {
+			// When the principal is a string, it is the clientId, REQUIRED for public clients
+			String clientId = authorizationCodeAuthentication.getClientId();
+			registeredClient = this.registeredClientRepository.findByClientId(clientId);
+			if (registeredClient == null) {
+				throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
+			}
+		} else {
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
 		}
 
-		// TODO Authenticate public client
-		// A client MAY use the "client_id" request parameter to identify itself
-		// when sending requests to the token endpoint.
-		// In the "authorization_code" "grant_type" request to the token endpoint,
-		// an unauthenticated client MUST send its "client_id" to prevent itself
-		// from inadvertently accepting a code intended for a client with a different "client_id".
-		// This protects the client from substitution of the authentication code.
+		if (clientPrincipal != null && !clientPrincipal.isAuthenticated()) {
+			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
+		}
 
 		OAuth2Authorization authorization = this.authorizationService.findByToken(
 				authorizationCodeAuthentication.getCode(), TokenType.AUTHORIZATION_CODE);
 		if (authorization == null) {
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
-		if (!clientPrincipal.getRegisteredClient().getId().equals(authorization.getRegisteredClientId())) {
-			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
-		}
 
 		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
 				OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
@@ -116,6 +123,35 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
 
+		if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) {
+			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
+		}
+
+
+		String codeChallenge;
+		Object codeChallengeParameter = authorizationRequest
+				.getAdditionalParameters()
+				.get(PkceParameterNames.CODE_CHALLENGE);
+
+		if (codeChallengeParameter != null) {
+			codeChallenge = (String) codeChallengeParameter;
+
+			String codeChallengeMethod = (String) authorizationRequest
+					.getAdditionalParameters()
+					.get(PkceParameterNames.CODE_CHALLENGE_METHOD);
+
+			String codeVerifier = (String) authorizationCodeAuthentication
+					.getAdditionalParameters()
+					.get(PkceParameterNames.CODE_VERIFIER);
+
+			if (!pkceCodeVerifierValid(codeVerifier, codeChallenge, codeChallengeMethod)) {
+				throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
+			}
+		} else if (registeredClient.getClientSettings().requireProofKey()){
+			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
+		}
+
+
 		JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
 
 		// TODO Allow configuration for issuer claim
@@ -130,7 +166,7 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 		JwtClaimsSet jwtClaimsSet = JwtClaimsSet.withClaims()
 				.issuer(issuer)
 				.subject(authorization.getPrincipalName())
-				.audience(Collections.singletonList(clientPrincipal.getRegisteredClient().getClientId()))
+				.audience(Collections.singletonList(registeredClient.getClientId()))
 				.issuedAt(issuedAt)
 				.expiresAt(expiresAt)
 				.notBefore(issuedAt)
@@ -148,8 +184,30 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 				.build();
 		this.authorizationService.save(authorization);
 
-		return new OAuth2AccessTokenAuthenticationToken(
-				clientPrincipal.getRegisteredClient(), clientPrincipal, accessToken);
+		return clientPrincipal != null ?
+				new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken) :
+				new OAuth2AccessTokenAuthenticationToken(registeredClient, new OAuth2ClientAuthenticationToken(registeredClient), accessToken);
+	}
+
+	private boolean pkceCodeVerifierValid(String codeVerifier, String codeChallenge, String codeChallengeMethod) {
+		if (codeVerifier == null) {
+			return false;
+		} else if (codeChallengeMethod == null || codeChallengeMethod.equals("plain")) {
+			return  codeVerifier.equals(codeChallenge);
+		} else if ("S256".equals(codeChallengeMethod)) {
+			try {
+				MessageDigest md = MessageDigest.getInstance("SHA-256");
+				byte[] digest = md.digest(codeVerifier.getBytes(StandardCharsets.US_ASCII));
+				String encodedVerifier = Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
+				return codeChallenge.equals(encodedVerifier);
+			} catch (NoSuchAlgorithmException e) {
+				// It is unlikely that SHA-256 is not available on the server. If it is not available,
+				// there will likely be bigger issues as well. We default to SERVER_ERROR.
+			}
+		}
+
+		// Unsupported algorithm should be caught in OAuth2AuthorizationEndpointFilter
+		throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR));
 	}
 
 	@Override

+ 37 - 5
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationToken.java

@@ -22,6 +22,7 @@ import org.springframework.security.core.SpringSecurityCoreVersion2;
 import org.springframework.util.Assert;
 
 import java.util.Collections;
+import java.util.Map;
 
 /**
  * An {@link Authentication} implementation used for the OAuth 2.0 Authorization Code Grant.
@@ -35,10 +36,11 @@ import java.util.Collections;
  */
 public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenticationToken {
 	private static final long serialVersionUID = SpringSecurityCoreVersion2.SERIAL_VERSION_UID;
-	private String code;
+	private final String code;
 	private Authentication clientPrincipal;
-	private String clientId;
-	private String redirectUri;
+	private final String clientId;
+	private final String redirectUri;
+	private final Map<String, Object> additionalParameters;
 
 	/**
 	 * Constructs an {@code OAuth2AuthorizationCodeAuthenticationToken} using the provided parameters.
@@ -46,15 +48,24 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
 	 * @param code the authorization code
 	 * @param clientPrincipal the authenticated client principal
 	 * @param redirectUri the redirect uri
+	 * @param additionalParameters the additional parameters
 	 */
 	public OAuth2AuthorizationCodeAuthenticationToken(String code,
-			Authentication clientPrincipal, @Nullable String redirectUri) {
+			Authentication clientPrincipal, @Nullable String redirectUri,
+			Map<String, Object> additionalParameters) {
 		super(Collections.emptyList());
 		Assert.hasText(code, "code cannot be empty");
 		Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
 		this.code = code;
 		this.clientPrincipal = clientPrincipal;
 		this.redirectUri = redirectUri;
+		this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap());
+
+		if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(this.clientPrincipal.getClass())) {
+			this.clientId = (String) this.clientPrincipal.getPrincipal();
+		} else {
+			this.clientId = null;
+		}
 	}
 
 	/**
@@ -63,15 +74,18 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
 	 * @param code the authorization code
 	 * @param clientId the client identifier
 	 * @param redirectUri the redirect uri
+	 * @param additionalParameters the additional parameters
 	 */
 	public OAuth2AuthorizationCodeAuthenticationToken(String code,
-			String clientId, @Nullable String redirectUri) {
+			String clientId, @Nullable String redirectUri,
+			Map<String, Object> additionalParameters) {
 		super(Collections.emptyList());
 		Assert.hasText(code, "code cannot be empty");
 		Assert.hasText(clientId, "clientId cannot be empty");
 		this.code = code;
 		this.clientId = clientId;
 		this.redirectUri = redirectUri;
+		this.additionalParameters = Collections.unmodifiableMap(additionalParameters != null ? additionalParameters : Collections.emptyMap());
 	}
 
 	@Override
@@ -101,4 +115,22 @@ public class OAuth2AuthorizationCodeAuthenticationToken extends AbstractAuthenti
 	public @Nullable String getRedirectUri() {
 		return this.redirectUri;
 	}
+
+	/**
+	 * Returns the additional parameters
+	 *
+	 * @return the additional parameters
+	 */
+	public Map<String, Object> getAdditionalParameters() {
+		return this.additionalParameters;
+	}
+
+	/**
+	 * Returns the client id
+	 *
+	 * @return the client id
+	 */
+	public @Nullable String getClientId() {
+		return this.clientId;
+	}
 }

+ 0 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java

@@ -367,7 +367,6 @@ public class RegisteredClient implements Serializable {
 			Assert.hasText(this.clientId, "clientId cannot be empty");
 			Assert.notEmpty(this.authorizationGrantTypes, "authorizationGrantTypes cannot be empty");
 			if (this.authorizationGrantTypes.contains(AuthorizationGrantType.AUTHORIZATION_CODE)) {
-				Assert.hasText(this.clientSecret, "clientSecret cannot be empty");
 				Assert.notEmpty(this.redirectUris, "redirectUris cannot be empty");
 			}
 			if (CollectionUtils.isEmpty(this.clientAuthenticationMethods)) {

+ 35 - 2
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -28,6 +28,7 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@@ -78,6 +79,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 	private final RequestMatcher authorizationEndpointMatcher;
 	private final StringKeyGenerator codeGenerator = new Base64StringKeyGenerator(Base64.getUrlEncoder());
 	private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
+	private final String PKCE_ERROR_URI = "https://tools.ietf.org/html/rfc7636#section-4.4.1";
 
 	/**
 	 * Constructs an {@code OAuth2AuthorizationEndpointFilter} using the provided parameters.
@@ -174,6 +176,34 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 			return;
 		}
 
+		// code_challenge (REQUIRED for public clients) - RFC 7636 (PKCE)
+		String codeChallenge = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE);
+		if (StringUtils.hasText(codeChallenge)) {
+			if (parameters.get(PkceParameterNames.CODE_CHALLENGE).size() != 1) {
+				OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI);
+				sendErrorResponse(request, response, error, stateParameter, redirectUri);
+				return;
+			}
+
+			if (parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD) != null &&
+					parameters.get(PkceParameterNames.CODE_CHALLENGE_METHOD).size() > 1) {
+				OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
+				sendErrorResponse(request, response, error, stateParameter, redirectUri);
+				return;
+			}
+
+			String codeChallengeMethod = parameters.getFirst(PkceParameterNames.CODE_CHALLENGE_METHOD);
+			if (codeChallengeMethod != null && !Arrays.asList("plain", "S256").contains(codeChallengeMethod)) {
+				OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE_METHOD, PKCE_ERROR_URI);
+				sendErrorResponse(request, response, error, stateParameter, redirectUri);
+				return;
+			}
+		} else if (registeredClient.getClientSettings().requireProofKey()) {
+			OAuth2Error error = createError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_CHALLENGE, PKCE_ERROR_URI);
+			sendErrorResponse(request, response, error, stateParameter, redirectUri);
+			return;
+		}
+
 		// ---------------
 		// The request is valid - ensure the resource owner is authenticated
 		// ---------------
@@ -245,8 +275,11 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 	}
 
 	private static OAuth2Error createError(String errorCode, String parameterName) {
-		return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName,
-				"https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
+		return createError(errorCode, parameterName, "https://tools.ietf.org/html/rfc6749#section-4.1.2.1");
+	}
+
+	private static OAuth2Error createError(String errorCode, String parameterName, String errorUri) {
+		return new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, errorUri);
 	}
 
 	private static boolean isPrincipalAuthenticated(Authentication principal) {

+ 30 - 9
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java

@@ -30,11 +30,13 @@ import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
 import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -54,6 +56,7 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 /**
  * A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
@@ -198,14 +201,22 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 			MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
 
 			// client_id (REQUIRED)
-			String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
-			Authentication clientPrincipal = null;
-			if (StringUtils.hasText(clientId)) {
-				if (parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
+			Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
+			String clientId = null;
+			if (clientPrincipal == null ||
+					!OAuth2ClientAuthenticationToken.class.isAssignableFrom(clientPrincipal.getClass())) {
+				clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
+				if (!StringUtils.hasText(clientId) ||
+						parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
 					throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.CLIENT_ID);
 				}
-			} else {
-				clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
+
+				// code_verifier (REQUIRED for public clients)
+				String codeVerifier = parameters.getFirst(PkceParameterNames.CODE_VERIFIER);
+				if (!StringUtils.hasText(codeVerifier) ||
+						parameters.get(PkceParameterNames.CODE_VERIFIER).size() != 1) {
+					throwError(OAuth2ErrorCodes.INVALID_REQUEST, PkceParameterNames.CODE_VERIFIER);
+				}
 			}
 
 			// code (REQUIRED)
@@ -223,9 +234,19 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 				throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.REDIRECT_URI);
 			}
 
-			return clientPrincipal != null ?
-					new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri) :
-					new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, redirectUri);
+			Map<String, Object> additionalParameters = parameters
+					.entrySet()
+					.stream()
+					.filter(e -> !e.getKey().equals(OAuth2ParameterNames.GRANT_TYPE) &&
+							!e.getKey().equals(OAuth2ParameterNames.CLIENT_ID) &&
+							!e.getKey().equals(OAuth2ParameterNames.CODE) &&
+							!e.getKey().equals(OAuth2ParameterNames.REDIRECT_URI))
+					.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().get(0)));
+
+
+			return clientId != null ?
+					new OAuth2AuthorizationCodeAuthenticationToken(code, clientId, redirectUri, additionalParameters) :
+					new OAuth2AuthorizationCodeAuthenticationToken(code, clientPrincipal, redirectUri, additionalParameters);
 		}
 	}
 

+ 2 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java

@@ -31,6 +31,7 @@ import org.springframework.security.crypto.keys.StaticKeyGeneratingKeyManager;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@@ -169,6 +170,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 		parameters.set(OAuth2ParameterNames.SCOPE,
 				StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
 		parameters.set(OAuth2ParameterNames.STATE, "state");
+		parameters.set(PkceParameterNames.CODE_CHALLENGE, "code-challenge");
 		return parameters;
 	}
 

+ 178 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2PkceTests.java

@@ -0,0 +1,178 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization;
+
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.annotation.Bean;
+import org.springframework.security.config.annotation.web.WebSecurityConfigurer;
+import org.springframework.security.config.annotation.web.builders.HttpSecurity;
+import org.springframework.security.config.annotation.web.builders.WebSecurity;
+import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
+import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
+import org.springframework.security.config.test.SpringTestRule;
+import org.springframework.security.crypto.keys.KeyManager;
+import org.springframework.security.crypto.keys.StaticKeyGeneratingKeyManager;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
+import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
+import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter;
+import org.springframework.test.web.servlet.MockMvc;
+import org.springframework.test.web.servlet.MvcResult;
+import org.springframework.util.LinkedMultiValueMap;
+import org.springframework.util.MultiValueMap;
+import org.springframework.util.StringUtils;
+import org.springframework.web.util.UriComponentsBuilder;
+import org.springframework.web.util.UriUtils;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.reset;
+import static org.mockito.Mockito.when;
+import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
+import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
+import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
+import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
+import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
+
+public class OAuth2PkceTests {
+	private static RegisteredClientRepository registeredClientRepository;
+
+	@Rule
+	public final SpringTestRule spring = new SpringTestRule();
+
+	@Autowired
+	private MockMvc mvc;
+
+	@BeforeClass
+	public static void init() {
+		registeredClientRepository = mock(RegisteredClientRepository.class);
+	}
+
+	@Before
+	public void setup() {
+		reset(registeredClientRepository);
+	}
+
+	@Test
+	public void requestWhenTokenRequestNotAuthenticatedAndPkceParamatersProvidedThenRedirectToClient() throws Exception {
+		// See RFC 7636: Appendix B.  Example for the S256 code_challenge_method
+		// https://tools.ietf.org/html/rfc7636#appendix-B
+		final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
+		final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
+
+		this.spring.register(AuthorizationServerConfiguration.class).autowire();
+
+		ClientSettings settings = new ClientSettings();
+		RegisteredClient registeredClient = TestRegisteredClients
+				.registeredClient()
+				.clientSettings(settings.requireProofKey(true))
+				.build();
+		when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+				.thenReturn(registeredClient);
+
+		MvcResult mvcResult = this.mvc.perform(get(OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI)
+				.params(getAuthorizationRequestParameters(registeredClient))
+				.param("code_challenge", S256_CODE_CHALLENGE)
+				.param("code_challenge_method", "S256")
+				.with(user("user")))
+				.andExpect(status().is3xxRedirection())
+				.andReturn();
+
+		assertThat(mvcResult.getResponse().getRedirectedUrl())
+				.doesNotContain("error=")
+				.contains("code=");
+
+		String authorizationCode = UriUtils.decode(UriComponentsBuilder.fromHttpUrl(mvcResult.getResponse().getRedirectedUrl())
+				.build()
+				.getQueryParams()
+				.getFirst("code"), "utf-8");
+
+		this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI)
+				.params(getTokenRequestParameters(registeredClient, authorizationCode))
+				.param("code_verifier", S256_CODE_VERIFIER))
+				.andExpect(status().is2xxSuccessful())
+				.andExpect(jsonPath("$.access_token").isNotEmpty());
+	}
+
+	private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) {
+		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
+		parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
+		parameters.set(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
+		parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
+		parameters.set(OAuth2ParameterNames.SCOPE,
+				StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
+		parameters.set(OAuth2ParameterNames.STATE, "state");
+		return parameters;
+	}
+
+	private static MultiValueMap<String, String> getTokenRequestParameters(RegisteredClient registeredClient,
+			String authorizationCode) {
+		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
+		parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
+		parameters.set(OAuth2ParameterNames.CODE, authorizationCode);
+		parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
+		parameters.set(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
+		return parameters;
+	}
+
+	@EnableWebSecurity
+	static class AuthorizationServerConfiguration {
+
+		@Bean
+		RegisteredClientRepository registeredClientRepository() {
+			return registeredClientRepository;
+		}
+
+		@Bean
+		OAuth2AuthorizationService authorizationService() {
+			return new InMemoryOAuth2AuthorizationService();
+		}
+
+		@Bean
+		KeyManager keyManager() {
+			return new StaticKeyGeneratingKeyManager();
+		}
+
+		@Bean
+		WebSecurityConfigurer<WebSecurity> defaultOAuth2AuthorizationServerSecurity() {
+			return new WebSecurityConfigurerAdapter() {
+				@Override
+				public void configure(HttpSecurity http) throws Exception {
+					http
+							.authorizeRequests()
+							.anyRequest()
+							.permitAll()
+							.and()
+							.csrf()
+							.disable()
+							.apply(new OAuth2AuthorizationServerConfigurer<>());
+				}
+			};
+		}
+	}
+}

+ 7 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java

@@ -21,6 +21,8 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 
 import java.time.Instant;
+import java.util.Collections;
+import java.util.Map;
 
 /**
  * @author Joe Grandja
@@ -32,12 +34,17 @@ public class TestOAuth2Authorizations {
 	}
 
 	public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient) {
+		return authorization(registeredClient, Collections.emptyMap());
+	}
+
+	public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, Map<String, Object> additionalParameters) {
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(
 				OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
 		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
 				.authorizationUri("https://provider.com/oauth2/authorize")
 				.clientId(registeredClient.getClientId())
 				.redirectUri(registeredClient.getRedirectUris().iterator().next())
+				.additionalParameters(additionalParameters)
 				.state("state")
 				.build();
 		return OAuth2Authorization.withRegisteredClient(registeredClient)

+ 322 - 15
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java

@@ -22,6 +22,7 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.jose.JoseHeaderNames;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.security.oauth2.jwt.Jwt;
@@ -35,6 +36,11 @@ import org.springframework.security.oauth2.server.authorization.client.InMemoryR
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
 
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
@@ -53,7 +59,19 @@ import static org.mockito.Mockito.when;
  * @author Joe Grandja
  */
 public class OAuth2AuthorizationCodeAuthenticationProviderTests {
+	private final String PLAIN_CODE_CHALLENGE = "pkce-key";
+	private final String PLAIN_CODE_VERIFIER = PLAIN_CODE_CHALLENGE;
+
+	// See RFC 7636: Appendix B.  Example for the S256 code_challenge_method
+	// https://tools.ietf.org/html/rfc7636#appendix-B
+	private final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
+	private final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
+
+	private final String AUTHORIZATION_CODE = "code";
+
 	private RegisteredClient registeredClient;
+	private RegisteredClient otherRegisteredClient;
+	private RegisteredClient registeredClientRequiresProofKey;
 	private RegisteredClientRepository registeredClientRepository;
 	private OAuth2AuthorizationService authorizationService;
 	private JwtEncoder jwtEncoder;
@@ -62,7 +80,17 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 	@Before
 	public void setUp() {
 		this.registeredClient = TestRegisteredClients.registeredClient().build();
-		this.registeredClientRepository = new InMemoryRegisteredClientRepository(this.registeredClient);
+		this.otherRegisteredClient = TestRegisteredClients.registeredClient2().build();
+		this.registeredClientRequiresProofKey = TestRegisteredClients.registeredClient()
+				.id("registration-3")
+				.clientId("client-3")
+				.clientSettings(new ClientSettings().requireProofKey(true))
+				.build();
+		this.registeredClientRepository = new InMemoryRegisteredClientRepository(
+				this.registeredClient,
+				this.otherRegisteredClient,
+				this.registeredClientRequiresProofKey
+		);
 		this.authorizationService = mock(OAuth2AuthorizationService.class);
 		this.jwtEncoder = mock(JwtEncoder.class);
 		this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(
@@ -100,7 +128,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken(
 				this.registeredClient.getClientId(), this.registeredClient.getClientSecret());
 		OAuth2AuthorizationCodeAuthenticationToken authentication =
-				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
+				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null);
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
@@ -113,7 +141,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
 				this.registeredClient.getClientId(), this.registeredClient.getClientSecret());
 		OAuth2AuthorizationCodeAuthenticationToken authentication =
-				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
+				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null);
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
@@ -125,7 +153,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 	public void authenticateWhenInvalidCodeThenThrowOAuth2AuthenticationException() {
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
 		OAuth2AuthorizationCodeAuthenticationToken authentication =
-				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
+				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null);
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
@@ -142,7 +170,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
 				TestRegisteredClients.registeredClient2().build());
 		OAuth2AuthorizationCodeAuthenticationToken authentication =
-				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null);
+				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, null, null);
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
@@ -160,7 +188,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
 				OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
 		OAuth2AuthorizationCodeAuthenticationToken authentication =
-				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri() + "-invalid");
+				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri() + "-invalid", null);
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
@@ -178,16 +206,9 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
 				OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
 		OAuth2AuthorizationCodeAuthenticationToken authentication =
-				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri());
+				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri(), null);
 
-		Instant issuedAt = Instant.now();
-		Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS);
-		Jwt jwt = Jwt.withTokenValue("token")
-				.header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName())
-				.issuedAt(issuedAt)
-				.expiresAt(expiresAt)
-				.build();
-		when(this.jwtEncoder.encode(any(), any())).thenReturn(jwt);
+		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());
 
 		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
 				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@@ -201,4 +222,290 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		assertThat(updatedAuthorization.getAccessToken()).isNotNull();
 		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken());
 	}
+
+	@Test
+	public void authenticateWhenRequireProofKeyAndMissingPkceCodeChallengeInAuthorizationRequestThenThrowOAuth2AuthenticationException() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClientRequiresProofKey).build();
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken(
+						AUTHORIZATION_CODE,
+						registeredClientRequiresProofKey.getClientId(),
+						authorizationRequest.getRedirectUri(),
+						Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_VERIFIER)
+				);
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+	}
+
+	@Test
+	public void authenticateWhenRequireProofKeyAndUnsupportedCodeChallengeMethodInAuthorizationRequestThenThrowOAuth2AuthenticationException() {
+		Map<String, Object> pkceParameters = new HashMap<>();
+		pkceParameters.put(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE);
+		// This should never happen: the Authorization endpoint should not allow it
+		pkceParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-challenge-method");
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+				.authorization(registeredClientRequiresProofKey, pkceParameters)
+				.build();
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken(
+						AUTHORIZATION_CODE,
+						registeredClientRequiresProofKey.getClientId(),
+						authorizationRequest.getRedirectUri(),
+						Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_VERIFIER)
+				);
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
+	}
+
+	@Test
+	public void authenticateWhenPublicClientAndClientIdNotMatchingThrowOAuth2AuthenticationException() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+				.authorization(registeredClient, getPkceAuthorizationParametersPlain())
+				.build();
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken(
+						AUTHORIZATION_CODE,
+						otherRegisteredClient.getClientId(),
+						authorizationRequest.getRedirectUri(),
+						Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_VERIFIER)
+				);
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+	}
+
+	@Test
+	public void authenticateWhenPublicClientAndUnknownClientIdThrowOAuth2AuthenticationException() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+				.authorization(registeredClient, getPkceAuthorizationParametersPlain())
+				.build();
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken(
+						AUTHORIZATION_CODE,
+						"invalid-client-id",
+						authorizationRequest.getRedirectUri(),
+						Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, PLAIN_CODE_CHALLENGE)
+				);
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
+	}
+
+	@Test
+	public void authenticateWhenPublicClientAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+				.authorization(registeredClient, getPkceAuthorizationParametersPlain())
+				.build();
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken(
+						AUTHORIZATION_CODE,
+						authorizationRequest.getClientId(),
+						authorizationRequest.getRedirectUri(),
+						null
+				);
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+	}
+
+	@Test
+	public void authenticateWhenPrivateClientAndRequireProofKeyAndMissingCodeVerifierThenThrowOAuth2AuthenticationException() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+				.authorization(registeredClient, getPkceAuthorizationParametersPlain())
+				.build();
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken(
+						AUTHORIZATION_CODE,
+						clientPrincipal,
+						authorizationRequest.getRedirectUri(),
+						null
+				);
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+	}
+
+	@Test
+	public void authenticateWhenPublicClientAndPlainMethodAndInvalidCodeVerifierThenThrowOAuth2AuthenticationException() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+				.authorization(registeredClient, getPkceAuthorizationParametersPlain())
+				.build();
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+
+		OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken("invalid-code-verifier");
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+	}
+
+	@Test
+	public void authenticateWhenPublicClientAndS256MethodAndInvalidCodeVerifierThenThrowOAuth2AuthenticationException() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+				.authorization(registeredClient, getPkceAuthorizationParametersS256())
+				.build();
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+
+		OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken("invalid-code-verifier");
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+	}
+
+	@Test
+	public void authenticateWhenPublicClientAndPlainMethodAndValidCodeVerifierThenReturnAccessToken() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+				.authorization(registeredClient, getPkceAuthorizationParametersPlain())
+				.build();
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());
+
+		OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken(PLAIN_CODE_VERIFIER);
+
+		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+		verify(this.authorizationService).save(authorizationCaptor.capture());
+		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
+
+		OAuth2ClientAuthenticationToken clientAuthenticationToken = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal();
+		assertThat(clientAuthenticationToken.getPrincipal()).isEqualTo(this.registeredClient.getClientId());
+		assertThat(updatedAuthorization.getAccessToken()).isNotNull();
+		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken());
+	}
+
+	@Test
+	public void authenticateWhenPublicClientAndNoMethodThenDefaultToPlainAndReturnAccessToken() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+				.authorization(registeredClient, Collections.singletonMap(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE))
+				.build();
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());
+
+		OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken(PLAIN_CODE_VERIFIER);
+
+		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+		verify(this.authorizationService).save(authorizationCaptor.capture());
+		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
+
+		OAuth2ClientAuthenticationToken clientAuthenticationToken = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal();
+		assertThat(clientAuthenticationToken.getPrincipal()).isEqualTo(this.registeredClient.getClientId());
+		assertThat(updatedAuthorization.getAccessToken()).isNotNull();
+		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken());
+	}
+
+
+	@Test
+	public void authenticateWhenPublicClientAndS256MethodAndValidCodeVerifierThenReturnAccessToken() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+				.authorization(registeredClient, getPkceAuthorizationParametersS256())
+				.build();
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());
+
+		OAuth2AuthorizationCodeAuthenticationToken authentication = makeAuthorizationCodeAuthenticationToken(S256_CODE_VERIFIER);
+
+
+		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+
+		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+		verify(this.authorizationService).save(authorizationCaptor.capture());
+		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
+
+		OAuth2ClientAuthenticationToken clientAuthenticationToken = (OAuth2ClientAuthenticationToken) accessTokenAuthentication.getPrincipal();
+		assertThat(clientAuthenticationToken.getPrincipal()).isEqualTo(this.registeredClient.getClientId());
+		assertThat(updatedAuthorization.getAccessToken()).isNotNull();
+		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken());
+	}
+
+	private Map<String, Object> getPkceAuthorizationParametersPlain() {
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain");
+		additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, PLAIN_CODE_CHALLENGE);
+		return additionalParameters;
+	}
+
+	private Map<String, Object> getPkceAuthorizationParametersS256() {
+		Map<String, Object> additionalParameters = new HashMap<>();
+		additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
+		additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE);
+		return additionalParameters;
+	}
+
+	private OAuth2AuthorizationCodeAuthenticationToken makeAuthorizationCodeAuthenticationToken(String codeVerifier) {
+		return new OAuth2AuthorizationCodeAuthenticationToken(
+				AUTHORIZATION_CODE,
+				registeredClient.getClientId(),
+				registeredClient.getRedirectUris().iterator().next(),
+				Collections.singletonMap(PkceParameterNames.CODE_VERIFIER, codeVerifier)
+		);
+	}
+
+	private Jwt createJwt() {
+		Instant issuedAt = Instant.now();
+		Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS);
+		return Jwt.withTokenValue("token")
+				.header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName())
+				.issuedAt(issuedAt)
+				.expiresAt(expiresAt)
+				.build();
+	}
 }

+ 41 - 6
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationTokenTests.java

@@ -19,6 +19,9 @@ import org.junit.Test;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 
+import java.util.Collections;
+import java.util.Map;
+
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
@@ -29,28 +32,30 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
  */
 public class OAuth2AuthorizationCodeAuthenticationTokenTests {
 	private String code = "code";
+	private String clientPrincipalClientId = "clientPrincipal.clientId";
 	private OAuth2ClientAuthenticationToken clientPrincipal =
-			new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().build());
+			new OAuth2ClientAuthenticationToken(TestRegisteredClients.registeredClient().clientId(clientPrincipalClientId).build());
 	private String clientId = "clientId";
 	private String redirectUri = "redirectUri";
+	private Map<String, Object> additonalParams = Collections.singletonMap("some_key", "some_value");
 
 	@Test
 	public void constructorWhenCodeNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(null, this.clientPrincipal, this.redirectUri))
+		assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(null, this.clientPrincipal, this.redirectUri, null))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("code cannot be empty");
 	}
 
 	@Test
 	public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (Authentication) null, this.redirectUri))
+		assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (Authentication) null, this.redirectUri, null))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("clientPrincipal cannot be null");
 	}
 
 	@Test
 	public void constructorWhenClientIdNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (String) null, this.redirectUri))
+		assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationToken(this.code, (String) null, this.redirectUri, null))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("clientId cannot be empty");
 	}
@@ -58,20 +63,50 @@ public class OAuth2AuthorizationCodeAuthenticationTokenTests {
 	@Test
 	public void constructorWhenClientPrincipalProvidedThenCreated() {
 		OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
-				this.code, this.clientPrincipal, this.redirectUri);
+				this.code, this.clientPrincipal, this.redirectUri, this.additonalParams);
 		assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
 		assertThat(authentication.getCredentials().toString()).isEmpty();
 		assertThat(authentication.getCode()).isEqualTo(this.code);
 		assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri);
+		assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additonalParams);
 	}
 
 	@Test
 	public void constructorWhenClientIdProvidedThenCreated() {
 		OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
-				this.code, this.clientId, this.redirectUri);
+				this.code, this.clientId, this.redirectUri, this.additonalParams);
 		assertThat(authentication.getPrincipal()).isEqualTo(this.clientId);
 		assertThat(authentication.getCredentials().toString()).isEmpty();
 		assertThat(authentication.getCode()).isEqualTo(this.code);
 		assertThat(authentication.getRedirectUri()).isEqualTo(this.redirectUri);
+		assertThat(authentication.getAdditionalParameters()).isEqualTo(this.additonalParams);
+	}
+
+	@Test
+	public void getAdditionalParamsIsImmutableMap() {
+		OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
+				this.code, this.clientId, this.redirectUri, this.additonalParams);
+		assertThatThrownBy(() -> authentication.getAdditionalParameters().put("another_key", 1))
+				.isInstanceOf(UnsupportedOperationException.class);
+		assertThatThrownBy(() -> authentication.getAdditionalParameters().remove("some_key"))
+				.isInstanceOf(UnsupportedOperationException.class);
+		assertThatThrownBy(() -> authentication.getAdditionalParameters().clear())
+				.isInstanceOf(UnsupportedOperationException.class);
+	}
+
+	@Test
+	public void getClientIdFromClientId() {
+		OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
+				this.code, this.clientId, this.redirectUri, this.additonalParams);
+
+		assertThat(authentication.getClientId()).isEqualTo(this.clientId);
+	}
+
+	@Test
+	public void getClientIdFromOAuth2ClientAuthenticationTokenPrincipal() {
+		OAuth2AuthorizationCodeAuthenticationToken authentication = new OAuth2AuthorizationCodeAuthenticationToken(
+				this.code, this.clientPrincipal, this.redirectUri, this.additonalParams);
+
+		assertThat(authentication.getClientId()).isEqualTo(this.clientPrincipalClientId);
 	}
 }

+ 232 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -30,12 +30,14 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
 import org.springframework.util.StringUtils;
 
 import javax.servlet.FilterChain;
@@ -280,6 +282,176 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				"state=state");
 	}
 
+	@Test
+	public void doFilterWhenProofKeyRequiredAndMissingPkceCodeChallengeThenThrowError() throws Exception {
+		RegisteredClient registeredClient = createClientRequireProofKey();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request = addPkceParameters(request);
+		request.removeParameter(PkceParameterNames.CODE_CHALLENGE);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+		assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
+				"error=invalid_request&" +
+				"error_description=OAuth%202.0%20Parameter:%20code_challenge&" +
+				"error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" +
+				"state=state");
+	}
+
+	@Test
+	public void doFilterWhenProofKeyRequiredAndMultiplePkceCodeChallengeThenThrowError() throws Exception {
+		RegisteredClient registeredClient = createClientRequireProofKey();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request = addPkceParameters(request);
+		request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenger");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+		assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
+				"error=invalid_request&" +
+				"error_description=OAuth%202.0%20Parameter:%20code_challenge&" +
+				"error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" +
+				"state=state");
+	}
+
+	@Test
+	public void doFilterWhenProofKeyNotRequiredClientAndMultiplePkceCodeChallengeThenThrowError() throws Exception {
+		RegisteredClient registeredClient = createClientDoNotRequireProofKey();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request = addPkceParameters(request);
+		request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenger");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+		assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
+				"error=invalid_request&" +
+				"error_description=OAuth%202.0%20Parameter:%20code_challenge&" +
+				"error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" +
+				"state=state");
+
+	}
+
+	@Test
+	public void doFilterWhenProofKeyRequiredAndMultiplePkceCodeChallengeMethodThenThrowError() throws Exception {
+		RegisteredClient registeredClient = createClientRequireProofKey();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request = addPkceParameters(request);
+		request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+		assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
+				"error=invalid_request&" +
+				"error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" +
+				"error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" +
+				"state=state");
+	}
+
+	@Test
+	public void doFilterWhenProofKeyNotRequiredClientAndPkceCodeChallengeAnMultiplePkceCodeChallengeMethodThenThrowError() throws Exception {
+		RegisteredClient registeredClient = createClientDoNotRequireProofKey();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request = addPkceParameters(request);
+		request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "plain");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+		assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
+				"error=invalid_request&" +
+				"error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" +
+				"error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" +
+				"state=state");
+	}
+
+	@Test
+	public void doFilterWhenProofKeyRequiredAndUnsupportedPkceCodeChallengeMethodThenThrowError() throws Exception {
+		RegisteredClient registeredClient = createClientRequireProofKey();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request = addPkceParameters(request);
+		request.setParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-transform");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+		assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
+				"error=invalid_request&" +
+				"error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" +
+				"error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" +
+				"state=state");
+	}
+
+	@Test
+	public void doFilterWhenProofKeyNotRequiredClientAndPkceCodeChallengeAndUnsupportedPkceCodeChallengeMethodThenThrowError() throws Exception {
+		RegisteredClient registeredClient = createClientDoNotRequireProofKey();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request = addPkceParameters(request);
+		request.setParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "unsupported-transform");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+		assertThat(response.getRedirectedUrl()).matches("https://example.com\\?" +
+				"error=invalid_request&" +
+				"error_description=OAuth%202.0%20Parameter:%20code_challenge_method&" +
+				"error_uri=https://tools.ietf.org/html/rfc7636%23section-4.4.1&" +
+				"state=state");
+
+	}
+
 	@Test
 	public void doFilterWhenAuthorizationRequestValidNotAuthenticatedThenContinueChainToCommenceAuthentication() throws Exception {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@@ -337,6 +509,40 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		assertThat(authorizationRequest.getAdditionalParameters()).isEmpty();
 	}
 
+	@Test
+	public void doFilterWhenProofKeyRequiredAndAuthorizationRequestValidThenAuthorizationResponse() throws Exception {
+		RegisteredClient registeredClient = createClientRequireProofKey();
+		when(this.registeredClientRepository.findByClientId((eq(registeredClient.getClientId()))))
+				.thenReturn(registeredClient);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		request = addPkceParameters(request);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verifyNoInteractions(filterChain);
+
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
+		assertThat(response.getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state");
+
+		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+
+		verify(this.authorizationService).save(authorizationCaptor.capture());
+
+		OAuth2Authorization authorization = authorizationCaptor.getValue();
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+		assertThat(authorizationRequest.getClientId()).isEqualTo(registeredClient.getClientId());
+
+		assertThat(authorizationRequest.getAdditionalParameters())
+				.size()
+				.isEqualTo(2)
+				.returnToMap()
+				.containsEntry(PkceParameterNames.CODE_CHALLENGE, "code-challenge")
+				.containsEntry(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
+	}
+
 	private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient,
 			String parameterName, String errorCode) throws Exception {
 		doFilterWhenAuthorizationRequestInvalidParameterThenError(registeredClient, parameterName, errorCode, request -> {});
@@ -374,4 +580,30 @@ public class OAuth2AuthorizationEndpointFilterTests {
 
 		return request;
 	}
+
+	private static MockHttpServletRequest addPkceParameters(MockHttpServletRequest request) {
+		request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge");
+		request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
+
+		return request;
+	}
+
+	private RegisteredClient createClientRequireProofKey() {
+		ClientSettings clientSettings = new ClientSettings();
+		clientSettings.requireProofKey(true);
+
+		return TestRegisteredClients.registeredClient()
+				.clientSettings(clientSettings)
+				.build();
+	}
+
+	private RegisteredClient createClientDoNotRequireProofKey() {
+		ClientSettings clientSettings = new ClientSettings();
+		clientSettings.requireProofKey(false);
+
+		return TestRegisteredClients.registeredClient()
+				.clientSettings(clientSettings)
+				.build();
+	}
+
 }

+ 49 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java

@@ -34,6 +34,7 @@ import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
 import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@@ -178,6 +179,12 @@ public class OAuth2TokenEndpointFilterTests {
 
 	@Test
 	public void doFilterWhenTokenRequestMissingCodeThenInvalidRequestError() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
+		Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
+		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+		securityContext.setAuthentication(clientPrincipal);
+		SecurityContextHolder.setContext(securityContext);
+
 		MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
 				TestRegisteredClients.registeredClient().build());
 		request.removeParameter(OAuth2ParameterNames.CODE);
@@ -188,6 +195,12 @@ public class OAuth2TokenEndpointFilterTests {
 
 	@Test
 	public void doFilterWhenTokenRequestMultipleCodeThenInvalidRequestError() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
+		Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
+		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+		securityContext.setAuthentication(clientPrincipal);
+		SecurityContextHolder.setContext(securityContext);
+
 		MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
 				TestRegisteredClients.registeredClient().build());
 		request.addParameter(OAuth2ParameterNames.CODE, "code-2");
@@ -198,6 +211,12 @@ public class OAuth2TokenEndpointFilterTests {
 
 	@Test
 	public void doFilterWhenTokenRequestMultipleRedirectUriThenInvalidRequestError() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
+		Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
+		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+		securityContext.setAuthentication(clientPrincipal);
+		SecurityContextHolder.setContext(securityContext);
+
 		MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
 				TestRegisteredClients.registeredClient().build());
 		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com");
@@ -206,6 +225,34 @@ public class OAuth2TokenEndpointFilterTests {
 				OAuth2ParameterNames.REDIRECT_URI, OAuth2ErrorCodes.INVALID_REQUEST, request);
 	}
 
+	@Test
+	public void doFilterWhenTokenRequestNotAuthenticatedAndMissingCodeVerifierThenInvalidRequestError() throws Exception {
+		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+		SecurityContextHolder.setContext(securityContext);
+
+		MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
+				TestRegisteredClients.registeredClient().build());
+		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example.com");
+
+		doFilterWhenTokenRequestInvalidParameterThenError(
+				PkceParameterNames.CODE_VERIFIER, OAuth2ErrorCodes.INVALID_REQUEST, request);
+	}
+
+	@Test
+	public void doFilterWhenTokenRequestNotAuthenticatedAndMultipleCodeVerifierThenInvalidRequestError() throws Exception {
+		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+		SecurityContextHolder.setContext(securityContext);
+
+		MockHttpServletRequest request = createAuthorizationCodeTokenRequest(
+				TestRegisteredClients.registeredClient().build());
+		request.addParameter(PkceParameterNames.CODE_VERIFIER, "one-verifier");
+		request.addParameter(PkceParameterNames.CODE_VERIFIER, "two-verifiers");
+		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example.com");
+
+		doFilterWhenTokenRequestInvalidParameterThenError(
+				PkceParameterNames.CODE_VERIFIER, OAuth2ErrorCodes.INVALID_REQUEST, request);
+	}
+
 	@Test
 	public void doFilterWhenAuthorizationCodeTokenRequestValidThenAccessTokenResponse() throws Exception {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
@@ -359,6 +406,8 @@ public class OAuth2TokenEndpointFilterTests {
 		request.addParameter(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
 		request.addParameter(OAuth2ParameterNames.CODE, "code");
 		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, redirectUris[0]);
+		// The client does not need to send the client ID param, but we are resilient in case they do
+		request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
 
 		return request;
 	}