Joe Grandja 4 жил өмнө
parent
commit
b7ddb837d6
18 өөрчлөгдсөн 585 нэмэгдсэн , 438 устгасан
  1. 0 48
      core/src/main/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationService.java
  2. 0 34
      core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenRevocationService.java
  3. 0 92
      core/src/test/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationServiceTests.java
  4. 23 6
      oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java
  5. 1 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/TokenType.java
  6. 61 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthenticationProviderUtils.java
  7. 13 11
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java
  8. 35 30
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java
  9. 31 29
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationToken.java
  10. 17 22
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java
  11. 35 32
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java
  12. 188 0
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java
  13. 5 1
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java
  14. 4 2
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java
  15. 85 41
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java
  16. 30 25
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationTokenTests.java
  17. 8 29
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java
  18. 49 35
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java

+ 0 - 48
core/src/main/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationService.java

@@ -1,48 +0,0 @@
-/*
- * Copyright 2020 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.server.authorization;
-
-import org.springframework.util.Assert;
-
-/**
- * An {@link OAuth2TokenRevocationService} that revokes tokens.
- *
- * @author Vivek Babu
- * @see OAuth2AuthorizationService
- * @since 0.0.1
- */
-public final class DefaultOAuth2TokenRevocationService implements OAuth2TokenRevocationService {
-
-	private OAuth2AuthorizationService authorizationService;
-
-	/**
-	 * Constructs an {@code DefaultOAuth2TokenRevocationService}.
-	 */
-	public DefaultOAuth2TokenRevocationService(OAuth2AuthorizationService authorizationService) {
-		Assert.notNull(authorizationService, "authorizationService cannot be null");
-		this.authorizationService = authorizationService;
-	}
-
-	@Override
-	public void revoke(String token, TokenType tokenType) {
-		final OAuth2Authorization authorization = this.authorizationService.findByTokenAndTokenType(token, tokenType);
-		if (authorization != null) {
-			final OAuth2Authorization revokedAuthorization = OAuth2Authorization.from(authorization)
-					.revoked(true).build();
-			this.authorizationService.save(revokedAuthorization);
-		}
-	}
-}

+ 0 - 34
core/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenRevocationService.java

@@ -1,34 +0,0 @@
-/*
- * Copyright 2020 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.server.authorization;
-
-/**
- * Implementations of this interface are responsible for the revocation of
- * OAuth2 tokens.
- *
- * @author Vivek Babu
- * @since 0.0.1
- */
-public interface OAuth2TokenRevocationService {
-
-	/**
-	 * Revokes the given token.
-	 *
-	 * @param token the token to be revoked
-	 * @param tokenType the type of token to be revoked
-	 */
-	void revoke(String token, TokenType tokenType);
-}

+ 0 - 92
core/src/test/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenRevocationServiceTests.java

@@ -1,92 +0,0 @@
-/*
- * Copyright 2020 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.server.authorization;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.mockito.ArgumentCaptor;
-import org.springframework.security.oauth2.core.OAuth2AccessToken;
-import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
-import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
-
-import java.time.Instant;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
-import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.eq;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
-
-/**
- * Tests for {@link DefaultOAuth2TokenRevocationService}.
- *
- * @author Vivek Babu
- */
-public class DefaultOAuth2TokenRevocationServiceTests {
-	private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
-	private static final String PRINCIPAL_NAME = "principal";
-	private static final String AUTHORIZATION_CODE = "code";
-	private DefaultOAuth2TokenRevocationService revocationService;
-	private OAuth2AuthorizationService authorizationService;
-
-	@Before
-	public void setup() {
-		this.authorizationService = mock(OAuth2AuthorizationService.class);
-		this.revocationService = new DefaultOAuth2TokenRevocationService(authorizationService);
-	}
-
-	@Test
-	public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new DefaultOAuth2TokenRevocationService(null))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("authorizationService cannot be null");
-	}
-
-	@Test
-	public void revokeWhenTokenNotFound() {
-		this.revocationService.revoke("token", TokenType.ACCESS_TOKEN);
-		verify(authorizationService, times(1)).findByTokenAndTokenType(eq("token"),
-				eq(TokenType.ACCESS_TOKEN));
-		verify(authorizationService, times(0)).save(any());
-	}
-
-	@Test
-	public void revokeWhenTokenFound() {
-		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
-				"token", Instant.now().minusSeconds(60), Instant.now());
-		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
-				.principalName(PRINCIPAL_NAME)
-				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
-				.accessToken(accessToken)
-				.build();
-		when(authorizationService.findByTokenAndTokenType(eq("token"), eq(TokenType.ACCESS_TOKEN)))
-				.thenReturn(authorization);
-		this.revocationService.revoke("token", TokenType.ACCESS_TOKEN);
-
-		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
-		verify(this.authorizationService).save(authorizationCaptor.capture());
-		final OAuth2Authorization savedAuthorization = authorizationCaptor.getValue();
-		assertThat(savedAuthorization.getPrincipalName()).isEqualTo(authorization.getPrincipalName());
-		assertThat((String) savedAuthorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE))
-				.isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE));
-		assertThat(savedAuthorization.getAccessToken()).isEqualTo(authorization.getAccessToken());
-		assertThat(savedAuthorization.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId());
-		assertThat(savedAuthorization.isRevoked()).isTrue();
-	}
-}

+ 23 - 6
oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java

@@ -31,11 +31,13 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationProvider;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.web.JwkSetEndpointFilter;
 import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
 import org.springframework.security.oauth2.server.authorization.web.OAuth2ClientAuthenticationFilter;
 import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter;
+import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenRevocationEndpointFilter;
 import org.springframework.security.web.access.intercept.FilterSecurityInterceptor;
 import org.springframework.security.web.authentication.HttpStatusEntryPoint;
 import org.springframework.security.web.authentication.preauth.AbstractPreAuthenticatedProcessingFilter;
@@ -73,6 +75,8 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
 					HttpMethod.POST.name()));
 	private final RequestMatcher tokenEndpointMatcher = new AntPathRequestMatcher(
 			OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI, HttpMethod.POST.name());
+	private final RequestMatcher tokenRevocationEndpointMatcher = new AntPathRequestMatcher(
+			OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI, HttpMethod.POST.name());
 	private final RequestMatcher jwkSetEndpointMatcher = new AntPathRequestMatcher(
 			JwkSetEndpointFilter.DEFAULT_JWK_SET_ENDPOINT_URI, HttpMethod.GET.name());
 
@@ -118,8 +122,8 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
 	 * @return a {@code List} of {@link RequestMatcher}'s for the authorization server endpoints
 	 */
 	public List<RequestMatcher> getEndpointMatchers() {
-		return Arrays.asList(this.authorizationEndpointMatcher,
-				this.tokenEndpointMatcher, this.jwkSetEndpointMatcher);
+		return Arrays.asList(this.authorizationEndpointMatcher, this.tokenEndpointMatcher,
+				this.tokenRevocationEndpointMatcher, this.jwkSetEndpointMatcher);
 	}
 
 	@Override
@@ -145,11 +149,17 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
 						jwtEncoder);
 		builder.authenticationProvider(postProcess(clientCredentialsAuthenticationProvider));
 
+		OAuth2TokenRevocationAuthenticationProvider tokenRevocationAuthenticationProvider =
+				new OAuth2TokenRevocationAuthenticationProvider(
+						getAuthorizationService(builder));
+		builder.authenticationProvider(postProcess(tokenRevocationAuthenticationProvider));
+
 		ExceptionHandlingConfigurer<B> exceptionHandling = builder.getConfigurer(ExceptionHandlingConfigurer.class);
 		if (exceptionHandling != null) {
-			// Register the default AuthenticationEntryPoint for the token endpoint
+			// Register the default AuthenticationEntryPoint for the token endpoint and token revocation endpoint
 			exceptionHandling.defaultAuthenticationEntryPointFor(
-					new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED), this.tokenEndpointMatcher);
+					new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED),
+					new OrRequestMatcher(this.tokenEndpointMatcher, this.tokenRevocationEndpointMatcher));
 		}
 	}
 
