Browse Source

Add convenience method for invalidating an OAuth2Token

Closes gh-1717
Joe Grandja 11 months ago
parent
commit
8edbc26b18
13 changed files with 64 additions and 101 deletions
  1. 28 2
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java
  2. 1 31
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthenticationProviderUtils.java
  3. 7 6
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java
  4. 4 7
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceAuthorizationConsentAuthenticationProvider.java
  5. 4 5
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceCodeAuthenticationProvider.java
  6. 2 3
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java
  7. 2 2
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java
  8. 1 31
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcAuthenticationProviderUtils.java
  9. 5 5
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java
  10. 2 2
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenIntrospectionAuthenticationProviderTests.java
  11. 2 2
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientConfigurationAuthenticationProviderTests.java
  12. 2 2
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java
  13. 4 3
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationProviderTests.java

+ 28 - 2
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -479,7 +479,6 @@ public class OAuth2Authorization implements Serializable {
 		 * @return the {@link Builder}
 		 */
 		public <T extends OAuth2Token> Builder token(T token, Consumer<Map<String, Object>> metadataConsumer) {
-
 			Assert.notNull(token, "token cannot be null");
 			Map<String, Object> metadata = Token.defaultMetadata();
 			Token<?> existingToken = this.tokens.get(token.getClass());
@@ -492,6 +491,33 @@ public class OAuth2Authorization implements Serializable {
 			return this;
 		}
 
+		/**
+		 * Invalidates the {@link OAuth2Token token}.
+		 * @param token the token
+		 * @param <T> the type of the token
+		 * @return the {@link Builder}
+		 * @since 1.4
+		 */
+		public <T extends OAuth2Token> Builder invalidate(T token) {
+			Assert.notNull(token, "token cannot be null");
+			if (this.tokens.get(token.getClass()) == null) {
+				return this;
+			}
+			token(token, (metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
+			if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) {
+				Token<?> accessToken = this.tokens.get(OAuth2AccessToken.class);
+				token(accessToken.getToken(),
+						(metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
+
+				Token<?> authorizationCode = this.tokens.get(OAuth2AuthorizationCode.class);
+				if (authorizationCode != null && !authorizationCode.isInvalidated()) {
+					token(authorizationCode.getToken(),
+							(metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
+				}
+			}
+			return this;
+		}
+
 		protected final Builder tokens(Map<Class<? extends OAuth2Token>, Token<?>> tokens) {
 			this.tokens = new HashMap<>(tokens);
 			return this;

+ 1 - 31
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthenticationProviderUtils.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -21,10 +21,8 @@ import org.springframework.security.oauth2.core.ClaimAccessor;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
-import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
-import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
 import org.springframework.security.oauth2.server.authorization.settings.OAuth2TokenFormat;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenContext;
 
@@ -50,34 +48,6 @@ final class OAuth2AuthenticationProviderUtils {
 		throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_CLIENT);
 	}
 
-	static <T extends OAuth2Token> OAuth2Authorization invalidate(OAuth2Authorization authorization, T token) {
-
-		// @formatter:off
-		OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
-				.token(token,
-						(metadata) ->
-								metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
-
-		if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) {
-			authorizationBuilder.token(
-					authorization.getAccessToken().getToken(),
-					(metadata) ->
-							metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
-
-			OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
-					authorization.getToken(OAuth2AuthorizationCode.class);
-			if (authorizationCode != null && !authorizationCode.isInvalidated()) {
-				authorizationBuilder.token(
-						authorizationCode.getToken(),
-						(metadata) ->
-								metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
-			}
-		}
-		// @formatter:on
-
-		return authorizationBuilder.build();
-	}
-
 	static <T extends OAuth2Token> OAuth2AccessToken accessToken(OAuth2Authorization.Builder builder, T token,
 			OAuth2TokenContext accessTokenContext) {
 

+ 7 - 6
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -144,8 +144,9 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
 			if (!authorizationCode.isInvalidated()) {
 				// Invalidate the authorization code given that a different client is
 				// attempting to use it
-				authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization,
-						authorizationCode.getToken());
+				authorization = OAuth2Authorization.from(authorization)
+					.invalidate(authorizationCode.getToken())
+					.build();
 				this.authorizationService.save(authorization);
 				if (this.logger.isWarnEnabled()) {
 					this.logger.warn(LogMessage.format("Invalidated authorization code used by registered client '%s'",
@@ -172,7 +173,7 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
 				if (token != null) {
 					// Invalidate the access (and refresh) token as the client is
 					// attempting to use the authorization code more than once
-					authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, token.getToken());
+					authorization = OAuth2Authorization.from(authorization).invalidate(token.getToken()).build();
 					this.authorizationService.save(authorization);
 					if (this.logger.isWarnEnabled()) {
 						this.logger.warn(LogMessage.format(
@@ -284,10 +285,10 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
 			idToken = null;
 		}
 
-		authorization = authorizationBuilder.build();
-
 		// Invalidate the authorization code as it can only be used once
-		authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode.getToken());
+		authorizationBuilder.invalidate(authorizationCode.getToken());
+
+		authorization = authorizationBuilder.build();
 
 		this.authorizationService.save(authorization);
 

+ 4 - 7
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceAuthorizationConsentAuthenticationProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2023 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -187,10 +187,8 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
 				}
 			}
 			authorization = OAuth2Authorization.from(authorization)
-				.token((deviceCodeToken.getToken()),
-						(metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true))
-				.token((userCodeToken.getToken()),
-						(metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true))
+				.invalidate(deviceCodeToken.getToken())
+				.invalidate(userCodeToken.getToken())
 				.attributes((attrs) -> attrs.remove(OAuth2ParameterNames.STATE))
 				.build();
 			this.authorizationService.save(authorization);
@@ -210,8 +208,7 @@ public final class OAuth2DeviceAuthorizationConsentAuthenticationProvider implem
 
 		authorization = OAuth2Authorization.from(authorization)
 			.authorizedScopes(authorizedScopes)
-			.token((userCodeToken.getToken()),
-					(metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true))
+			.invalidate(userCodeToken.getToken())
 			.attributes((attrs) -> attrs.remove(OAuth2ParameterNames.STATE))
 			.attributes((attrs) -> attrs.remove(OAuth2ParameterNames.SCOPE))
 			.build();

+ 4 - 5
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceCodeAuthenticationProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2023 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -124,7 +124,7 @@ public final class OAuth2DeviceCodeAuthenticationProvider implements Authenticat
 			if (!deviceCode.isInvalidated()) {
 				// Invalidate the device code given that a different client is attempting
 				// to use it
-				authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, deviceCode.getToken());
+				authorization = OAuth2Authorization.from(authorization).invalidate(deviceCode.getToken()).build();
 				this.authorizationService.save(authorization);
 				if (this.logger.isWarnEnabled()) {
 					this.logger.warn(LogMessage.format("Invalidated device code used by registered client '%s'",
@@ -172,7 +172,7 @@ public final class OAuth2DeviceCodeAuthenticationProvider implements Authenticat
 		// restarting to avoid unnecessary polling.
 		if (deviceCode.isExpired()) {
 			// Invalidate the device code
-			authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, deviceCode.getToken());
+			authorization = OAuth2Authorization.from(authorization).invalidate(deviceCode.getToken()).build();
 			this.authorizationService.save(authorization);
 			if (this.logger.isWarnEnabled()) {
 				this.logger.warn(LogMessage.format("Invalidated device code used by registered client '%s'",
@@ -200,8 +200,7 @@ public final class OAuth2DeviceCodeAuthenticationProvider implements Authenticat
 		// @formatter:off
 		OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
 				// Invalidate the device code as it can only be used (successfully) once
-				.token(deviceCode.getToken(), (metadata) ->
-						metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
+				.invalidate(deviceCode.getToken());
 		// @formatter:on
 
 		// ----- Access token -----

+ 2 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2DeviceVerificationAuthenticationProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2023 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -166,8 +166,7 @@ public final class OAuth2DeviceVerificationAuthenticationProvider implements Aut
 		authorization = OAuth2Authorization.from(authorization)
 				.principalName(principal.getName())
 				.authorizedScopes(requestedScopes)
-				.token(userCode.getToken(), (metadata) -> metadata
-						.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true))
+				.invalidate(userCode.getToken())
 				.attribute(Principal.class.getName(), principal)
 				.attributes((attributes) -> attributes.remove(OAuth2ParameterNames.SCOPE))
 				.build();

+ 2 - 2
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -79,7 +79,7 @@ public final class OAuth2TokenRevocationAuthenticationProvider implements Authen
 		}
 
 		OAuth2Authorization.Token<OAuth2Token> token = authorization.getToken(tokenRevocationAuthentication.getToken());
-		authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, token.getToken());
+		authorization = OAuth2Authorization.from(authorization).invalidate(token.getToken()).build();
 		this.authorizationService.save(authorization);
 
 		if (this.logger.isTraceEnabled()) {

+ 1 - 31
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcAuthenticationProviderUtils.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -18,10 +18,8 @@ package org.springframework.security.oauth2.server.authorization.oidc.authentica
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.oauth2.core.ClaimAccessor;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
-import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
-import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCode;
 import org.springframework.security.oauth2.server.authorization.settings.OAuth2TokenFormat;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenContext;
 
@@ -36,34 +34,6 @@ final class OidcAuthenticationProviderUtils {
 	private OidcAuthenticationProviderUtils() {
 	}
 
-	static <T extends OAuth2Token> OAuth2Authorization invalidate(OAuth2Authorization authorization, T token) {
-
-		// @formatter:off
-		OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
-				.token(token,
-						(metadata) ->
-								metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
-
-		if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) {
-			authorizationBuilder.token(
-					authorization.getAccessToken().getToken(),
-					(metadata) ->
-							metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
-
-			OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
-					authorization.getToken(OAuth2AuthorizationCode.class);
-			if (authorizationCode != null && !authorizationCode.isInvalidated()) {
-				authorizationBuilder.token(
-						authorizationCode.getToken(),
-						(metadata) ->
-								metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
-			}
-		}
-		// @formatter:on
-
-		return authorizationBuilder.build();
-	}
-
 	static <T extends OAuth2Token> OAuth2AccessToken accessToken(OAuth2Authorization.Builder builder, T token,
 			OAuth2TokenContext accessTokenContext) {
 

+ 5 - 5
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2023 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -260,12 +260,12 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 		OAuth2Authorization registeredClientAuthorization = registerAccessToken(registeredClient);
 
 		// Invalidate the "initial" access token as it can only be used once
-		authorization = OidcAuthenticationProviderUtils.invalidate(authorization,
-				authorization.getAccessToken().getToken());
+		OAuth2Authorization.Builder builder = OAuth2Authorization.from(authorization)
+			.invalidate(authorization.getAccessToken().getToken());
 		if (authorization.getRefreshToken() != null) {
-			authorization = OidcAuthenticationProviderUtils.invalidate(authorization,
-					authorization.getRefreshToken().getToken());
+			builder.invalidate(authorization.getRefreshToken().getToken());
 		}
+		authorization = builder.build();
 		this.authorizationService.save(authorization);
 
 		if (this.logger.isTraceEnabled()) {

+ 2 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenIntrospectionAuthenticationProviderTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -147,7 +147,7 @@ public class OAuth2TokenIntrospectionAuthenticationProviderTests {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
 		OAuth2AccessToken accessToken = authorization.getAccessToken().getToken();
-		authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, accessToken);
+		authorization = OAuth2Authorization.from(authorization).invalidate(accessToken).build();
 		given(this.authorizationService.findByToken(eq(accessToken.getTokenValue()), isNull()))
 			.willReturn(authorization);
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient,

+ 2 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientConfigurationAuthenticationProviderTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2023 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -176,8 +176,8 @@ public class OidcClientConfigurationAuthenticationProviderTests {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		OAuth2Authorization authorization = TestOAuth2Authorizations
 			.authorization(registeredClient, jwtAccessToken, jwt.getClaims())
+			.invalidate(jwtAccessToken)
 			.build();
-		authorization = OidcAuthenticationProviderUtils.invalidate(authorization, jwtAccessToken);
 		given(this.authorizationService.findByToken(eq(jwtAccessToken.getTokenValue()),
 				eq(OAuth2TokenType.ACCESS_TOKEN)))
 			.willReturn(authorization);

+ 2 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2023 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -250,8 +250,8 @@ public class OidcClientRegistrationAuthenticationProviderTests {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		OAuth2Authorization authorization = TestOAuth2Authorizations
 			.authorization(registeredClient, jwtAccessToken, jwt.getClaims())
+			.invalidate(jwtAccessToken)
 			.build();
-		authorization = OidcAuthenticationProviderUtils.invalidate(authorization, jwtAccessToken);
 		given(this.authorizationService.findByToken(eq(jwtAccessToken.getTokenValue()),
 				eq(OAuth2TokenType.ACCESS_TOKEN)))
 			.willReturn(authorization);

+ 4 - 3
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcUserInfoAuthenticationProviderTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -133,8 +133,9 @@ public class OidcUserInfoAuthenticationProviderTests {
 	public void authenticateWhenAccessTokenNotActiveThenThrowOAuth2AuthenticationException() {
 		String tokenValue = "token";
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
-		authorization = OidcAuthenticationProviderUtils.invalidate(authorization,
-				authorization.getAccessToken().getToken());
+		authorization = OAuth2Authorization.from(authorization)
+			.invalidate(authorization.getAccessToken().getToken())
+			.build();
 		given(this.authorizationService.findByToken(eq(tokenValue), eq(OAuth2TokenType.ACCESS_TOKEN)))
 			.willReturn(authorization);