Joe Grandja пре 1 година
родитељ
комит
f0a6a4c0bf
13 измењених фајлова са 128 додато и 125 уклоњено
  1. 1 1
      docs/src/docs/asciidoc/examples/src/test/java/sample/AuthorizationCodeGrantFlow.java
  2. 23 3
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcClientRegistrationAuthenticationConverter.java
  3. 0 12
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java
  4. 1 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java
  5. 1 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java
  6. 4 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java
  7. 12 10
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java
  8. 1 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java
  9. 4 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java
  10. 0 33
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientCredentialsGrantTests.java
  11. 21 1
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java
  12. 60 36
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java
  13. 0 25
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverterTests.java

+ 1 - 1
docs/src/docs/asciidoc/examples/src/test/java/sample/AuthorizationCodeGrantFlow.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.

+ 23 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/web/authentication/OidcClientRegistrationAuthenticationConverter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,6 +15,8 @@
  */
 package org.springframework.security.oauth2.server.authorization.oidc.web.authentication;
 
+import java.util.Map;
+
 import javax.servlet.http.HttpServletRequest;
 
 import org.springframework.http.converter.HttpMessageConverter;
@@ -30,6 +32,8 @@ import org.springframework.security.oauth2.server.authorization.oidc.authenticat
 import org.springframework.security.oauth2.server.authorization.oidc.http.converter.OidcClientRegistrationHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.oidc.web.OidcClientRegistrationEndpointFilter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.util.LinkedMultiValueMap;
