Browse Source

Fix to ensure endpoints distinguish between form and query parameters

Closes gh-1451
Greg Li 1 year ago
parent
commit
4bc0df5ef8
15 changed files with 105 additions and 67 deletions
  1. 1 1
      docs/src/docs/asciidoc/examples/src/test/java/sample/AuthorizationCodeGrantFlow.java
  2. 12 12
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java
  3. 4 4
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java
  4. 2 3
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java
  5. 2 2
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java
  6. 3 3
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationConsentAuthenticationConverter.java
  7. 3 3
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientCredentialsAuthenticationConverter.java
  8. 27 8
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java
  9. 3 3
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2RefreshTokenAuthenticationConverter.java
  10. 1 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenIntrospectionAuthenticationConverter.java
  11. 1 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java
  12. 1 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java
  13. 15 14
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java
  14. 1 1
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java
  15. 29 10
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

+ 1 - 1
docs/src/docs/asciidoc/examples/src/test/java/sample/AuthorizationCodeGrantFlow.java

@@ -94,7 +94,7 @@ public class AuthorizationCodeGrantFlow {
 		parameters.set(OAuth2ParameterNames.STATE, "state");
 
 		MvcResult mvcResult = this.mockMvc.perform(get("/oauth2/authorize")
-				.params(parameters)
+				.queryParams(parameters)
 				.with(user(this.username).roles("USER")))
 				.andExpect(status().isOk())
 				.andExpect(header().string("content-type", containsString(MediaType.TEXT_HTML_VALUE)))

+ 12 - 12
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java

@@ -48,7 +48,18 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica
 	@Nullable
 	@Override
 	public Authentication convert(HttpServletRequest request) {
-		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
+		String queryString = request.getQueryString();
+		if (StringUtils.hasText(queryString) &&
+				(queryString.contains(OAuth2ParameterNames.CLIENT_ID) ||
+						queryString.contains(OAuth2ParameterNames.CLIENT_SECRET))) {
+			OAuth2Error error = new OAuth2Error(
+					OAuth2ErrorCodes.INVALID_REQUEST,
+					"Client credentials MUST NOT be included in the request URI.",
+					null);
+			throw new OAuth2AuthenticationException(error);
+		}
+
+		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
 
 		// client_id (REQUIRED)
 		String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
@@ -70,17 +81,6 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica
 			throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
 		}
 
-		String queryString = request.getQueryString();
-		if (StringUtils.hasText(queryString) &&
-				(queryString.contains(OAuth2ParameterNames.CLIENT_ID) ||
-						queryString.contains(OAuth2ParameterNames.CLIENT_SECRET))) {
-			OAuth2Error error = new OAuth2Error(
-					OAuth2ErrorCodes.INVALID_REQUEST,
-					"Client credentials MUST NOT be included in the request URI.",
-					null);
-			throw new OAuth2AuthenticationException(error);
-		}
-
 		Map<String, Object> additionalParameters = OAuth2EndpointUtils.getParametersIfMatchesAuthorizationCodeGrantRequest(request,
 				OAuth2ParameterNames.CLIENT_ID,
 				OAuth2ParameterNames.CLIENT_SECRET);

+ 4 - 4
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java

@@ -48,13 +48,13 @@ public final class JwtClientAssertionAuthenticationConverter implements Authenti
 	@Nullable
 	@Override
 	public Authentication convert(HttpServletRequest request) {
-		if (request.getParameter(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE) == null ||
-				request.getParameter(OAuth2ParameterNames.CLIENT_ASSERTION) == null) {
+		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
+
+		if (parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE) == null ||
+				parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION) == null) {
 			return null;
 		}
 
-		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
-
 		// client_assertion_type (REQUIRED)
 		String clientAssertionType = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE);
 		if (parameters.get(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE).size() != 1) {

+ 2 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java

@@ -47,16 +47,15 @@ public final class OAuth2AuthorizationCodeAuthenticationConverter implements Aut
 	@Nullable
 	@Override
 	public Authentication convert(HttpServletRequest request) {
+		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
 		// grant_type (REQUIRED)
-		String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
+		String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
 		if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(grantType)) {
 			return null;
 		}
 
 		Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
 
-		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
-
 		// code (REQUIRED)
 		String code = parameters.getFirst(OAuth2ParameterNames.CODE);
 		if (!StringUtils.hasText(code) ||

+ 2 - 2
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java

@@ -66,10 +66,10 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationConverter impleme
 			return null;
 		}
 
-		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
+		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getQueryParameters(request);
 
 		// response_type (REQUIRED)
-		String responseType = request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE);
+		String responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE);
 		if (!StringUtils.hasText(responseType) ||
 				parameters.get(OAuth2ParameterNames.RESPONSE_TYPE).size() != 1) {
 			throwError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.RESPONSE_TYPE);