@@ -160,8 +170,10 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
 
 		AuthenticationManager authenticationManager = builder.getSharedObject(AuthenticationManager.class);
 
-		OAuth2ClientAuthenticationFilter clientAuthenticationFilter = new OAuth2ClientAuthenticationFilter(
-				authenticationManager, this.tokenEndpointMatcher);
+		OAuth2ClientAuthenticationFilter clientAuthenticationFilter =
+				new OAuth2ClientAuthenticationFilter(
+						authenticationManager,
+						new OrRequestMatcher(this.tokenEndpointMatcher, this.tokenRevocationEndpointMatcher));
 		builder.addFilterAfter(postProcess(clientAuthenticationFilter), AbstractPreAuthenticatedProcessingFilter.class);
 
 		OAuth2AuthorizationEndpointFilter authorizationEndpointFilter =
@@ -175,6 +187,11 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
 						authenticationManager,
 						getAuthorizationService(builder));
 		builder.addFilterAfter(postProcess(tokenEndpointFilter), FilterSecurityInterceptor.class);
+
+		OAuth2TokenRevocationEndpointFilter tokenRevocationEndpointFilter =
+				new OAuth2TokenRevocationEndpointFilter(
+						authenticationManager);
+		builder.addFilterAfter(postProcess(tokenRevocationEndpointFilter), OAuth2TokenEndpointFilter.class);
 	}
 
 	private static <B extends HttpSecurityBuilder<B>> RegisteredClientRepository getRegisteredClientRepository(B builder) {

+ 1 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/TokenType.java

@@ -15,7 +15,6 @@
  */
 package org.springframework.security.oauth2.server.authorization;
 
-import org.springframework.security.oauth2.server.authorization.Version;
 import org.springframework.util.Assert;
 
 import java.io.Serializable;
@@ -26,6 +25,7 @@ import java.io.Serializable;
 public final class TokenType implements Serializable {
 	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
 	public static final TokenType ACCESS_TOKEN = new TokenType("access_token");
+	public static final TokenType REFRESH_TOKEN = new TokenType("refresh_token");
 	public static final TokenType AUTHORIZATION_CODE = new TokenType("authorization_code");
 	private final String value;
 

+ 61 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthenticationProviderUtils.java

@@ -0,0 +1,61 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+import org.springframework.security.authentication.AuthenticationProvider;
+import org.springframework.security.oauth2.core.AbstractOAuth2Token;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
+
+/**
+ * Utility methods for the OAuth 2.0 {@link AuthenticationProvider}'s.
+ *
+ * @author Joe Grandja
+ * @since 0.0.3
+ */
+final class OAuth2AuthenticationProviderUtils {
+
+	private OAuth2AuthenticationProviderUtils() {
+	}
+
+	static <T extends AbstractOAuth2Token> OAuth2Authorization invalidate(
+			OAuth2Authorization authorization, T token) {
+
+		OAuth2Tokens.Builder builder = OAuth2Tokens.from(authorization.getTokens())
+				.token(token, OAuth2TokenMetadata.builder().invalidated().build());
+
+		if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) {
+			builder.token(
+					authorization.getTokens().getAccessToken(),
+					OAuth2TokenMetadata.builder().invalidated().build());
+			OAuth2AuthorizationCode authorizationCode =
+					authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
+			if (authorizationCode != null &&
+					!authorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()) {
+				builder.token(
+						authorizationCode,
+						OAuth2TokenMetadata.builder().invalidated().build());
+			}
+		}
+
+		return OAuth2Authorization.from(authorization)
+				.tokens(builder.build())
+				.build();
+	}
+}

+ 13 - 11
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -105,14 +105,17 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
 		OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
+		OAuth2TokenMetadata authorizationCodeMetadata = authorization.getTokens().getTokenMetadata(authorizationCode);
 
 		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
 				OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
 
 		if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) {
-			// Invalidate the authorization code given that a different client is attempting to use it
-			authorization.getTokens().invalidate(authorizationCode);
-			this.authorizationService.save(authorization);
+			if (!authorizationCodeMetadata.isInvalidated()) {
+				// Invalidate the authorization code given that a different client is attempting to use it
+				authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode);
+				this.authorizationService.save(authorization);
+			}
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
 
@@ -121,9 +124,7 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
 
-		OAuth2TokenMetadata authorizationCodeMetadata = authorization.getTokens().getTokenMetadata(authorizationCode);
 		if (authorizationCodeMetadata.isInvalidated()) {
-			// Prevent the same client from using the authorization code more than once
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
 
@@ -154,15 +155,16 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
 				jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE));
 
-		OAuth2Tokens tokens = OAuth2Tokens.from(authorization.getTokens())
-				.accessToken(accessToken)
-				.build();
-		tokens.invalidate(authorizationCode);		// Invalidate the authorization code as it can only be used once
-
 		authorization = OAuth2Authorization.from(authorization)
-				.tokens(tokens)
+				.tokens(OAuth2Tokens.from(authorization.getTokens())
+						.accessToken(accessToken)
+						.build())
 				.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt)
 				.build();
+
+		// Invalidate the authorization code as it can only be used once
+		authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode);
+
 		this.authorizationService.save(authorization);
 
 		return new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken);

+ 35 - 30
core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java → oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java

