Joe Grandja преди 2 години
родител
ревизия
4eb25c163f

+ 16 - 11
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -20,6 +20,7 @@ import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Map;
 import java.util.Set;
 
 import javax.servlet.FilterChain;
@@ -252,13 +253,11 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
 		String state = authorizationConsentAuthentication.getState();
 
 		if (hasConsentUri()) {
-			UriComponentsBuilder uriBuilder = UriComponentsBuilder.fromUriString(resolveConsentUri(request))
+			String redirectUri = UriComponentsBuilder.fromUriString(resolveConsentUri(request))
 					.queryParam(OAuth2ParameterNames.SCOPE, String.join(" ", requestedScopes))
 					.queryParam(OAuth2ParameterNames.CLIENT_ID, clientId)
-					.queryParam(OAuth2ParameterNames.STATE, "{state}");
-			HashMap<String, String> queryParameters = new HashMap<>(1);
-			queryParameters.put(OAuth2ParameterNames.STATE, state);
-			String redirectUri = uriBuilder.build(queryParameters).toString();
+					.queryParam(OAuth2ParameterNames.STATE, state)
+					.toUriString();
 			this.redirectStrategy.sendRedirect(request, response, redirectUri);
 		} else {
 			DefaultConsentPage.displayConsent(request, response, clientId, principal, requestedScopes, authorizedScopes, state);
@@ -290,12 +289,15 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
 		UriComponentsBuilder uriBuilder = UriComponentsBuilder
 				.fromUriString(authorizationCodeRequestAuthentication.getRedirectUri())
 				.queryParam(OAuth2ParameterNames.CODE, authorizationCodeRequestAuthentication.getAuthorizationCode().getTokenValue());
+		String redirectUri;
 		if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) {
 			uriBuilder.queryParam(OAuth2ParameterNames.STATE, "{state}");
+			Map<String, String> queryParams = new HashMap<>();
+			queryParams.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState());
+			redirectUri = uriBuilder.build(queryParams).toString();
+		} else {
+			redirectUri = uriBuilder.toUriString();
 		}
-		HashMap<String, String> queryParams = new HashMap<>();
-		queryParams.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState());
-		String redirectUri = uriBuilder.build(queryParams).toString();
 		this.redirectStrategy.sendRedirect(request, response, redirectUri);
 	}
 
@@ -323,12 +325,15 @@ public final class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilte
 		if (StringUtils.hasText(error.getUri())) {
 			uriBuilder.queryParam(OAuth2ParameterNames.ERROR_URI, error.getUri());
 		}
+		String redirectUri;
 		if (StringUtils.hasText(authorizationCodeRequestAuthentication.getState())) {
 			uriBuilder.queryParam(OAuth2ParameterNames.STATE, "{state}");
+			Map<String, String> queryParams = new HashMap<>();
+			queryParams.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState());
+			redirectUri = uriBuilder.build(queryParams).toString();
+		} else {
+			redirectUri = uriBuilder.toUriString();
 		}
-		HashMap<String, String> queryParams = new HashMap<>();
-		queryParams.put(OAuth2ParameterNames.STATE, authorizationCodeRequestAuthentication.getState());
-		String redirectUri = uriBuilder.build(queryParams).toString();
 		this.redirectStrategy.sendRedirect(request, response, redirectUri);
 	}
 

+ 9 - 18
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java

@@ -39,53 +39,44 @@ import org.springframework.util.CollectionUtils;
 public class TestOAuth2Authorizations {
 
 	public static OAuth2Authorization.Builder authorization() {
-		return authorization(TestRegisteredClients.registeredClient().build(), "state");
-	}
-
-	public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient, String state) {
-		return authorization(registeredClient, Collections.emptyMap(), state);
+		return authorization(TestRegisteredClients.registeredClient().build());
 	}
 
 	public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient) {
-		return authorization(registeredClient, Collections.emptyMap(), "state");
+		return authorization(registeredClient, Collections.emptyMap());
 	}
 
 	public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient,
-			Map<String, Object> authorizationRequestAdditionalParameters, String state) {
+			Map<String, Object> authorizationRequestAdditionalParameters) {
 		OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
 				"code", Instant.now(), Instant.now().plusSeconds(120));
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(
 				OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
-		return authorization(registeredClient, authorizationCode, accessToken, Collections.emptyMap(), authorizationRequestAdditionalParameters, state);
-	}
-
-	public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient,
-			Map<String, Object> authorizationRequestAdditionalParameters) {
-		return authorization(registeredClient, authorizationRequestAdditionalParameters, "state");
+		return authorization(registeredClient, authorizationCode, accessToken, Collections.emptyMap(), authorizationRequestAdditionalParameters);
 	}
 
 	public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient,
 			OAuth2AuthorizationCode authorizationCode) {
-		return authorization(registeredClient, authorizationCode, null, Collections.emptyMap(), Collections.emptyMap(), "state");
+		return authorization(registeredClient, authorizationCode, null, Collections.emptyMap(), Collections.emptyMap());
 	}
 
 	public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient,
 			OAuth2AccessToken accessToken, Map<String, Object> accessTokenClaims) {
 		OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
 				"code", Instant.now(), Instant.now().plusSeconds(120));