+ 3 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationConsentAuthenticationConverter.java

@@ -54,13 +54,13 @@ public final class OAuth2AuthorizationConsentAuthenticationConverter implements
 
 	@Override
 	public Authentication convert(HttpServletRequest request) {
+		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
+
 		if (!"POST".equals(request.getMethod()) ||
-				request.getParameter(OAuth2ParameterNames.RESPONSE_TYPE) != null) {
+				parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE) != null) {
 			return null;
 		}
 
-		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
-
 		String authorizationUri = request.getRequestURL().toString();
 
 		// client_id (REQUIRED)

+ 3 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ClientCredentialsAuthenticationConverter.java

@@ -50,16 +50,16 @@ public final class OAuth2ClientCredentialsAuthenticationConverter implements Aut
 	@Nullable
 	@Override
 	public Authentication convert(HttpServletRequest request) {
+		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
+
 		// grant_type (REQUIRED)
-		String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
+		String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
 		if (!AuthorizationGrantType.CLIENT_CREDENTIALS.getValue().equals(grantType)) {
 			return null;
 		}
 
 		Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
 
-		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
-
 		// scope (OPTIONAL)
 		String scope = parameters.getFirst(OAuth2ParameterNames.SCOPE);
 		if (StringUtils.hasText(scope) &&

+ 27 - 8
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java

@@ -28,24 +28,41 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.util.LinkedMultiValueMap;
 import org.springframework.util.MultiValueMap;
+import org.springframework.util.StringUtils;
 
 /**
  * Utility methods for the OAuth 2.0 Protocol Endpoints.
  *
  * @author Joe Grandja
+ * @author Greg Li
  * @since 0.1.2
  */
 final class OAuth2EndpointUtils {
 	static final String ACCESS_TOKEN_REQUEST_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
-
 	private OAuth2EndpointUtils() {
 	}
 
-	static MultiValueMap<String, String> getParameters(HttpServletRequest request) {
+	static MultiValueMap<String, String> getFormParameters(HttpServletRequest request) {
+		Map<String, String[]> parameterMap = request.getParameterMap();
+		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
+		parameterMap.forEach((key, values) -> {
+			// If not query parameter then it's a form parameter
+			if ((!StringUtils.hasText(request.getQueryString()) && values.length > 0)
+					|| (!request.getQueryString().contains(key) && values.length > 0)) {
+				for (String value : values) {
+					parameters.add(key, value);
+				}
+			}
+		});
+		return parameters;
+	}
+
+	static MultiValueMap<String, String> getQueryParameters(HttpServletRequest request) {
 		Map<String, String[]> parameterMap = request.getParameterMap();
-		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>(parameterMap.size());
+		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
 		parameterMap.forEach((key, values) -> {
-			if (values.length > 0) {
+			if (StringUtils.hasText(request.getQueryString())
+					&& request.getQueryString().contains(key) && values.length > 0) {
 				for (String value : values) {
 					parameters.add(key, value);
 				}
@@ -58,7 +75,7 @@ final class OAuth2EndpointUtils {
 		if (!matchesAuthorizationCodeGrantRequest(request)) {
 			return Collections.emptyMap();
 		}
-		MultiValueMap<String, String> multiValueParameters = getParameters(request);
+		MultiValueMap<String, String> multiValueParameters = getFormParameters(request);
 		for (String exclusion : exclusions) {
 			multiValueParameters.remove(exclusion);
 		}
@@ -71,14 +88,16 @@ final class OAuth2EndpointUtils {
 	}
 
 	static boolean matchesAuthorizationCodeGrantRequest(HttpServletRequest request) {
+		MultiValueMap<String, String> parameters = getFormParameters(request);
 		return AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(
-				request.getParameter(OAuth2ParameterNames.GRANT_TYPE)) &&
-				request.getParameter(OAuth2ParameterNames.CODE) != null;
+				parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) &&
+				parameters.getFirst(OAuth2ParameterNames.CODE) != null;
 	}
 
 	static boolean matchesPkceTokenRequest(HttpServletRequest request) {
+		MultiValueMap<String, String> parameters = getFormParameters(request);
 		return matchesAuthorizationCodeGrantRequest(request) &&
-				request.getParameter(PkceParameterNames.CODE_VERIFIER) != null;
+				parameters.getFirst(PkceParameterNames.CODE_VERIFIER) != null;
 	}
 
 	static void throwError(String errorCode, String parameterName, String errorUri) {

+ 3 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2RefreshTokenAuthenticationConverter.java

@@ -50,16 +50,16 @@ public final class OAuth2RefreshTokenAuthenticationConverter implements Authenti
 	@Nullable
 	@Override
 	public Authentication convert(HttpServletRequest request) {
+		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
+
 		// grant_type (REQUIRED)
-		String grantType = request.getParameter(OAuth2ParameterNames.GRANT_TYPE);
+		String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
 		if (!AuthorizationGrantType.REFRESH_TOKEN.getValue().equals(grantType)) {
 			return null;
 		}
 
 		Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
 
-		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
-
 		// refresh_token (REQUIRED)
 		String refreshToken = parameters.getFirst(OAuth2ParameterNames.REFRESH_TOKEN);
 		if (!StringUtils.hasText(refreshToken) ||

+ 1 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenIntrospectionAuthenticationConverter.java

@@ -49,7 +49,7 @@ public final class OAuth2TokenIntrospectionAuthenticationConverter implements Au
 	public Authentication convert(HttpServletRequest request) {
 		Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
 
-		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
+		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
 
 		// token (REQUIRED)
 		String token = parameters.getFirst(OAuth2ParameterNames.TOKEN);

+ 1 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java

@@ -46,7 +46,7 @@ public final class OAuth2TokenRevocationAuthenticationConverter implements Authe
 	public Authentication convert(HttpServletRequest request) {
 		Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
 
-		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
+		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
 
 		// token (REQUIRED)
 		String token = parameters.getFirst(OAuth2ParameterNames.TOKEN);

+ 1 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java

@@ -53,7 +53,7 @@ public final class PublicClientAuthenticationConverter implements Authentication
 			return null;
 		}
 
-		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
+		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
 
 		// client_id (REQUIRED for public clients)
 		String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);

+ 15 - 14
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java

@@ -153,6 +153,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
  * @author Daniel Garnier-Moiroux
  * @author Dmitriy Dubson
  * @author Steve Riesenberg
+ * @author Greg Li
  */
 @ExtendWith(SpringTestContextExtension.class)
 public class OAuth2AuthorizationCodeGrantTests {
@@ -255,7 +256,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 		this.registeredClientRepository.save(registeredClient);
 
 		this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
-				.params(getAuthorizationRequestParameters(registeredClient)))
+				.queryParams(getAuthorizationRequestParameters(registeredClient)))
 				.andExpect(status().isUnauthorized())
 				.andReturn();
 	}
@@ -297,7 +298,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 		MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
 		MvcResult mvcResult = this.mvc.perform(get(authorizationEndpointUri)
-				.params(authorizationRequestParameters)
+				.queryParams(authorizationRequestParameters)
 				.with(user("user")))
 				.andExpect(status().is3xxRedirection())
 				.andReturn();
@@ -389,9 +390,9 @@ public class OAuth2AuthorizationCodeGrantTests {
 		this.registeredClientRepository.save(registeredClient);
 
 		MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
-				.params(getAuthorizationRequestParameters(registeredClient))
-				.param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
-				.param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
+				.queryParams(getAuthorizationRequestParameters(registeredClient))
+				.queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
+				.queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
 				.with(user("user")))
 				.andExpect(status().is3xxRedirection())
 				.andReturn();
@@ -434,9 +435,9 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 		MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
 		MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
-				.params(authorizationRequestParameters)
-				.param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
-				.param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
+				.queryParams(authorizationRequestParameters)
+				.queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
+				.queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
 				.with(user("user")))
 				.andExpect(status().is3xxRedirection())
 				.andReturn();
@@ -473,7 +474,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 		MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
 		MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
-				.params(authorizationRequestParameters)
+				.queryParams(authorizationRequestParameters)
 				.with(user("user")))
 				.andExpect(status().is3xxRedirection())
 				.andReturn();
@@ -519,7 +520,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 		this.registeredClientRepository.save(registeredClient);
 
 		String consentPage = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
-				.params(getAuthorizationRequestParameters(registeredClient))
+				.queryParams(getAuthorizationRequestParameters(registeredClient))
 				.with(user("user")))
 				.andExpect(status().is2xxSuccessful())
 				.andReturn()
@@ -602,7 +603,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 		this.registeredClientRepository.save(registeredClient);
 
 		MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
-				.params(getAuthorizationRequestParameters(registeredClient))
+				.queryParams(getAuthorizationRequestParameters(registeredClient))
 				.with(user("user")))
 				.andExpect(status().is3xxRedirection())
 				.andReturn();
@@ -737,9 +738,9 @@ public class OAuth2AuthorizationCodeGrantTests {
 		this.registeredClientRepository.save(registeredClient);
 
 		MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
-				.params(getAuthorizationRequestParameters(registeredClient))
-				.param(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
-				.param(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
+				.queryParams(getAuthorizationRequestParameters(registeredClient))
+				.queryParam(PkceParameterNames.CODE_CHALLENGE, S256_CODE_CHALLENGE)
+				.queryParam(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256")
 				.with(user("user")))
 				.andExpect(status().is3xxRedirection())
 				.andReturn();

+ 1 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java

@@ -184,7 +184,7 @@ public class OidcTests {
 
 		MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
 		MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
-				.params(authorizationRequestParameters)
+				.queryParams(authorizationRequestParameters)
 				.with(user("user").roles("A", "B")))
 				.andExpect(status().is3xxRedirection())
 				.andReturn();

+ 29 - 10
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -37,6 +37,7 @@ import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.mock.web.MockServletContext;
 import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.TestingAuthenticationToken;
@@ -58,6 +59,7 @@ import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.authentication.WebAuthenticationDetails;
+import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
 import org.springframework.util.StringUtils;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -78,6 +80,7 @@ import static org.mockito.Mockito.when;
  * @author Daniel Garnier-Moiroux
  * @author Anoop Garlapati
  * @author Dmitriy Dubson
+ * @author Greg Li
  * @since 0.0.1
  */
 public class OAuth2AuthorizationEndpointFilterTests {
@@ -263,6 +266,13 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				request -> {
 					request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge");
 					request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenge");
+					String originalQueryString = request.getQueryString();
+					if (StringUtils.hasText(originalQueryString)) {
+						String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE)
+								.concat("=code-challenge").concat("&")
+								.concat(PkceParameterNames.CODE_CHALLENGE).concat("=another-code-challenge");
+						request.setQueryString(newQueryString);
+					}
 				});
 	}
 
@@ -275,6 +285,13 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				request -> {
 					request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
 					request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
+					String originalQueryString = request.getQueryString();
+					if (StringUtils.hasText(originalQueryString)) {
+						String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE_METHOD)
+								.concat("=S256").concat("&")
+								.concat(PkceParameterNames.CODE_CHALLENGE_METHOD).concat("=S256");
+						request.setQueryString(newQueryString);
+					}
 				});
 	}
 
@@ -557,6 +574,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
 
 		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
 		request.addParameter("custom-param", "custom-value-1", "custom-value-2");
+		String newQueryString = request.getQueryString().concat("custom-param")
+				.concat("=custom-value-1").concat("&")
+				.concat("custom-param").concat("=custom-value-2");
+		request.setQueryString(newQueryString);
 
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
@@ -646,17 +667,15 @@ public class OAuth2AuthorizationEndpointFilterTests {
 
 	private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) {
 		String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
-		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
-		request.setServletPath(requestUri);
+		MockHttpServletRequest request = MockMvcRequestBuilders.get(requestUri)
+				.queryParam(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue())
+				.queryParam(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
+				.queryParam(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next())
+				.queryParam(OAuth2ParameterNames.SCOPE,
+						StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "))
+				.queryParam(OAuth2ParameterNames.STATE, "state")
+				.buildRequest(new MockServletContext());
 		request.setRemoteAddr(REMOTE_ADDRESS);
-
-		request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
-		request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
-		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
-		request.addParameter(OAuth2ParameterNames.SCOPE,
-				StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
-		request.addParameter(OAuth2ParameterNames.STATE, "state");
-
 		return request;
 	}