@@ -18,78 +18,83 @@ package org.springframework.security.oauth2.server.authorization.authentication;
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.oauth2.core.AbstractOAuth2Token;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
-import org.springframework.security.oauth2.server.authorization.OAuth2TokenRevocationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
-import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
 
 /**
- * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Token Revocation.
+ * An {@link AuthenticationProvider} implementation for OAuth 2.0 Token Revocation.
  *
  * @author Vivek Babu
- * @since 0.0.1
+ * @author Joe Grandja
+ * @since 0.0.3
  * @see OAuth2TokenRevocationAuthenticationToken
  * @see OAuth2AuthorizationService
- * @see OAuth2TokenRevocationService
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7009#section-2.1">Section 2.1 Revocation Request</a>
  */
 public class OAuth2TokenRevocationAuthenticationProvider implements AuthenticationProvider {
-
-	private OAuth2AuthorizationService authorizationService;
-	private OAuth2TokenRevocationService tokenRevocationService;
+	private final OAuth2AuthorizationService authorizationService;
 
 	/**
 	 * Constructs an {@code OAuth2TokenRevocationAuthenticationProvider} using the provided parameters.
 	 *
 	 * @param authorizationService the authorization service
-	 * @param tokenRevocationService the token revocation service
 	 */
-	public OAuth2TokenRevocationAuthenticationProvider(OAuth2AuthorizationService authorizationService,
-			OAuth2TokenRevocationService tokenRevocationService) {
+	public OAuth2TokenRevocationAuthenticationProvider(OAuth2AuthorizationService authorizationService) {
 		Assert.notNull(authorizationService, "authorizationService cannot be null");
-		Assert.notNull(tokenRevocationService, "tokenRevocationService cannot be null");
 		this.authorizationService = authorizationService;
-		this.tokenRevocationService = tokenRevocationService;
 	}
 
 	@Override
 	public Authentication authenticate(Authentication authentication) throws AuthenticationException {
-		OAuth2TokenRevocationAuthenticationToken tokenRevocationAuthenticationToken =
+		OAuth2TokenRevocationAuthenticationToken tokenRevocationAuthentication =
 				(OAuth2TokenRevocationAuthenticationToken) authentication;
 
 		OAuth2ClientAuthenticationToken clientPrincipal = null;
-		if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(tokenRevocationAuthenticationToken.getPrincipal()
-				.getClass())) {
-			clientPrincipal = (OAuth2ClientAuthenticationToken) tokenRevocationAuthenticationToken.getPrincipal();
+		if (OAuth2ClientAuthenticationToken.class.isAssignableFrom(tokenRevocationAuthentication.getPrincipal().getClass())) {
+			clientPrincipal = (OAuth2ClientAuthenticationToken) tokenRevocationAuthentication.getPrincipal();
 		}
 		if (clientPrincipal == null || !clientPrincipal.isAuthenticated()) {
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
 		}
+		RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
 
-		final RegisteredClient registeredClient = clientPrincipal.getRegisteredClient();
-		final String tokenTypeHint = tokenRevocationAuthenticationToken.getTokenTypeHint();
-		final String token = tokenRevocationAuthenticationToken.getToken();
-		final OAuth2Authorization authorization = authorizationService.findByTokenAndTokenType(token,
-				TokenType.ACCESS_TOKEN);
-
-		OAuth2TokenRevocationAuthenticationToken successfulAuthentication =
-				new OAuth2TokenRevocationAuthenticationToken(token, registeredClient, tokenTypeHint);
+		TokenType tokenType = null;
+		String tokenTypeHint = tokenRevocationAuthentication.getTokenTypeHint();
+		if (StringUtils.hasText(tokenTypeHint)) {
+			if (TokenType.REFRESH_TOKEN.getValue().equals(tokenTypeHint)) {
+				tokenType = TokenType.REFRESH_TOKEN;
+			} else if (TokenType.ACCESS_TOKEN.getValue().equals(tokenTypeHint)) {
+				tokenType = TokenType.ACCESS_TOKEN;
+			} else {
+				// TODO Add OAuth2ErrorCodes.UNSUPPORTED_TOKEN_TYPE
+				throw new OAuth2AuthenticationException(new OAuth2Error("unsupported_token_type"));
+			}
+		}
 
+		OAuth2Authorization authorization = this.authorizationService.findByToken(
+				tokenRevocationAuthentication.getToken(), tokenType);
 		if (authorization == null) {
-			return successfulAuthentication;
+			// Return the authentication request when token not found
+			return tokenRevocationAuthentication;
 		}
 
-		if (!registeredClient.getClientId().equals(authorization.getRegisteredClientId())) {
-			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
+		if (!registeredClient.getId().equals(authorization.getRegisteredClientId())) {
+			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
 		}
 
-		tokenRevocationService.revoke(token, TokenType.ACCESS_TOKEN);
-		return successfulAuthentication;
+		AbstractOAuth2Token token = authorization.getTokens().getToken(tokenRevocationAuthentication.getToken());
+		authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, token);
+		this.authorizationService.save(authorization);
+
+		return new OAuth2TokenRevocationAuthenticationToken(token, clientPrincipal);
 	}
 
 	@Override

+ 31 - 29
core/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationToken.java → oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationToken.java

@@ -18,53 +18,64 @@ package org.springframework.security.oauth2.server.authorization.authentication;
 import org.springframework.lang.Nullable;
 import org.springframework.security.authentication.AbstractAuthenticationToken;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.core.AbstractOAuth2Token;
 import org.springframework.security.oauth2.server.authorization.Version;
-import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.util.Assert;
 
 import java.util.Collections;
 
 /**
- * An {@link Authentication} implementation used for OAuth 2.0 Client Authentication.
+ * An {@link Authentication} implementation used for OAuth 2.0 Token Revocation.
  *
  * @author Vivek Babu
- * @since 0.0.1
+ * @author Joe Grandja
+ * @since 0.0.3
  * @see AbstractAuthenticationToken
- * @see RegisteredClient
  * @see OAuth2TokenRevocationAuthenticationProvider
  */
 public class OAuth2TokenRevocationAuthenticationToken extends AbstractAuthenticationToken {
 	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
+	private final String token;
+	private final Authentication clientPrincipal;
 	private final String tokenTypeHint;
-	private Authentication clientPrincipal;
-	private String token;
-	private RegisteredClient registeredClient;
 
+	/**
+	 * Constructs an {@code OAuth2TokenRevocationAuthenticationToken} using the provided parameters.
+	 *
+	 * @param token the token
+	 * @param clientPrincipal the authenticated client principal
+	 * @param tokenTypeHint the token type hint
+	 */
 	public OAuth2TokenRevocationAuthenticationToken(String token,
 			Authentication clientPrincipal, @Nullable String tokenTypeHint) {
 		super(Collections.emptyList());
-		Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
 		Assert.hasText(token, "token cannot be empty");
+		Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
 		this.token = token;
 		this.clientPrincipal = clientPrincipal;
 		this.tokenTypeHint = tokenTypeHint;
 	}
 
-	public OAuth2TokenRevocationAuthenticationToken(String token,
-			RegisteredClient registeredClient, @Nullable String tokenTypeHint) {
+	/**
+	 * Constructs an {@code OAuth2TokenRevocationAuthenticationToken} using the provided parameters.
+	 *
+	 * @param revokedToken the revoked token
+	 * @param clientPrincipal the authenticated client principal
+	 */
+	public OAuth2TokenRevocationAuthenticationToken(AbstractOAuth2Token revokedToken,
+			Authentication clientPrincipal) {
 		super(Collections.emptyList());
-		Assert.notNull(registeredClient, "registeredClient cannot be null");
-		Assert.hasText(token, "token cannot be empty");
-		this.token = token;
-		this.registeredClient = registeredClient;
-		this.tokenTypeHint = tokenTypeHint;
-		setAuthenticated(true);
+		Assert.notNull(revokedToken, "revokedToken cannot be null");
+		Assert.notNull(clientPrincipal, "clientPrincipal cannot be null");
+		this.token = revokedToken.getTokenValue();
+		this.clientPrincipal = clientPrincipal;
+		this.tokenTypeHint = null;
+		setAuthenticated(true);		// Indicates that the token was authenticated and revoked
 	}
 
 	@Override
 	public Object getPrincipal() {
-		return this.clientPrincipal != null ? this.clientPrincipal : this.registeredClient
-				.getClientId();
+		return this.clientPrincipal;
 	}
 
 	@Override
@@ -86,17 +97,8 @@ public class OAuth2TokenRevocationAuthenticationToken extends AbstractAuthentica
 	 *
 	 * @return the token type hint
 	 */
+	@Nullable
 	public String getTokenTypeHint() {
-		return tokenTypeHint;
-	}
-
-	/**
-	 * Returns the {@link RegisteredClient registered client}.
-	 *
-	 * @return the {@link RegisteredClient}
-	 */
-	public @Nullable
-	RegisteredClient getRegisteredClient() {
-		return this.registeredClient;
+		return this.tokenTypeHint;
 	}
 }

+ 17 - 22
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java

@@ -83,41 +83,36 @@ public class OAuth2Tokens implements Serializable {
 	}
 
 	/**
-	 * Returns the token metadata associated to the provided {@code token}.
+	 * Returns the token specified by {@code token}.
 	 *
 	 * @param token the token
 	 * @param <T> the type of the token
-	 * @return the token metadata, or {@code null} if not available
+	 * @return the token, or {@code null} if not available
 	 */
 	@Nullable
-	public <T extends AbstractOAuth2Token> OAuth2TokenMetadata getTokenMetadata(T token) {
-		Assert.notNull(token, "token cannot be null");
-		OAuth2TokenHolder tokenHolder = this.tokens.get(token.getClass());
-		return (tokenHolder != null && tokenHolder.getToken().equals(token)) ?
-				tokenHolder.getTokenMetadata() : null;
-	}
-
-	/**
-	 * Invalidates all tokens.
-	 */
-	public void invalidate() {
-		this.tokens.values().forEach(tokenHolder -> invalidate(tokenHolder.getToken()));
+	@SuppressWarnings("unchecked")
+	public <T extends AbstractOAuth2Token> T getToken(String token) {
+		Assert.hasText(token, "token cannot be empty");
+		OAuth2TokenHolder tokenHolder = this.tokens.values().stream()
+				.filter(holder -> holder.getToken().getTokenValue().equals(token))
+				.findFirst()
+				.orElse(null);
+		return tokenHolder != null ? (T) tokenHolder.getToken() : null;
 	}
 
 	/**
-	 * Invalidates the token matching the provided {@code token}.
+	 * Returns the token metadata associated to the provided {@code token}.
 	 *
 	 * @param token the token
 	 * @param <T> the type of the token
+	 * @return the token metadata, or {@code null} if not available
 	 */
-	public <T extends AbstractOAuth2Token> void invalidate(T token) {
+	@Nullable
+	public <T extends AbstractOAuth2Token> OAuth2TokenMetadata getTokenMetadata(T token) {
 		Assert.notNull(token, "token cannot be null");
-		this.tokens.computeIfPresent(token.getClass(),
-				(tokenType, tokenHolder) ->
-						new OAuth2TokenHolder(
-								tokenHolder.getToken(),
-								OAuth2TokenMetadata.builder().invalidated().build())
-		);
+		OAuth2TokenHolder tokenHolder = this.tokens.get(token.getClass());
+		return (tokenHolder != null && tokenHolder.getToken().equals(token)) ?
+				tokenHolder.getTokenMetadata() : null;
 	}
 
 	@Override

+ 35 - 32
core/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java → oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java

@@ -27,10 +27,10 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
-import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
-import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationToken;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
+import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
 import org.springframework.util.MultiValueMap;
 import org.springframework.util.StringUtils;
@@ -43,31 +43,30 @@ import javax.servlet.http.HttpServletResponse;
 import java.io.IOException;
 
 /**
- * A {@code Filter} for the OAuth 2.0 Token Revocation,
- * which handles the processing of the OAuth 2.0 Token Revocation Request.
+ * A {@code Filter} for the OAuth 2.0 Token Revocation endpoint.
  *
  * @author Vivek Babu
- * @see OAuth2AuthorizationService
- * @see OAuth2Authorization
+ * @author Joe Grandja
+ * @see OAuth2TokenRevocationAuthenticationProvider
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7009#section-2">Section 2 Token Revocation</a>
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc7009#section-2.1">Section 2.1 Revocation Request</a>
- * @since 0.0.1
+ * @since 0.0.3
  */
 public class OAuth2TokenRevocationEndpointFilter extends OncePerRequestFilter {
+	static final String TOKEN_PARAM_NAME = "token";
+	static final String TOKEN_TYPE_HINT_PARAM_NAME = "token_type_hint";
 
 	/**
-	 * The default endpoint {@code URI} for token revocation request.
+	 * The default endpoint {@code URI} for token revocation requests.
 	 */
 	public static final String DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI = "/oauth2/revoke";
-	private static final String TOKEN_TYPE_HINT = "token_type_hint";
-	private static final String TOKEN = "token";
-	private final AntPathRequestMatcher revocationEndpointMatcher;
 
+	private final AuthenticationManager authenticationManager;
+	private final RequestMatcher tokenRevocationEndpointMatcher;
 	private final Converter<HttpServletRequest, Authentication> tokenRevocationAuthenticationConverter =
-			new OAuth2TokenRevocationEndpointFilter.TokenRevocationAuthenticationConverter();
+			new DefaultTokenRevocationAuthenticationConverter();
 	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
 			new OAuth2ErrorHttpMessageConverter();
-	private final AuthenticationManager authenticationManager;
 
 	/**
 	 * Constructs an {@code OAuth2TokenRevocationEndpointFilter} using the provided parameters.
@@ -82,30 +81,30 @@ public class OAuth2TokenRevocationEndpointFilter extends OncePerRequestFilter {
 	 * Constructs an {@code OAuth2TokenRevocationEndpointFilter} using the provided parameters.
 	 *
 	 * @param authenticationManager the authentication manager
-	 * @param revocationEndpointUri the endpoint {@code URI} for revocation requests
+	 * @param tokenRevocationEndpointUri the endpoint {@code URI} for token revocation requests
 	 */
 	public OAuth2TokenRevocationEndpointFilter(AuthenticationManager authenticationManager,
-			String revocationEndpointUri) {
+			String tokenRevocationEndpointUri) {
 		Assert.notNull(authenticationManager, "authenticationManager cannot be null");
-		Assert.hasText(revocationEndpointUri, "revocationEndpointUri cannot be empty");
+		Assert.hasText(tokenRevocationEndpointUri, "tokenRevocationEndpointUri cannot be empty");
 		this.authenticationManager = authenticationManager;
-		this.revocationEndpointMatcher = new AntPathRequestMatcher(
-				revocationEndpointUri, HttpMethod.POST.name());
+		this.tokenRevocationEndpointMatcher = new AntPathRequestMatcher(
+				tokenRevocationEndpointUri, HttpMethod.POST.name());
 	}
 
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
 
-		if (!this.revocationEndpointMatcher.matches(request)) {
+		if (!this.tokenRevocationEndpointMatcher.matches(request)) {
 			filterChain.doFilter(request, response);
 			return;
 		}
 
 		try {
-			Authentication tokenRevocationRequestAuthentication =
-					this.tokenRevocationAuthenticationConverter.convert(request);
-			this.authenticationManager.authenticate(tokenRevocationRequestAuthentication);
+			this.authenticationManager.authenticate(
+					this.tokenRevocationAuthenticationConverter.convert(request));
+			response.setStatus(HttpStatus.OK.value());
 		} catch (OAuth2AuthenticationException ex) {
 			SecurityContextHolder.clearContext();
 			sendErrorResponse(response, ex.getError());
@@ -118,30 +117,34 @@ public class OAuth2TokenRevocationEndpointFilter extends OncePerRequestFilter {
 		this.errorHttpResponseConverter.write(error, null, httpResponse);
 	}
 
-	private static OAuth2AuthenticationException throwError(String errorCode, String parameterName) {
-		OAuth2Error error = new OAuth2Error(errorCode, "Token Revocation Request Parameter: " + parameterName,
+	private static void throwError(String errorCode, String parameterName) {
+		OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Token Revocation Parameter: " + parameterName,
 				"https://tools.ietf.org/html/rfc7009#section-2.1");
 		throw new OAuth2AuthenticationException(error);
 	}
 
-	private static class TokenRevocationAuthenticationConverter implements
-			Converter<HttpServletRequest, Authentication> {
+	private static class DefaultTokenRevocationAuthenticationConverter
+			implements Converter<HttpServletRequest, Authentication> {
 
 		@Override
 		public Authentication convert(HttpServletRequest request) {
-			MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
-
 			Authentication clientPrincipal = SecurityContextHolder.getContext().getAuthentication();
 
+			MultiValueMap<String, String> parameters = OAuth2EndpointUtils.getParameters(request);
+
 			// token (REQUIRED)
-			String token = parameters.getFirst(TOKEN);
+			String token = parameters.getFirst(TOKEN_PARAM_NAME);
 			if (!StringUtils.hasText(token) ||
-					parameters.get(TOKEN).size() != 1) {
-				throwError(OAuth2ErrorCodes.INVALID_REQUEST, TOKEN);
+					parameters.get(TOKEN_PARAM_NAME).size() != 1) {
+				throwError(OAuth2ErrorCodes.INVALID_REQUEST, TOKEN_PARAM_NAME);
 			}
 
 			// token_type_hint (OPTIONAL)
-			String tokenTypeHint = parameters.getFirst(TOKEN_TYPE_HINT);
+			String tokenTypeHint = parameters.getFirst(TOKEN_TYPE_HINT_PARAM_NAME);
+			if (StringUtils.hasText(tokenTypeHint) &&
+					parameters.get(TOKEN_TYPE_HINT_PARAM_NAME).size() != 1) {
+				throwError(OAuth2ErrorCodes.INVALID_REQUEST, TOKEN_TYPE_HINT_PARAM_NAME);
+			}
 
 			return new OAuth2TokenRevocationAuthenticationToken(token, clientPrincipal, tokenTypeHint);
 		}

+ 188 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java

@@ -0,0 +1,188 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.config.annotation.web.configurers.oauth2.server.authorization;
+
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Import;
+import org.springframework.http.HttpHeaders;
+import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
+import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
+import org.springframework.security.config.test.SpringTestRule;
+import org.springframework.security.crypto.keys.KeyManager;
+import org.springframework.security.crypto.keys.StaticKeyGeneratingKeyManager;
+import org.springframework.security.oauth2.core.AbstractOAuth2Token;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
+import org.springframework.security.oauth2.server.authorization.TokenType;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenRevocationEndpointFilter;
+import org.springframework.test.web.servlet.MockMvc;
+import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
+import org.springframework.util.LinkedMultiValueMap;
+import org.springframework.util.MultiValueMap;
+
+import java.net.URLEncoder;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.reset;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
+
+/**
+ * Integration tests for the OAuth 2.0 Token Revocation endpoint.
+ *
+ * @author Joe Grandja
+ */
+public class OAuth2TokenRevocationTests {
+	private static RegisteredClientRepository registeredClientRepository;
+	private static OAuth2AuthorizationService authorizationService;
+	private static KeyManager keyManager;
+
+	@Rule
+	public final SpringTestRule spring = new SpringTestRule();
+
+	@Autowired
+	private MockMvc mvc;
+
+	@BeforeClass
+	public static void init() {
+		registeredClientRepository = mock(RegisteredClientRepository.class);
+		authorizationService = mock(OAuth2AuthorizationService.class);
+		keyManager = new StaticKeyGeneratingKeyManager();
+	}
+
+	@Before
+	public void setup() {
+		reset(registeredClientRepository);
+		reset(authorizationService);
+	}
+
+	@Test
+	public void requestWhenRevokeRefreshTokenThenRevoked() throws Exception {
+		this.spring.register(AuthorizationServerConfiguration.class).autowire();
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+				.thenReturn(registeredClient);
+
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		OAuth2RefreshToken token = authorization.getTokens().getRefreshToken();
+		TokenType tokenType = TokenType.REFRESH_TOKEN;
+		when(authorizationService.findByToken(eq(token.getTokenValue()), eq(tokenType))).thenReturn(authorization);
+
+		this.mvc.perform(MockMvcRequestBuilders.post(OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI)
+				.params(getTokenRevocationRequestParameters(token, tokenType))
+				.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
+						registeredClient.getClientId(), registeredClient.getClientSecret())))
+				.andExpect(status().isOk());
+
+		verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
+		verify(authorizationService).findByToken(eq(token.getTokenValue()), eq(tokenType));
+
+		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+		verify(authorizationService).save(authorizationCaptor.capture());
+
+		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
+		OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken();
+		assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isTrue();
+		OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken();
+		assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue();
+	}
+
+	@Test
+	public void requestWhenRevokeAccessTokenThenRevoked() throws Exception {
+		this.spring.register(AuthorizationServerConfiguration.class).autowire();
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		when(registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+				.thenReturn(registeredClient);
+
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		OAuth2AccessToken token = authorization.getTokens().getAccessToken();
+		TokenType tokenType = TokenType.ACCESS_TOKEN;
+		when(authorizationService.findByToken(eq(token.getTokenValue()), eq(tokenType))).thenReturn(authorization);
+
+		this.mvc.perform(MockMvcRequestBuilders.post(OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI)
+				.params(getTokenRevocationRequestParameters(token, tokenType))
+				.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
+						registeredClient.getClientId(), registeredClient.getClientSecret())))
+				.andExpect(status().isOk());
+
+		verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
+		verify(authorizationService).findByToken(eq(token.getTokenValue()), eq(tokenType));
+
+		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+		verify(authorizationService).save(authorizationCaptor.capture());
+
+		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
+		OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken();
+		assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue();
+		OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken();
+		assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isFalse();
+	}
+
+	private static MultiValueMap<String, String> getTokenRevocationRequestParameters(AbstractOAuth2Token token, TokenType tokenType) {
+		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
+		// TODO Use OAuth2ParameterNames
+		parameters.set("token", token.getTokenValue());
+		parameters.set("token_type_hint", tokenType.getValue());
+		return parameters;
+	}
+
+	private static String encodeBasicAuth(String clientId, String secret) throws Exception {
+		clientId = URLEncoder.encode(clientId, StandardCharsets.UTF_8.name());
+		secret = URLEncoder.encode(secret, StandardCharsets.UTF_8.name());
+		String credentialsString = clientId + ":" + secret;
+		byte[] encodedBytes = Base64.getEncoder().encode(credentialsString.getBytes(StandardCharsets.UTF_8));
+		return new String(encodedBytes, StandardCharsets.UTF_8);
+	}
+
+	@EnableWebSecurity
+	@Import(OAuth2AuthorizationServerConfiguration.class)
+	static class AuthorizationServerConfiguration {
+
+		@Bean
+		RegisteredClientRepository registeredClientRepository() {
+			return registeredClientRepository;
+		}
+
+		@Bean
+		OAuth2AuthorizationService authorizationService() {
+			return authorizationService;
+		}
+
+		@Bean
+		KeyManager keyManager() {
+			return keyManager;
+		}
+	}
+}

+ 5 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java

@@ -16,6 +16,7 @@
 package org.springframework.security.oauth2.server.authorization;
 
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
@@ -23,6 +24,7 @@ import org.springframework.security.oauth2.server.authorization.token.OAuth2Auth
 import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
 import java.time.Instant;
+import java.time.temporal.ChronoUnit;
 import java.util.Collections;
 import java.util.Map;
 
@@ -46,6 +48,8 @@ public class TestOAuth2Authorizations {
 				"code", Instant.now(), Instant.now().plusSeconds(120));
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(
 				OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(
+				"refresh-token", Instant.now(), Instant.now().plus(1, ChronoUnit.HOURS));
 		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
 				.authorizationUri("https://provider.com/oauth2/authorize")
 				.clientId(registeredClient.getClientId())
@@ -56,7 +60,7 @@ public class TestOAuth2Authorizations {
 				.build();
 		return OAuth2Authorization.withRegisteredClient(registeredClient)
 				.principalName("principal")
-				.tokens(OAuth2Tokens.builder().token(authorizationCode).accessToken(accessToken).build())
+				.tokens(OAuth2Tokens.builder().token(authorizationCode).accessToken(accessToken).refreshToken(refreshToken).build())
 				.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest)
 				.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes());
 	}

+ 4 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java

@@ -38,6 +38,7 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
 import java.time.Instant;
@@ -186,9 +187,10 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
 				AUTHORIZATION_CODE, Instant.now(), Instant.now().plusSeconds(120));
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization()
-				.tokens(OAuth2Tokens.builder().token(authorizationCode).build())
+				.tokens(OAuth2Tokens.builder()
+						.token(authorizationCode, OAuth2TokenMetadata.builder().invalidated().build())
+						.build())
 				.build();
-		authorization.getTokens().invalidate(authorizationCode);
 		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
 				.thenReturn(authorization);
 

+ 85 - 41
core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java → oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java

@@ -17,12 +17,14 @@ package org.springframework.security.oauth2.server.authorization.authentication;
 
 import org.junit.Before;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
-import org.springframework.security.oauth2.server.authorization.OAuth2TokenRevocationService;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
 import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@@ -30,8 +32,10 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -39,38 +43,27 @@ import static org.mockito.Mockito.when;
  * Tests for {@link OAuth2TokenRevocationAuthenticationProvider}.
  *
  * @author Vivek Babu
+ * @author Joe Grandja
  */
 public class OAuth2TokenRevocationAuthenticationProviderTests {
 	private RegisteredClient registeredClient;
-	private OAuth2AuthorizationService oAuth2AuthorizationService;
-	private OAuth2TokenRevocationService oAuth2TokenRevocationService;
+	private OAuth2AuthorizationService authorizationService;
 	private OAuth2TokenRevocationAuthenticationProvider authenticationProvider;
 
 	@Before
 	public void setUp() {
 		this.registeredClient = TestRegisteredClients.registeredClient().build();
-		this.oAuth2AuthorizationService = mock(OAuth2AuthorizationService.class);
-		this.oAuth2TokenRevocationService = mock(OAuth2TokenRevocationService.class);
-		this.authenticationProvider = new OAuth2TokenRevocationAuthenticationProvider(oAuth2AuthorizationService,
-				oAuth2TokenRevocationService);
+		this.authorizationService = mock(OAuth2AuthorizationService.class);
+		this.authenticationProvider = new OAuth2TokenRevocationAuthenticationProvider(this.authorizationService);
 	}
 
 	@Test
 	public void constructorWhenAuthorizationServiceNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationProvider(null,
-				oAuth2TokenRevocationService))
+		assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationProvider(null))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("authorizationService cannot be null");
 	}
 
-	@Test
-	public void constructorWhenRevocationServiceNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationProvider(oAuth2AuthorizationService,
-				null))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("tokenRevocationService cannot be null");
-	}
-
 	@Test
 	public void supportsWhenTypeOAuth2TokenRevocationAuthenticationTokenThenReturnTrue() {
 		assertThat(this.authenticationProvider.supports(OAuth2TokenRevocationAuthenticationToken.class)).isTrue();
@@ -81,7 +74,7 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
 		TestingAuthenticationToken clientPrincipal = new TestingAuthenticationToken(
 				this.registeredClient.getClientId(), this.registeredClient.getClientSecret());
 		OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
-				"token", clientPrincipal, "access_token");
+				"token", clientPrincipal, TokenType.ACCESS_TOKEN.getValue());
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
@@ -92,9 +85,9 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
 	@Test
 	public void authenticateWhenClientPrincipalNotAuthenticatedThenThrowOAuth2AuthenticationException() {
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
-				this.registeredClient.getClientId(), this.registeredClient.getClientSecret());
+				this.registeredClient.getClientId(), this.registeredClient.getClientSecret(), null);
 		OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