+import org.springframework.util.MultiValueMap;
 import org.springframework.util.StringUtils;
 
 /**
@@ -65,14 +69,30 @@ public final class OidcClientRegistrationAuthenticationConverter implements Auth
 			return new OidcClientRegistrationAuthenticationToken(principal, clientRegistration);
 		}
 
+		MultiValueMap<String, String> parameters = getQueryParameters(request);
+
 		// client_id (REQUIRED)
-		String clientId = request.getParameter(OAuth2ParameterNames.CLIENT_ID);
+		String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);
 		if (!StringUtils.hasText(clientId) ||
-				request.getParameterValues(OAuth2ParameterNames.CLIENT_ID).length != 1) {
+				parameters.get(OAuth2ParameterNames.CLIENT_ID).size() != 1) {
 			throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_REQUEST);
 		}
 
 		return new OidcClientRegistrationAuthenticationToken(principal, clientId);
 	}
 
+	private static MultiValueMap<String, String> getQueryParameters(HttpServletRequest request) {
+		Map<String, String[]> parameterMap = request.getParameterMap();
+		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
+		parameterMap.forEach((key, values) -> {
+			String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : "";
+			if (queryString.contains(key) && values.length > 0) {
+				for (String value : values) {
+					parameters.add(key, value);
+				}
+			}
+		});
+		return parameters;
+	}
+
 }

+ 0 - 12
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverter.java

@@ -23,7 +23,6 @@ import org.springframework.lang.Nullable;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
-import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
@@ -48,17 +47,6 @@ public final class ClientSecretPostAuthenticationConverter implements Authentica
 	@Nullable
 	@Override
 	public Authentication convert(HttpServletRequest request) {
-		String queryString = request.getQueryString();
-		if (StringUtils.hasText(queryString) &&
-				(queryString.contains(OAuth2ParameterNames.CLIENT_ID) ||
-						queryString.contains(OAuth2ParameterNames.CLIENT_SECRET))) {
-			OAuth2Error error = new OAuth2Error(
-					OAuth2ErrorCodes.INVALID_REQUEST,
-					"Client credentials MUST NOT be included in the request URI.",
-					null);
-			throw new OAuth2AuthenticationException(error);
-		}
-
 		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
 
 		// client_id (REQUIRED)

+ 1 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/JwtClientAssertionAuthenticationConverter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2021 the original author or authors.
+ * Copyright 2020-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.

+ 1 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeAuthenticationConverter.java

@@ -48,6 +48,7 @@ public final class OAuth2AuthorizationCodeAuthenticationConverter implements Aut
 	@Override
 	public Authentication convert(HttpServletRequest request) {
 		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
+
 		// grant_type (REQUIRED)
 		String grantType = parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE);
 		if (!AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(grantType)) {

+ 4 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2AuthorizationCodeRequestAuthenticationConverter.java

@@ -66,7 +66,10 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationConverter impleme
 			return null;
 		}
 
-		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getQueryParameters(request);
+		MultiValueMap<String, String> parameters =
+				"GET".equals(request.getMethod()) ?
+						OAuth2EndpointUtils.getQueryParameters(request) :
+						OAuth2EndpointUtils.getFormParameters(request);
 
 		// response_type (REQUIRED)
 		String responseType = parameters.getFirst(OAuth2ParameterNames.RESPONSE_TYPE);

+ 12 - 10
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2EndpointUtils.java

@@ -39,6 +39,7 @@ import org.springframework.util.StringUtils;
  */
 final class OAuth2EndpointUtils {
 	static final String ACCESS_TOKEN_REQUEST_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
+
 	private OAuth2EndpointUtils() {
 	}
 
@@ -46,9 +47,9 @@ final class OAuth2EndpointUtils {
 		Map<String, String[]> parameterMap = request.getParameterMap();
 		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
 		parameterMap.forEach((key, values) -> {
+			String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : "";
 			// If not query parameter then it's a form parameter
-			if ((!StringUtils.hasText(request.getQueryString()) && values.length > 0)
-					|| (!request.getQueryString().contains(key) && values.length > 0)) {
+			if (!queryString.contains(key) && values.length > 0) {
 				for (String value : values) {
 					parameters.add(key, value);
 				}
@@ -61,8 +62,8 @@ final class OAuth2EndpointUtils {
 		Map<String, String[]> parameterMap = request.getParameterMap();
 		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
 		parameterMap.forEach((key, values) -> {
-			if (StringUtils.hasText(request.getQueryString())
-					&& request.getQueryString().contains(key) && values.length > 0) {
+			String queryString = StringUtils.hasText(request.getQueryString()) ? request.getQueryString() : "";
+			if (queryString.contains(key) && values.length > 0) {
 				for (String value : values) {
 					parameters.add(key, value);
 				}
@@ -75,7 +76,10 @@ final class OAuth2EndpointUtils {
 		if (!matchesAuthorizationCodeGrantRequest(request)) {
 			return Collections.emptyMap();
 		}
-		MultiValueMap<String, String> multiValueParameters = getFormParameters(request);
+		MultiValueMap<String, String> multiValueParameters =
+				"GET".equals(request.getMethod()) ?
+						getQueryParameters(request) :
+						getFormParameters(request);
 		for (String exclusion : exclusions) {
 			multiValueParameters.remove(exclusion);
 		}
@@ -88,16 +92,14 @@ final class OAuth2EndpointUtils {
 	}
 
 	static boolean matchesAuthorizationCodeGrantRequest(HttpServletRequest request) {
-		MultiValueMap<String, String> parameters = getFormParameters(request);
 		return AuthorizationGrantType.AUTHORIZATION_CODE.getValue().equals(
-				parameters.getFirst(OAuth2ParameterNames.GRANT_TYPE)) &&
-				parameters.getFirst(OAuth2ParameterNames.CODE) != null;
+				request.getParameter(OAuth2ParameterNames.GRANT_TYPE)) &&
+				request.getParameter(OAuth2ParameterNames.CODE) != null;
 	}
 
 	static boolean matchesPkceTokenRequest(HttpServletRequest request) {
-		MultiValueMap<String, String> parameters = getFormParameters(request);
 		return matchesAuthorizationCodeGrantRequest(request) &&
-				parameters.getFirst(PkceParameterNames.CODE_VERIFIER) != null;
+				request.getParameter(PkceParameterNames.CODE_VERIFIER) != null;
 	}
 
 	static void throwError(String errorCode, String parameterName, String errorUri) {

+ 1 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2TokenRevocationAuthenticationConverter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.

+ 4 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/PublicClientAuthenticationConverter.java

@@ -53,7 +53,10 @@ public final class PublicClientAuthenticationConverter implements Authentication
 			return null;
 		}
 
-		MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getFormParameters(request);
+		MultiValueMap<String, String> parameters =
+				"GET".equals(request.getMethod()) ?
+						OAuth2EndpointUtils.getQueryParameters(request) :
+						OAuth2EndpointUtils.getFormParameters(request);
 
 		// client_id (REQUIRED for public clients)
 		String clientId = parameters.getFirst(OAuth2ParameterNames.CLIENT_ID);

+ 0 - 33
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2ClientCredentialsGrantTests.java

@@ -59,7 +59,6 @@ import org.springframework.security.crypto.password.PasswordEncoder;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
-import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.jose.TestJwks;
 import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService;
@@ -98,7 +97,6 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
-import org.springframework.web.util.UriComponentsBuilder;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
@@ -232,37 +230,6 @@ public class OAuth2ClientCredentialsGrantTests {
 		verify(jwtCustomizer).customize(any());
 	}
 
-	// gh-1378
-	@Test
-	public void requestWhenTokenRequestWithClientCredentialsInQueryParamThenInvalidRequest() throws Exception {
-		this.spring.register(AuthorizationServerConfiguration.class).autowire();
-
-		RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
-		this.registeredClientRepository.save(registeredClient);
-
-		String tokenEndpointUri = UriComponentsBuilder.fromUriString(DEFAULT_TOKEN_ENDPOINT_URI)
-				.queryParam(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
-				.toUriString();
-
-		this.mvc.perform(post(tokenEndpointUri)
-						.param(OAuth2ParameterNames.CLIENT_SECRET, registeredClient.getClientSecret())
-						.param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
-						.param(OAuth2ParameterNames.SCOPE, "scope1 scope2"))
-				.andExpect(status().isBadRequest())
-				.andExpect(jsonPath("$.error").value(OAuth2ErrorCodes.INVALID_REQUEST));
-
-		tokenEndpointUri = UriComponentsBuilder.fromUriString(DEFAULT_TOKEN_ENDPOINT_URI)
-				.queryParam(OAuth2ParameterNames.CLIENT_SECRET, registeredClient.getClientSecret())
-				.toUriString();
-
-		this.mvc.perform(post(tokenEndpointUri)
-						.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
-						.param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
-						.param(OAuth2ParameterNames.SCOPE, "scope1 scope2"))
-				.andExpect(status().isBadRequest())
-				.andExpect(jsonPath("$.error").value(OAuth2ErrorCodes.INVALID_REQUEST));
-	}
-
 	@Test
 	public void requestWhenTokenEndpointCustomizedThenUsed() throws Exception {
 		this.spring.register(AuthorizationServerConfigurationCustomTokenEndpoint.class).autowire();

+ 21 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/web/OidcClientRegistrationEndpointFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -61,6 +61,7 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
+import org.springframework.web.util.UriComponentsBuilder;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -327,6 +328,7 @@ public class OidcClientRegistrationEndpointFilterTests {
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		request.addParameter(OAuth2ParameterNames.CLIENT_ID, "");
+		updateQueryString(request);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
@@ -342,6 +344,7 @@ public class OidcClientRegistrationEndpointFilterTests {
 		request.setServletPath(requestUri);
 		request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-id");
 		request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-id2");
+		updateQueryString(request);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
@@ -388,6 +391,7 @@ public class OidcClientRegistrationEndpointFilterTests {
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client1");
+		updateQueryString(request);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
@@ -421,6 +425,7 @@ public class OidcClientRegistrationEndpointFilterTests {
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		request.setParameter(OAuth2ParameterNames.CLIENT_ID, expectedClientRegistrationResponse.getClientId());
+		updateQueryString(request);
 
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
@@ -463,6 +468,7 @@ public class OidcClientRegistrationEndpointFilterTests {
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client-id");
+		updateQueryString(request);
 
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
@@ -492,6 +498,7 @@ public class OidcClientRegistrationEndpointFilterTests {
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		request.setParameter(OAuth2ParameterNames.CLIENT_ID, expectedClientRegistrationResponse.getClientId());
+		updateQueryString(request);
 
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
@@ -513,6 +520,7 @@ public class OidcClientRegistrationEndpointFilterTests {
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		request.setParameter(OAuth2ParameterNames.CLIENT_ID, "client1");
+		updateQueryString(request);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
@@ -522,6 +530,18 @@ public class OidcClientRegistrationEndpointFilterTests {
 				any(OAuth2AuthenticationException.class));
 	}
 
+	private static void updateQueryString(MockHttpServletRequest request) {
+		UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI());
+		request.getParameterMap().forEach((key, values) -> {
+			if (values.length > 0) {
+				for (String value : values) {
+					uriBuilder.queryParam(key, value);
+				}
+			}
+		});
+		request.setQueryString(uriBuilder.build().getQuery());
+	}
+
 	private OAuth2Error readError(MockHttpServletResponse response) throws Exception {
 		MockClientHttpResponse httpResponse = new MockClientHttpResponse(
 				response.getContentAsByteArray(), HttpStatus.valueOf(response.getStatus()));

+ 60 - 36
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -37,7 +37,6 @@ import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
-import org.springframework.mock.web.MockServletContext;
 import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.TestingAuthenticationToken;
@@ -59,8 +58,8 @@ import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.authentication.WebAuthenticationDetails;
-import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
 import org.springframework.util.StringUtils;
+import org.springframework.web.util.UriComponentsBuilder;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -173,7 +172,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				TestRegisteredClients.registeredClient().build(),
 				OAuth2ParameterNames.RESPONSE_TYPE,
 				OAuth2ErrorCodes.INVALID_REQUEST,
-				request -> request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE));
+				request -> {
+					request.removeParameter(OAuth2ParameterNames.RESPONSE_TYPE);
+					updateQueryString(request);
+				});
 	}
 
 	@Test
@@ -182,7 +184,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				TestRegisteredClients.registeredClient().build(),
 				OAuth2ParameterNames.RESPONSE_TYPE,
 				OAuth2ErrorCodes.INVALID_REQUEST,
-				request -> request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"));
+				request -> {
+					request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
+					updateQueryString(request);
+				});
 	}
 
 	@Test
@@ -191,7 +196,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				TestRegisteredClients.registeredClient().build(),
 				OAuth2ParameterNames.RESPONSE_TYPE,
 				OAuth2ErrorCodes.UNSUPPORTED_RESPONSE_TYPE,
-				request -> request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token"));
+				request -> {
+					request.setParameter(OAuth2ParameterNames.RESPONSE_TYPE, "id_token");
+					updateQueryString(request);
+				});
 	}
 
 	@Test
@@ -200,7 +208,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				TestRegisteredClients.registeredClient().build(),
 				OAuth2ParameterNames.CLIENT_ID,
 				OAuth2ErrorCodes.INVALID_REQUEST,
-				request -> request.removeParameter(OAuth2ParameterNames.CLIENT_ID));
+				request -> {
+					request.removeParameter(OAuth2ParameterNames.CLIENT_ID);
+					updateQueryString(request);
+				});
 	}
 
 	@Test
@@ -209,7 +220,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				TestRegisteredClients.registeredClient().build(),
 				OAuth2ParameterNames.CLIENT_ID,
 				OAuth2ErrorCodes.INVALID_REQUEST,
-				request -> request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2"));
+				request -> {
+					request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-2");
+					updateQueryString(request);
+				});
 	}
 
 	@Test
@@ -218,7 +232,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				TestRegisteredClients.registeredClient().build(),
 				OAuth2ParameterNames.REDIRECT_URI,
 				OAuth2ErrorCodes.INVALID_REQUEST,
-				request -> request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com"));
+				request -> {
+					request.addParameter(OAuth2ParameterNames.REDIRECT_URI, "https://example2.com");
+					updateQueryString(request);
+				});
 	}
 
 	@Test
@@ -227,7 +244,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				TestRegisteredClients.registeredClient().build(),
 				OAuth2ParameterNames.SCOPE,
 				OAuth2ErrorCodes.INVALID_REQUEST,
-				request -> request.addParameter(OAuth2ParameterNames.SCOPE, "scope2"));
+				request -> {
+					request.addParameter(OAuth2ParameterNames.SCOPE, "scope2");
+					updateQueryString(request);
+				});
 	}
 
 	@Test
@@ -236,7 +256,10 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				TestRegisteredClients.registeredClient().build(),
 				OAuth2ParameterNames.STATE,
 				OAuth2ErrorCodes.INVALID_REQUEST,
-				request -> request.addParameter(OAuth2ParameterNames.STATE, "state2"));
+				request -> {
+					request.addParameter(OAuth2ParameterNames.STATE, "state2");
+					updateQueryString(request);
+				});
 	}
 
 	@Test