-		return authorization(registeredClient, authorizationCode, accessToken, accessTokenClaims, Collections.emptyMap(), "state");
+		return authorization(registeredClient, authorizationCode, accessToken, accessTokenClaims, Collections.emptyMap());
 	}
 
 	private static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient,
 			OAuth2AuthorizationCode authorizationCode, OAuth2AccessToken accessToken,
-			Map<String, Object> accessTokenClaims, Map<String, Object> authorizationRequestAdditionalParameters, String state) {
+			Map<String, Object> accessTokenClaims, Map<String, Object> authorizationRequestAdditionalParameters) {
 		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
 				.authorizationUri("https://provider.com/oauth2/authorize")
 				.clientId(registeredClient.getClientId())
 				.redirectUri(registeredClient.getRedirectUris().iterator().next())
 				.scopes(registeredClient.getScopes())
 				.additionalParameters(authorizationRequestAdditionalParameters)
-				.state(state)
+				.state("state")
 				.build();
 		OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient)
 				.id("id")
@@ -93,7 +84,7 @@ public class TestOAuth2Authorizations {
 				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
 				.authorizedScopes(authorizationRequest.getScopes())
 				.token(authorizationCode)
-				.attribute(OAuth2ParameterNames.STATE, state)
+				.attribute(OAuth2ParameterNames.STATE, "consent-state")
 				.attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest)
 				.attribute(Principal.class.getName(),
 						new TestingAuthenticationToken("principal", null, "ROLE_A", "ROLE_B"));

+ 22 - 6
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java

@@ -69,6 +69,7 @@ import org.springframework.security.crypto.password.PasswordEncoder;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
@@ -293,7 +294,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 				.andExpect(status().is3xxRedirection())
 				.andReturn();
 		String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
-		assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state="+STATE_URL_ENCODED);
+		assertThat(redirectedUrl).matches("https://example.com\\?code=.{15,}&state=" + STATE_URL_ENCODED);
 
 		String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
 		OAuth2Authorization authorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
@@ -502,9 +503,16 @@ public class OAuth2AuthorizationCodeGrantTests {
 				.build();
 		this.registeredClientRepository.save(registeredClient);
 
-		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient, STATE_URL_UNENCODED)
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
 				.principalName("user")
-				.attribute(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED)
+				.build();
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName());
+		OAuth2AuthorizationRequest updatedAuthorizationRequest =
+				OAuth2AuthorizationRequest.from(authorizationRequest)
+						.state(STATE_URL_UNENCODED)
+						.build();
+		authorization = OAuth2Authorization.from(authorization)
+				.attribute(OAuth2AuthorizationRequest.class.getName(), updatedAuthorizationRequest)
 				.build();
 		this.authorizationService.save(authorization);
 
@@ -512,7 +520,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
 				.param(OAuth2ParameterNames.SCOPE, "message.read")
 				.param(OAuth2ParameterNames.SCOPE, "message.write")
-				.param(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED)
+				.param(OAuth2ParameterNames.STATE, authorization.<String>getAttribute(OAuth2ParameterNames.STATE))
 				.with(user("user")))
 				.andExpect(status().is3xxRedirection())
 				.andReturn();