-				"token", clientPrincipal, "access_token");
+				"token", clientPrincipal, TokenType.ACCESS_TOKEN.getValue());
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
@@ -103,48 +96,99 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
 	}
 
 	@Test
-	public void authenticateWhenInvalidTokenThenAuthenticate() {
+	public void authenticateWhenInvalidTokenTypeThenThrowOAuth2AuthenticationException() {
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
 		OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
-				"token", clientPrincipal, "access_token");
+				"token", clientPrincipal, "unsupported_token_type");
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo("unsupported_token_type");
+	}
+
+	@Test
+	public void authenticateWhenInvalidTokenThenNotRevoked() {
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
+		OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
+				"token", clientPrincipal, TokenType.ACCESS_TOKEN.getValue());
 		OAuth2TokenRevocationAuthenticationToken authenticationResult =
 				(OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication);
-		assertThat(authenticationResult.isAuthenticated()).isTrue();
-		assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(this.registeredClient.getClientId());
-		assertThat(authenticationResult.getRegisteredClient()).isEqualTo(this.registeredClient);
+		assertThat(authenticationResult.isAuthenticated()).isFalse();
+		verify(this.authorizationService, never()).save(any());
 	}
 
 	@Test
-	public void authenticateWhenAuthorizationIssuedToAnotherClientThenThrowOAuth2AuthenticationException() {
-		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
-		when(this.oAuth2AuthorizationService.findByTokenAndTokenType(eq("token"), eq(TokenType.ACCESS_TOKEN)))
+	public void authenticateWhenTokenIssuedToAnotherClientThenThrowOAuth2AuthenticationException() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(
+				TestRegisteredClients.registeredClient2().build()).build();
+		when(this.authorizationService.findByToken(
+				eq("token"),
+				eq(TokenType.ACCESS_TOKEN)))
 				.thenReturn(authorization);
-		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
-				TestRegisteredClients.registeredClient2().build());
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
 		OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
