瀏覽代碼

Remove OAuth2AuthorizationAttributeNames.STATE

Issue gh-213
Joe Grandja 4 年之前
父節點
當前提交
cee5aacc15

+ 3 - 2
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java

@@ -23,6 +23,7 @@ import java.util.concurrent.ConcurrentHashMap;
 import org.springframework.lang.Nullable;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
 import org.springframework.util.Assert;
 
@@ -72,7 +73,7 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
 					matchesAuthorizationCode(authorization, token) ||
 					matchesAccessToken(authorization, token) ||
 					matchesRefreshToken(authorization, token);
-		} else if (OAuth2AuthorizationAttributeNames.STATE.equals(tokenType.getValue())) {
+		} else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) {
 			return matchesState(authorization, token);
 		} else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) {
 			return matchesAuthorizationCode(authorization, token);
@@ -85,7 +86,7 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
 	}
 
 	private static boolean matchesState(OAuth2Authorization authorization, String token) {
-		return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE));
+		return token.equals(authorization.getAttribute(OAuth2ParameterNames.STATE));
 	}
 
 	private static boolean matchesAuthorizationCode(OAuth2Authorization authorization, String token) {

+ 0 - 5
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java

@@ -28,11 +28,6 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
  */
 public interface OAuth2AuthorizationAttributeNames {
 
-	/**
-	 * The name of the attribute used for correlating the user consent request/response.
-	 */
-	String STATE = OAuth2Authorization.class.getName().concat(".STATE");
-
 	/**
 	 * The name of the attribute used for the {@link OAuth2AuthorizationRequest}.
 	 */

+ 4 - 4
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -200,7 +200,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 		if (registeredClient.getClientSettings().requireUserConsent()) {
 			String state = this.stateGenerator.generateKey();
 			OAuth2Authorization authorization = builder
-					.attribute(OAuth2AuthorizationAttributeNames.STATE, state)
+					.attribute(OAuth2ParameterNames.STATE, state)
 					.build();
 			this.authorizationService.save(authorization);
 
@@ -266,7 +266,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 		OAuth2Authorization authorization = OAuth2Authorization.from(userConsentRequestContext.getAuthorization())
 				.token(authorizationCode)
 				.attributes(attrs -> {
-					attrs.remove(OAuth2AuthorizationAttributeNames.STATE);
+					attrs.remove(OAuth2ParameterNames.STATE);
 					attrs.put(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, userConsentRequestContext.getScopes());
 				})
 				.build();
@@ -376,7 +376,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 			return;
 		}
 		OAuth2Authorization authorization = this.authorizationService.findByToken(
-				userConsentRequestContext.getState(), new TokenType(OAuth2AuthorizationAttributeNames.STATE));
+				userConsentRequestContext.getState(), new TokenType(OAuth2ParameterNames.STATE));
 		if (authorization == null) {
 			userConsentRequestContext.setError(
 					createError(OAuth2ErrorCodes.INVALID_REQUEST, OAuth2ParameterNames.STATE));
@@ -661,7 +661,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 			OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
 					OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
 			String state = authorization.getAttribute(
-					OAuth2AuthorizationAttributeNames.STATE);
+					OAuth2ParameterNames.STATE);
 
 			StringBuilder builder = new StringBuilder();
 

+ 3 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java

@@ -24,6 +24,7 @@ import org.junit.Test;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
@@ -110,12 +111,12 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
 				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
-				.attribute(OAuth2AuthorizationAttributeNames.STATE, state)
+				.attribute(OAuth2ParameterNames.STATE, state)
 				.build();
 		this.authorizationService.save(authorization);
 
 		OAuth2Authorization result = this.authorizationService.findByToken(
-				state, new TokenType(OAuth2AuthorizationAttributeNames.STATE));
+				state, new TokenType(OAuth2ParameterNames.STATE));
 		assertThat(authorization).isEqualTo(result);
 		result = this.authorizationService.findByToken(state, null);
 		assertThat(authorization).isEqualTo(result);

+ 11 - 11
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -569,7 +569,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
 				.isEqualTo(this.authentication);
 
-		String state = authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE);
+		String state = authorization.getAttribute(OAuth2ParameterNames.STATE);
 		assertThat(state).isNotNull();
 
 		Set<String> authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES);
@@ -620,7 +620,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
 				.thenReturn(registeredClient);
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
-		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE))))
+		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
 				.thenReturn(authorization);
 
 		this.authentication.setAuthenticated(false);
@@ -638,7 +638,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
 				.thenReturn(registeredClient);
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
-		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE))))
+		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
 				.thenReturn(authorization);
 
 		this.authentication = new TestingAuthenticationToken("other-principal", "password");
@@ -662,7 +662,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
 				.principalName(this.authentication.getName())
 				.build();
-		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE))))
+		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
 				.thenReturn(authorization);
 
 		doFilterWhenUserConsentRequestInvalidParameterThenError(
@@ -680,7 +680,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
 				.principalName(this.authentication.getName())
 				.build();
-		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE))))
+		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
 				.thenReturn(authorization);
 
 		doFilterWhenUserConsentRequestInvalidParameterThenError(
@@ -698,7 +698,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
 				.principalName(this.authentication.getName())
 				.build();
-		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE))))
+		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
 				.thenReturn(authorization);
 
 		doFilterWhenUserConsentRequestInvalidParameterThenError(
@@ -717,7 +717,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(otherRegisteredClient)
 				.principalName(this.authentication.getName())
 				.build();
-		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE))))
+		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
 				.thenReturn(authorization);
 
 		doFilterWhenUserConsentRequestInvalidParameterThenError(
@@ -735,7 +735,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
 				.principalName(this.authentication.getName())
 				.build();
-		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE))))
+		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
 				.thenReturn(authorization);
 
 		doFilterWhenUserConsentRequestInvalidParameterThenRedirect(
@@ -756,7 +756,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
 				.principalName(this.authentication.getName())
 				.build();
-		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE))))
+		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
 				.thenReturn(authorization);
 
 		doFilterWhenUserConsentRequestInvalidParameterThenRedirect(
@@ -777,7 +777,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
 				.principalName(this.authentication.getName())
 				.build();
-		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2AuthorizationAttributeNames.STATE))))
+		when(this.authorizationService.findByToken(eq("state"), eq(new TokenType(OAuth2ParameterNames.STATE))))
 				.thenReturn(authorization);
 
 		MockHttpServletRequest request = createUserConsentRequest(registeredClient);
@@ -800,7 +800,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
 		assertThat(updatedAuthorization.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
 		assertThat(updatedAuthorization.getToken(OAuth2AuthorizationCode.class)).isNotNull();
-		assertThat(updatedAuthorization.<String>getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull();
+		assertThat(updatedAuthorization.<String>getAttribute(OAuth2ParameterNames.STATE)).isNull();
 		assertThat(updatedAuthorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST))
 				.isEqualTo(authorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST));
 		assertThat(updatedAuthorization.<Set<String>>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES))