Pārlūkot izejas kodu

Ignore unknown token_type_hint

Closes gh-174
Joe Grandja 4 gadi atpakaļ
vecāks
revīzija
7f8aff7982

+ 30 - 10
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java

@@ -63,23 +63,43 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
 				.orElse(null);
 	}
 
-	private boolean hasToken(OAuth2Authorization authorization, String token, TokenType tokenType) {
-		if (OAuth2AuthorizationAttributeNames.STATE.equals(tokenType.getValue())) {
-			return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE));
+	private static boolean hasToken(OAuth2Authorization authorization, String token, @Nullable TokenType tokenType) {
+		if (tokenType == null) {
+			return matchesState(authorization, token) ||
+					matchesAuthorizationCode(authorization, token) ||
+					matchesAccessToken(authorization, token) ||
+					matchesRefreshToken(authorization, token);
+		} else if (OAuth2AuthorizationAttributeNames.STATE.equals(tokenType.getValue())) {
+			return matchesState(authorization, token);
 		} else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) {
-			OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
-			return authorizationCode != null && authorizationCode.getTokenValue().equals(token);
+			return matchesAuthorizationCode(authorization, token);
 		} else if (TokenType.ACCESS_TOKEN.equals(tokenType)) {
-			return authorization.getTokens().getAccessToken() != null &&
-					authorization.getTokens().getAccessToken().getTokenValue().equals(token);
+			return matchesAccessToken(authorization, token);
 		} else if (TokenType.REFRESH_TOKEN.equals(tokenType)) {
-			return authorization.getTokens().getRefreshToken() != null &&
-					authorization.getTokens().getRefreshToken().getTokenValue().equals(token);
+			return matchesRefreshToken(authorization, token);
 		}
-
 		return false;
 	}
 
+	private static boolean matchesState(OAuth2Authorization authorization, String token) {
+		return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE));
+	}
+
+	private static boolean matchesAuthorizationCode(OAuth2Authorization authorization, String token) {
+		OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
+		return authorizationCode != null && authorizationCode.getTokenValue().equals(token);
+	}
+
+	private static boolean matchesAccessToken(OAuth2Authorization authorization, String token) {
+		return authorization.getTokens().getAccessToken() != null &&
+				authorization.getTokens().getAccessToken().getTokenValue().equals(token);
+	}
+
+	private static boolean matchesRefreshToken(OAuth2Authorization authorization, String token) {
+		return authorization.getTokens().getRefreshToken() != null &&
+				authorization.getTokens().getRefreshToken().getTokenValue().equals(token);
+	}
+
 	private static class OAuth2AuthorizationId implements Serializable {
 		private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
 		private final String registeredClientId;

+ 0 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java

@@ -22,7 +22,6 @@ import org.springframework.security.oauth2.core.AbstractOAuth2Token;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
-import org.springframework.security.oauth2.core.OAuth2ErrorCodes2;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.TokenType;
@@ -71,8 +70,6 @@ public class OAuth2TokenRevocationAuthenticationProvider implements Authenticati
 				tokenType = TokenType.REFRESH_TOKEN;
 			} else if (TokenType.ACCESS_TOKEN.getValue().equals(tokenTypeHint)) {
 				tokenType = TokenType.ACCESS_TOKEN;
-			} else {
-				throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes2.UNSUPPORTED_TOKEN_TYPE));
 			}
 		}
 

+ 13 - 5
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java

@@ -101,7 +101,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 	}
 
 	@Test
-	public void findByTokenWhenTokenTypeStateThenFound() {
+	public void findByTokenWhenStateExistsThenFound() {
 		String state = "state";
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
@@ -112,10 +112,12 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 		OAuth2Authorization result = this.authorizationService.findByToken(
 				state, new TokenType(OAuth2AuthorizationAttributeNames.STATE));
 		assertThat(authorization).isEqualTo(result);
+		result = this.authorizationService.findByToken(state, null);
+		assertThat(authorization).isEqualTo(result);
 	}
 
 	@Test
-	public void findByTokenWhenTokenTypeAuthorizationCodeThenFound() {
+	public void findByTokenWhenAuthorizationCodeExistsThenFound() {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
 				.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build())
@@ -125,10 +127,12 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 		OAuth2Authorization result = this.authorizationService.findByToken(
 				AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE);
 		assertThat(authorization).isEqualTo(result);
+		result = this.authorizationService.findByToken(AUTHORIZATION_CODE.getTokenValue(), null);
+		assertThat(authorization).isEqualTo(result);
 	}
 
 	@Test
-	public void findByTokenWhenTokenTypeAccessTokenThenFound() {
+	public void findByTokenWhenAccessTokenExistsThenFound() {
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
 				"access-token", Instant.now().minusSeconds(60), Instant.now());
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
@@ -138,12 +142,14 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 		this.authorizationService.save(authorization);
 
 		OAuth2Authorization result = this.authorizationService.findByToken(
-				"access-token", TokenType.ACCESS_TOKEN);
+				accessToken.getTokenValue(), TokenType.ACCESS_TOKEN);
+		assertThat(authorization).isEqualTo(result);
+		result = this.authorizationService.findByToken(accessToken.getTokenValue(), null);
 		assertThat(authorization).isEqualTo(result);
 	}
 
 	@Test
-	public void findByTokenWhenTokenTypeRefreshTokenThenFound() {
+	public void findByTokenWhenRefreshTokenExistsThenFound() {
 		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now());
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
@@ -154,6 +160,8 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 		OAuth2Authorization result = this.authorizationService.findByToken(
 				refreshToken.getTokenValue(), TokenType.REFRESH_TOKEN);
 		assertThat(authorization).isEqualTo(result);
+		result = this.authorizationService.findByToken(refreshToken.getTokenValue(), null);
+		assertThat(authorization).isEqualTo(result);
 	}
 
 	@Test

+ 0 - 14
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java

@@ -23,7 +23,6 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
-import org.springframework.security.oauth2.core.OAuth2ErrorCodes2;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@@ -97,19 +96,6 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
 				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
 	}
 
-	@Test
-	public void authenticateWhenInvalidTokenTypeThenThrowOAuth2AuthenticationException() {
-		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
-		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
-		OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
-				"token", clientPrincipal, OAuth2ErrorCodes2.UNSUPPORTED_TOKEN_TYPE);
-		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
-				.isInstanceOf(OAuth2AuthenticationException.class)
-				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
-				.extracting("errorCode")
-				.isEqualTo(OAuth2ErrorCodes2.UNSUPPORTED_TOKEN_TYPE);
-	}
-
 	@Test
 	public void authenticateWhenInvalidTokenThenNotRevoked() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();