-				"token", clientPrincipal, "access_token");
+				"token", clientPrincipal, TokenType.ACCESS_TOKEN.getValue());
+
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
 				.extracting("errorCode")
-				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+				.isEqualTo(OAuth2ErrorCodes.INVALID_CLIENT);
 	}
 
 	@Test
-	public void authenticateWhenValidAccessTokenThenInvalidateTokenAndAuthenticate() {
+	public void authenticateWhenValidRefreshTokenThenRevoked() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(
+				this.registeredClient).build();
+		when(this.authorizationService.findByToken(
+				eq(authorization.getTokens().getRefreshToken().getTokenValue()),
+				eq(TokenType.REFRESH_TOKEN)))
+				.thenReturn(authorization);
+
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
 		OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
-				"token", clientPrincipal, "access_token");
-		OAuth2Authorization mockAuthorization = mock(OAuth2Authorization.class);
-		when(oAuth2AuthorizationService.findByTokenAndTokenType(eq("token"), eq(TokenType.ACCESS_TOKEN))).
-				thenReturn(mockAuthorization);
-		when(mockAuthorization.getRegisteredClientId()).thenReturn(this.registeredClient.getClientId());
+				authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal, TokenType.REFRESH_TOKEN.getValue());
+
 		OAuth2TokenRevocationAuthenticationToken authenticationResult =
 				(OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication);