@@ -266,13 +289,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				request -> {
 					request.addParameter(PkceParameterNames.CODE_CHALLENGE, "code-challenge");
 					request.addParameter(PkceParameterNames.CODE_CHALLENGE, "another-code-challenge");
-					String originalQueryString = request.getQueryString();
-					if (StringUtils.hasText(originalQueryString)) {
-						String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE)
-								.concat("=code-challenge").concat("&")
-								.concat(PkceParameterNames.CODE_CHALLENGE).concat("=another-code-challenge");
-						request.setQueryString(newQueryString);
-					}
+					updateQueryString(request);
 				});
 	}
 
@@ -285,13 +302,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				request -> {
 					request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
 					request.addParameter(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
-					String originalQueryString = request.getQueryString();
-					if (StringUtils.hasText(originalQueryString)) {
-						String newQueryString = originalQueryString.concat(PkceParameterNames.CODE_CHALLENGE_METHOD)
-								.concat("=S256").concat("&")
-								.concat(PkceParameterNames.CODE_CHALLENGE_METHOD).concat("=S256");
-						request.setQueryString(newQueryString);
-					}
+					updateQueryString(request);
 				});
 	}
 
@@ -574,10 +585,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 
 		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
 		request.addParameter("custom-param", "custom-value-1", "custom-value-2");