@@ -584,14 +592,22 @@ public class OAuth2AuthorizationCodeGrantTests {
 				.build();
 		this.registeredClientRepository.save(registeredClient);
 
-		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient, STATE_URL_UNENCODED)
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
+				.build();
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationRequest.class.getName());
+		OAuth2AuthorizationRequest updatedAuthorizationRequest =
+				OAuth2AuthorizationRequest.from(authorizationRequest)
+						.state(STATE_URL_UNENCODED)
+						.build();
+		authorization = OAuth2Authorization.from(authorization)
+				.attribute(OAuth2AuthorizationRequest.class.getName(), updatedAuthorizationRequest)
 				.build();
 		this.authorizationService.save(authorization);
 
 		MvcResult mvcResult = this.mvc.perform(post(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
 				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
 				.param("authority", "authority-1 authority-2")
-				.param(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED)
+				.param(OAuth2ParameterNames.STATE, authorization.<String>getAttribute(OAuth2ParameterNames.STATE))
 				.with(user("principal")))
 				.andExpect(status().is3xxRedirection())
 				.andReturn();

+ 9 - 12
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -85,9 +85,6 @@ public class OAuth2AuthorizationEndpointFilterTests {
 	private static final String AUTHORIZATION_URI = "https://provider.com/oauth2/authorize";
 	private static final String STATE = "state";
 	private static final String REMOTE_ADDRESS = "remote-address";
-	private static final String STATE_URL_UNENCODED = "awrD0fCnEcTUPFgmyy2SU89HZNcnAJ60ZW6l39YI0KyVjmIZ+004pwm9j55li7BoydXYysH4enZMF21Q";
-	private static final String STATE_URL_ENCODED = "awrD0fCnEcTUPFgmyy2SU89HZNcnAJ60ZW6l39YI0KyVjmIZ%2B004pwm9j55li7BoydXYysH4enZMF21Q";
-
 	private AuthenticationManager authenticationManager;
 	private OAuth2AuthorizationEndpointFilter filter;
 	private TestingAuthenticationToken principal;
@@ -287,7 +284,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
 				new OAuth2AuthorizationCodeRequestAuthenticationToken(
 						AUTHORIZATION_URI, registeredClient.getClientId(), principal,
-						registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes(), null);
+						registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes(), null);
 		OAuth2Error error = new OAuth2Error("errorCode", "errorDescription", "errorUri");
 		when(this.authenticationManager.authenticate(any()))
 				.thenThrow(new OAuth2AuthorizationCodeRequestAuthenticationException(error, authorizationCodeRequestAuthentication));
@@ -302,7 +299,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		verifyNoInteractions(filterChain);
 
 		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
-		assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?error=errorCode&error_description=errorDescription&error_uri=errorUri&state=" + STATE_URL_ENCODED);
+		assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?error=errorCode&error_description=errorDescription&error_uri=errorUri&state=state");
 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.principal);
 	}
 
@@ -446,7 +443,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2AuthorizationConsentAuthenticationToken authorizationConsentAuthenticationResult =
 				new OAuth2AuthorizationConsentAuthenticationToken(
 						AUTHORIZATION_URI, registeredClient.getClientId(), principal,
-						STATE_URL_UNENCODED, new HashSet<>(), null);	// No scopes previously approved
+						STATE, new HashSet<>(), null);	// No scopes previously approved
 		authorizationConsentAuthenticationResult.setAuthenticated(true);
 		when(this.authenticationManager.authenticate(any()))
 				.thenReturn(authorizationConsentAuthenticationResult);
@@ -462,7 +459,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		verifyNoInteractions(filterChain);
 
 		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
-		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/oauth2/custom-consent?scope=scope1%20scope2&client_id=client-1&state=" + STATE_URL_ENCODED);
+		assertThat(response.getRedirectedUrl()).isEqualTo("http://localhost/oauth2/custom-consent?scope=scope1%20scope2&client_id=client-1&state=state");
 	}
 
 	@Test
@@ -542,7 +539,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult =
 				new OAuth2AuthorizationCodeRequestAuthenticationToken(
 						AUTHORIZATION_URI, registeredClient.getClientId(), principal, this.authorizationCode,
-						registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes());
+						registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes());
 		authorizationCodeRequestAuthenticationResult.setAuthenticated(true);
 		when(this.authenticationManager.authenticate(any()))
 				.thenReturn(authorizationCodeRequestAuthenticationResult);
@@ -563,7 +560,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				.extracting(WebAuthenticationDetails::getRemoteAddress)
 				.isEqualTo(REMOTE_ADDRESS);
 		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
-		assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=" + STATE_URL_ENCODED);
+		assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=state");
 	}
 
 	@Test
@@ -578,7 +575,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult =
 				new OAuth2AuthorizationCodeRequestAuthenticationToken(
 						AUTHORIZATION_URI, registeredClient.getClientId(), principal, this.authorizationCode,
-						registeredClient.getRedirectUris().iterator().next(), STATE_URL_UNENCODED, registeredClient.getScopes());
+						registeredClient.getRedirectUris().iterator().next(), STATE, registeredClient.getScopes());
 		authorizationCodeRequestAuthenticationResult.setAuthenticated(true);
 		when(this.authenticationManager.authenticate(any()))
 				.thenReturn(authorizationCodeRequestAuthenticationResult);
@@ -594,7 +591,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		verifyNoInteractions(filterChain);
 
 		assertThat(response.getStatus()).isEqualTo(HttpStatus.FOUND.value());
-		assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=" + STATE_URL_ENCODED);
+		assertThat(response.getRedirectedUrl()).isEqualTo("https://example.com?code=code&state=state");
 	}
 
 	private void doFilterWhenAuthorizationRequestInvalidParameterThenError(RegisteredClient registeredClient,
@@ -637,7 +634,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		request.addParameter(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
 		request.addParameter(OAuth2ParameterNames.SCOPE,
 				StringUtils.collectionToDelimitedString(registeredClient.getScopes(), " "));
-		request.addParameter(OAuth2ParameterNames.STATE, STATE_URL_UNENCODED);
+		request.addParameter(OAuth2ParameterNames.STATE, "state");
 
 		return request;
 	}