-		verify(this.oAuth2TokenRevocationService).revoke(eq("token"), eq(TokenType.ACCESS_TOKEN));
+		assertThat(authenticationResult.isAuthenticated()).isTrue();
+
+		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+		verify(this.authorizationService).save(authorizationCaptor.capture());
+
+		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
+		OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken();
+		assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isTrue();
+		OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken();
+		assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue();
+	}
+
+	@Test
+	public void authenticateWhenValidAccessTokenThenRevoked() {
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(
+				this.registeredClient).build();
+		when(this.authorizationService.findByToken(
+				eq(authorization.getTokens().getAccessToken().getTokenValue()),
+				eq(TokenType.ACCESS_TOKEN)))
+				.thenReturn(authorization);
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
+		OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
+				authorization.getTokens().getAccessToken().getTokenValue(), clientPrincipal, TokenType.ACCESS_TOKEN.getValue());
 
+		OAuth2TokenRevocationAuthenticationToken authenticationResult =
+				(OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication);
 		assertThat(authenticationResult.isAuthenticated()).isTrue();
-		assertThat(authenticationResult.getPrincipal().toString()).isEqualTo(this.registeredClient.getClientId());
-		assertThat(authenticationResult.getRegisteredClient()).isEqualTo(this.registeredClient);
+
+		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+		verify(this.authorizationService).save(authorizationCaptor.capture());
+
+		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
+		OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken();
+		assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue();
+		OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken();
+		assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isFalse();
 	}
 }