-		String newQueryString = request.getQueryString().concat("custom-param")
-				.concat("=custom-value-1").concat("&")
-				.concat("custom-param").concat("=custom-value-2");
-		request.setQueryString(newQueryString);
+		updateQueryString(request);
 
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
@@ -623,6 +631,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 
 		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
 		request.setMethod("POST");	// OpenID Connect supports POST method
+		request.setQueryString(null);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
@@ -667,15 +676,18 @@ public class OAuth2AuthorizationEndpointFilterTests {
 
 	private static MockHttpServletRequest createAuthorizationRequest(RegisteredClient registeredClient) {
 		String requestUri = DEFAULT_AUTHORIZATION_ENDPOINT_URI;
-		MockHttpServletRequest request = MockMvcRequestBuilders.get(requestUri)
-				.queryParam(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue())
-				.queryParam(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
-				.queryParam(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next())
-				.queryParam(OAuth2ParameterNames.SCOPE,
-						StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "))
-				.queryParam(OAuth2ParameterNames.STATE, "state")
-				.buildRequest(new MockServletContext());
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
 		request.setRemoteAddr(REMOTE_ADDRESS);
+
+		request.addParameter(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
+		request.addParameter(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId());
+		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
+		request.addParameter(OAuth2ParameterNames.SCOPE,
+				StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
+		request.addParameter(OAuth2ParameterNames.STATE, "state");
+		updateQueryString(request);
+
 		return request;
 	}
 
@@ -692,6 +704,18 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		return request;
 	}
 
+	private static void updateQueryString(MockHttpServletRequest request) {
+		UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(request.getRequestURI());
+		request.getParameterMap().forEach((key, values) -> {
+			if (values.length > 0) {
+				for (String value : values) {
+					uriBuilder.queryParam(key, value);
+				}
+			}
+		});
+		request.setQueryString(uriBuilder.build().getQuery());
+	}
+
 	private static String scopeCheckbox(String scope) {
 		return MessageFormat.format(
 				"<input class=\"form-check-input\" type=\"checkbox\" name=\"scope\" value=\"{0}\" id=\"{0}\">",

+ 0 - 25
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/ClientSecretPostAuthenticationConverterTests.java

@@ -79,31 +79,6 @@ public class ClientSecretPostAuthenticationConverterTests {
 				.isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
 	}
 
-	// gh-1378
-	@Test
-	public void convertWhenClientCredentialsInQueryParamThenInvalidRequestError() {
-		MockHttpServletRequest request = new MockHttpServletRequest();
-		request.addParameter(OAuth2ParameterNames.CLIENT_ID, "client-1");
-		request.addParameter(OAuth2ParameterNames.CLIENT_SECRET, "client-secret");
-		request.setQueryString("client_id=client-1");
-		assertThatThrownBy(() -> this.converter.convert(request))
-				.isInstanceOf(OAuth2AuthenticationException.class)
-				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.satisfies(error -> {
-					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
-					assertThat(error.getDescription()).isEqualTo("Client credentials MUST NOT be included in the request URI.");
-				});
-
-		request.setQueryString("client_secret=client-secret");
-		assertThatThrownBy(() -> this.converter.convert(request))
-				.isInstanceOf(OAuth2AuthenticationException.class)
-				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.satisfies(error -> {
-					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.INVALID_REQUEST);
-					assertThat(error.getDescription()).isEqualTo("Client credentials MUST NOT be included in the request URI.");
-				});
-	}
-
 	@Test
 	public void convertWhenPostWithValidCredentialsThenReturnClientAuthenticationToken() {
 		MockHttpServletRequest request = new MockHttpServletRequest();