+ 30 - 25
core/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationTokenTests.java → oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationTokenTests.java

@@ -16,10 +16,13 @@
 package org.springframework.security.oauth2.server.authorization.authentication;
 
 import org.junit.Test;
-import org.springframework.security.core.Authentication;
-import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 
+import java.time.Duration;
+import java.time.Instant;
+
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
@@ -27,62 +30,64 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
  * Tests for {@link OAuth2TokenRevocationAuthenticationToken}.
  *
  * @author Vivek Babu
+ * @author Joe Grandja
  */
 public class OAuth2TokenRevocationAuthenticationTokenTests {
-	private OAuth2TokenRevocationAuthenticationToken clientPrincipal = new OAuth2TokenRevocationAuthenticationToken(
-			"Token", TestRegisteredClients.registeredClient().build(), null);
-	private RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+	private String token = "token";
+	private OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+			TestRegisteredClients.registeredClient().build());
+	private String tokenTypeHint = TokenType.ACCESS_TOKEN.getValue();
+	private OAuth2AccessToken accessToken = new OAuth2AccessToken(
+			OAuth2AccessToken.TokenType.BEARER, this.token,
+			Instant.now(), Instant.now().plus(Duration.ofHours(1)));
 
 	@Test
 	public void constructorWhenTokenNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken(null,
-				this.clientPrincipal, "hint"))
+		assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken(null, this.clientPrincipal, this.tokenTypeHint))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("token cannot be empty");
 	}
 
 	@Test
 	public void constructorWhenClientPrincipalNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken("token",
-				(Authentication) null, "hint"))
+		assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken(this.token, null, this.tokenTypeHint))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("clientPrincipal cannot be null");
 	}
 
 	@Test
-	public void constructorWhenTokenNullRegisteredClientPresentThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken(null, registeredClient, "hint"))
+	public void constructorWhenRevokedTokenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken(null, this.clientPrincipal))
 				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("token cannot be empty");
+				.hasMessage("revokedToken cannot be null");
 	}
 
 	@Test
-	public void constructorWhenRegisteredClientNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken("token",
-				(RegisteredClient) null, "hint"))
+	public void constructorWhenRevokedTokenAndClientPrincipalNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2TokenRevocationAuthenticationToken(this.accessToken, null))
 				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("registeredClient cannot be null");
+				.hasMessage("clientPrincipal cannot be null");
 	}
 
 	@Test
-	public void constructorWhenTokenAndClientPrincipalProvidedThenCreated() {
+	public void constructorWhenTokenProvidedThenCreated() {
 		OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
-				"token", this.clientPrincipal, "token_hint");
+				this.token, this.clientPrincipal, this.tokenTypeHint);
+		assertThat(authentication.getToken()).isEqualTo(this.token);
 		assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
+		assertThat(authentication.getTokenTypeHint()).isEqualTo(this.tokenTypeHint);
 		assertThat(authentication.getCredentials().toString()).isEmpty();
-		assertThat(authentication.getToken()).isEqualTo("token");
-		assertThat(authentication.getTokenTypeHint()).isEqualTo("token_hint");
 		assertThat(authentication.isAuthenticated()).isFalse();
 	}
 
 	@Test
-	public void constructorWhenTokenAndRegisteredProvidedThenCreated() {
+	public void constructorWhenRevokedTokenProvidedThenCreated() {
 		OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
-				"token", this.registeredClient, "token_hint");
-		assertThat(authentication.getPrincipal()).isEqualTo(this.registeredClient.getClientId());
+				this.accessToken, this.clientPrincipal);
+		assertThat(authentication.getToken()).isEqualTo(this.accessToken.getTokenValue());
+		assertThat(authentication.getPrincipal()).isEqualTo(this.clientPrincipal);
+		assertThat(authentication.getTokenTypeHint()).isNull();
 		assertThat(authentication.getCredentials().toString()).isEmpty();
-		assertThat(authentication.getToken()).isEqualTo("token");
-		assertThat(authentication.getTokenTypeHint()).isEqualTo("token_hint");
 		assertThat(authentication.isAuthenticated()).isTrue();
 	}
 }

+ 8 - 29
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java

@@ -82,11 +82,18 @@ public class OAuth2TokensTests {
 
 	@Test
 	public void getTokenWhenTokenTypeNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> OAuth2Tokens.builder().build().getToken(null))
+		assertThatThrownBy(() -> OAuth2Tokens.builder().build().getToken((Class<OAuth2AccessToken>) null))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("tokenType cannot be null");
 	}
 
+	@Test
+	public void getTokenWhenTokenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2Tokens.builder().build().getToken((String) null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("token cannot be empty");
+	}
+
 	@Test
 	public void getTokenMetadataWhenTokenNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> OAuth2Tokens.builder().build().getTokenMetadata(null))
@@ -185,32 +192,4 @@ public class OAuth2TokensTests {
 				this.accessToken.getScopes());
 		assertThat(tokens.getTokenMetadata(otherAccessToken)).isNull();
 	}
-
-	@Test
-	public void invalidateWhenAllTokensThenAllInvalidated() {
-		OAuth2Tokens tokens = OAuth2Tokens.builder()
-				.accessToken(this.accessToken)
-				.refreshToken(this.refreshToken)
-				.token(this.idToken)
-				.build();
-		tokens.invalidate();
-
-		assertThat(tokens.getTokenMetadata(tokens.getAccessToken()).isInvalidated()).isTrue();
-		assertThat(tokens.getTokenMetadata(tokens.getRefreshToken()).isInvalidated()).isTrue();
-		assertThat(tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class)).isInvalidated()).isTrue();
-	}
-
-	@Test
-	public void invalidateWhenTokenProvidedThenInvalidated() {
-		OAuth2Tokens tokens = OAuth2Tokens.builder()
-				.accessToken(this.accessToken)
-				.refreshToken(this.refreshToken)
-				.token(this.idToken)
-				.build();
-		tokens.invalidate(this.accessToken);
-
-		assertThat(tokens.getTokenMetadata(tokens.getAccessToken()).isInvalidated()).isTrue();
-		assertThat(tokens.getTokenMetadata(tokens.getRefreshToken()).isInvalidated()).isFalse();
-		assertThat(tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class)).isInvalidated()).isFalse();
-	}
 }

+ 49 - 35
core/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java → oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilterTests.java

@@ -18,7 +18,6 @@ package org.springframework.security.oauth2.server.authorization.web;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
-import org.mockito.ArgumentCaptor;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.converter.HttpMessageConverter;
 import org.springframework.mock.http.client.MockClientHttpResponse;
@@ -28,9 +27,11 @@ import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
+import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@@ -39,6 +40,10 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis
 import javax.servlet.FilterChain;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.HashSet;
 import java.util.function.Consumer;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -48,15 +53,16 @@ import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.when;
+import static org.springframework.security.oauth2.server.authorization.web.OAuth2TokenRevocationEndpointFilter.TOKEN_PARAM_NAME;
+import static org.springframework.security.oauth2.server.authorization.web.OAuth2TokenRevocationEndpointFilter.TOKEN_TYPE_HINT_PARAM_NAME;
 
 /**
  * Tests for {@link OAuth2TokenRevocationEndpointFilter}.
  *
  * @author Vivek Babu
+ * @author Joe Grandja
  */
 public class OAuth2TokenRevocationEndpointFilterTests {
-	private static final String TOKEN = "token";
-	private static final String TOKEN_TYPE_HINT = "token_type_hint";
 	private AuthenticationManager authenticationManager;
 	private OAuth2TokenRevocationEndpointFilter filter;
 	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
@@ -81,14 +87,14 @@ public class OAuth2TokenRevocationEndpointFilterTests {
 	}
 
 	@Test
-	public void constructorWhenTokenEndpointUriNullThenThrowIllegalArgumentException() {
+	public void constructorWhenTokenRevocationEndpointUriNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> new OAuth2TokenRevocationEndpointFilter(this.authenticationManager, null))
 				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("revocationEndpointUri cannot be empty");
+				.hasMessage("tokenRevocationEndpointUri cannot be empty");
 	}
 
 	@Test
-	public void doFilterWhenNotRevocationRequestThenNotProcessed() throws Exception {
+	public void doFilterWhenNotTokenRevocationRequestThenNotProcessed() throws Exception {
 		String requestUri = "/path";
 		MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
 		request.setServletPath(requestUri);
@@ -101,8 +107,8 @@ public class OAuth2TokenRevocationEndpointFilterTests {
 	}
 
 	@Test
-	public void doFilterWhenRevocationRequestGetThenNotProcessed() throws Exception {
-		String requestUri = OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI;
+	public void doFilterWhenTokenRevocationRequestGetThenNotProcessed() throws Exception {
+		String requestUri = OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI;
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
 		request.setServletPath(requestUri);
 		MockHttpServletResponse response = new MockHttpServletResponse();
@@ -114,54 +120,63 @@ public class OAuth2TokenRevocationEndpointFilterTests {
 	}
 
 	@Test
-	public void doFilterWhenRevocationRequestMissingTokenThenInvalidRequestError() throws Exception {
-		doFilterWhenRevocationRequestInvalidParameterThenError(
-				TOKEN, OAuth2ErrorCodes.INVALID_REQUEST,
-				request -> request.removeParameter(TOKEN));
+	public void doFilterWhenTokenRevocationRequestMissingTokenThenInvalidRequestError() throws Exception {
+		doFilterWhenTokenRevocationRequestInvalidParameterThenError(
+				TOKEN_PARAM_NAME,
+				OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> request.removeParameter(TOKEN_PARAM_NAME));
+	}
+
+	@Test
+	public void doFilterWhenTokenRevocationRequestMultipleTokenThenInvalidRequestError() throws Exception {
+		doFilterWhenTokenRevocationRequestInvalidParameterThenError(
+				TOKEN_PARAM_NAME,
+				OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> request.addParameter(TOKEN_PARAM_NAME, "token-2"));
 	}
 
 	@Test
-	public void doFilterWhenRevocationRequestMultipleTokenThenInvalidRequestError() throws Exception {
-		doFilterWhenRevocationRequestInvalidParameterThenError(
-				TOKEN, OAuth2ErrorCodes.INVALID_REQUEST,
-				request -> {
-					request.addParameter(TOKEN, "token-1");
-					request.addParameter(TOKEN, "token-2");
-				});
+	public void doFilterWhenTokenRevocationRequestMultipleTokenTypeHintThenInvalidRequestError() throws Exception {
+		doFilterWhenTokenRevocationRequestInvalidParameterThenError(
+				TOKEN_TYPE_HINT_PARAM_NAME,
+				OAuth2ErrorCodes.INVALID_REQUEST,
+				request -> request.addParameter(TOKEN_TYPE_HINT_PARAM_NAME, TokenType.ACCESS_TOKEN.getValue()));
 	}
 
 	@Test
-	public void doFilterWhenTokenRequestValidThenAccessTokenResponse() throws Exception {
+	public void doFilterWhenTokenRevocationRequestValidThenSuccessResponse() throws Exception {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(
+				OAuth2AccessToken.TokenType.BEARER, "token",
+				Instant.now(), Instant.now().plus(Duration.ofHours(1)),
+				new HashSet<>(Arrays.asList("scope1", "scope2")));
+		OAuth2TokenRevocationAuthenticationToken tokenRevocationAuthentication =
+				new OAuth2TokenRevocationAuthenticationToken(
+						accessToken, clientPrincipal);
 
-		Authentication tokenRevocationAuthenticationSuccess = mock(Authentication.class);
-
-		when(this.authenticationManager.authenticate(any())).thenReturn(tokenRevocationAuthenticationSuccess);
+		when(this.authenticationManager.authenticate(any())).thenReturn(tokenRevocationAuthentication);
 
 		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
 		securityContext.setAuthentication(clientPrincipal);
 		SecurityContextHolder.setContext(securityContext);
 
-		MockHttpServletRequest request = createRevocationRequest();
+		MockHttpServletRequest request = createTokenRevocationRequest();
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
 
 		this.filter.doFilter(request, response, filterChain);
 
 		verifyNoInteractions(filterChain);
-
-		ArgumentCaptor<OAuth2TokenRevocationAuthenticationToken> tokenRevocationAuthenticationCaptor =
-				ArgumentCaptor.forClass(OAuth2TokenRevocationAuthenticationToken.class);
-		verify(this.authenticationManager).authenticate(tokenRevocationAuthenticationCaptor.capture());
+		verify(this.authenticationManager).authenticate(any());
 
 		assertThat(response.getStatus()).isEqualTo(HttpStatus.OK.value());
 	}
 
-	private void doFilterWhenRevocationRequestInvalidParameterThenError(String parameterName, String errorCode,
+	private void doFilterWhenTokenRevocationRequestInvalidParameterThenError(String parameterName, String errorCode,
 			Consumer<MockHttpServletRequest> requestConsumer) throws Exception {
 
-		MockHttpServletRequest request = createRevocationRequest();
+		MockHttpServletRequest request = createTokenRevocationRequest();
 		requestConsumer.accept(request);
 		MockHttpServletResponse response = new MockHttpServletResponse();
 		FilterChain filterChain = mock(FilterChain.class);
@@ -173,7 +188,7 @@ public class OAuth2TokenRevocationEndpointFilterTests {
 		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
 		OAuth2Error error = readError(response);
 		assertThat(error.getErrorCode()).isEqualTo(errorCode);
-		assertThat(error.getDescription()).isEqualTo("Token Revocation Request Parameter: " + parameterName);
+		assertThat(error.getDescription()).isEqualTo("OAuth 2.0 Token Revocation Parameter: " + parameterName);
 	}
 
 	private OAuth2Error readError(MockHttpServletResponse response) throws Exception {
@@ -182,14 +197,13 @@ public class OAuth2TokenRevocationEndpointFilterTests {
 		return this.errorHttpResponseConverter.read(OAuth2Error.class, httpResponse);
 	}
 
-	private static MockHttpServletRequest createRevocationRequest() {
-
+	private static MockHttpServletRequest createTokenRevocationRequest() {
 		String requestUri = OAuth2TokenRevocationEndpointFilter.DEFAULT_TOKEN_REVOCATION_ENDPOINT_URI;
 		MockHttpServletRequest request = new MockHttpServletRequest("POST", requestUri);
 		request.setServletPath(requestUri);
 
-		request.addParameter(TOKEN, "token");
-		request.addParameter(TOKEN_TYPE_HINT, "access_token");
+		request.addParameter(TOKEN_PARAM_NAME, "token");
+		request.addParameter(TOKEN_TYPE_HINT_PARAM_NAME, TokenType.ACCESS_TOKEN.getValue());
 
 		return request;
 	}