瀏覽代碼

Improve OAuth2Authorization model

This commit removes OAuth2Tokens and OAuth2TokenMetadata and consolidates the code into OAuth2Authorization.

Closes gh-213
Joe Grandja 4 年之前
父節點
當前提交
bffcbc5440
共有 24 個文件被更改,包括 390 次插入968 次删除
  1. 15 10
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java
  2. 210 57
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java
  3. 20 18
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthenticationProviderUtils.java
  4. 17 18
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java
  5. 2 3
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java
  6. 9 11
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java
  7. 3 3
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java
  8. 0 169
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenMetadata.java
  9. 0 292
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java
  10. 2 3
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java
  11. 7 7
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java
  12. 3 3
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java
  13. 10 10
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java
  14. 3 3
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java
  15. 11 10
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java
  16. 16 13
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java
  17. 3 2
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java
  18. 16 19
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java
  19. 3 3
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java
  20. 24 30
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java
  21. 13 12
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java
  22. 0 74
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenMetadataTests.java
  23. 0 195
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java
  24. 3 3
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

+ 15 - 10
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java

@@ -15,15 +15,17 @@
  */
 package org.springframework.security.oauth2.server.authorization;
 
-import org.springframework.lang.Nullable;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
-import org.springframework.util.Assert;
-
 import java.io.Serializable;
 import java.util.Map;
 import java.util.Objects;
 import java.util.concurrent.ConcurrentHashMap;
 
+import org.springframework.lang.Nullable;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
+import org.springframework.util.Assert;
+
 /**
  * An {@link OAuth2AuthorizationService} that stores {@link OAuth2Authorization}'s in-memory.
  *
@@ -87,18 +89,21 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
 	}
 
 	private static boolean matchesAuthorizationCode(OAuth2Authorization authorization, String token) {
-		OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
-		return authorizationCode != null && authorizationCode.getTokenValue().equals(token);
+		OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
+				authorization.getToken(OAuth2AuthorizationCode.class);
+		return authorizationCode != null && authorizationCode.getToken().getTokenValue().equals(token);
 	}
 
 	private static boolean matchesAccessToken(OAuth2Authorization authorization, String token) {
-		return authorization.getTokens().getAccessToken() != null &&
-				authorization.getTokens().getAccessToken().getTokenValue().equals(token);
+		OAuth2Authorization.Token<OAuth2AccessToken> accessToken =
+				authorization.getToken(OAuth2AccessToken.class);
+		return accessToken != null && accessToken.getToken().getTokenValue().equals(token);
 	}
 
 	private static boolean matchesRefreshToken(OAuth2Authorization authorization, String token) {
-		return authorization.getTokens().getRefreshToken() != null &&
-				authorization.getTokens().getRefreshToken().getTokenValue().equals(token);
+		OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken =
+				authorization.getToken(OAuth2RefreshToken.class);
+		return refreshToken != null && refreshToken.getToken().getTokenValue().equals(token);
 	}
 
 	private static class OAuth2AuthorizationId implements Serializable {

+ 210 - 57
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,11 +15,6 @@
  */
 package org.springframework.security.oauth2.server.authorization;
 
-import org.springframework.security.oauth2.core.OAuth2AccessToken;
-import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
-import org.springframework.util.Assert;
-
 import java.io.Serializable;
 import java.util.Collections;
 import java.util.HashMap;
@@ -27,26 +22,32 @@ import java.util.Map;
 import java.util.Objects;
 import java.util.function.Consumer;
 
+import org.springframework.lang.Nullable;
+import org.springframework.security.oauth2.core.AbstractOAuth2Token;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.util.Assert;
+
 /**
- * A representation of an OAuth 2.0 Authorization,
- * which holds state related to the authorization granted to the {@link #getRegisteredClientId() client}
- * by the {@link #getPrincipalName() resource owner}.
+ * A representation of an OAuth 2.0 Authorization, which holds state related to the authorization granted
+ * to a {@link #getRegisteredClientId() client}, by the {@link #getPrincipalName() resource owner}
+ * or itself in the case of the {@code client_credentials} grant type.
  *
  * @author Joe Grandja
  * @author Krisztian Toth
  * @since 0.0.1
  * @see RegisteredClient
- * @see OAuth2Tokens
+ * @see AbstractOAuth2Token
+ * @see OAuth2AccessToken
+ * @see OAuth2RefreshToken
  */
 public class OAuth2Authorization implements Serializable {
 	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
 	private String registeredClientId;
 	private String principalName;
-	private OAuth2Tokens tokens;
-
-	@Deprecated
-	private OAuth2AccessToken accessToken;
-
+	private Map<Class<? extends AbstractOAuth2Token>, Token<?>> tokens;
 	private Map<String, Object> attributes;
 
 	protected OAuth2Authorization() {
@@ -62,31 +63,64 @@ public class OAuth2Authorization implements Serializable {
 	}
 
 	/**
-	 * Returns the resource owner's {@code Principal} name.
+	 * Returns the {@code Principal} name of the resource owner (or client).
 	 *
-	 * @return the resource owner's {@code Principal} name
+	 * @return the {@code Principal} name of the resource owner (or client)
 	 */
 	public String getPrincipalName() {
 		return this.principalName;
 	}
 
 	/**
-	 * Returns the {@link OAuth2Tokens}.
+	 * Returns the {@link Token} of type {@link OAuth2AccessToken}.
 	 *
-	 * @return the {@link OAuth2Tokens}
+	 * @return the {@link Token} of type {@link OAuth2AccessToken}
 	 */
-	public OAuth2Tokens getTokens() {
-		return this.tokens;
+	public Token<OAuth2AccessToken> getAccessToken() {
+		return getToken(OAuth2AccessToken.class);
 	}
 
 	/**
-	 * Returns the {@link OAuth2AccessToken access token} credential.
+	 * Returns the {@link Token} of type {@link OAuth2RefreshToken}.
 	 *
-	 * @return the {@link OAuth2AccessToken}
+	 * @return the {@link Token} of type {@link OAuth2RefreshToken}, or {@code null} if not available
 	 */
-	@Deprecated
-	public OAuth2AccessToken getAccessToken() {
-		return getTokens().getAccessToken();
+	@Nullable
+	public Token<OAuth2RefreshToken> getRefreshToken() {
+		return getToken(OAuth2RefreshToken.class);
+	}
+
+	/**
+	 * Returns the {@link Token} of type {@code tokenType}.
+	 *
+	 * @param tokenType the token type
+	 * @param <T> the type of the token
+	 * @return the {@link Token}, or {@code null} if not available
+	 */
+	@Nullable
+	@SuppressWarnings("unchecked")
+	public <T extends AbstractOAuth2Token> Token<T> getToken(Class<T> tokenType) {
+		Assert.notNull(tokenType, "tokenType cannot be null");
+		Token<?> token = this.tokens.get(tokenType);
+		return token != null ? (Token<T>) token : null;
+	}
+
+	/**
+	 * Returns the {@link Token} matching the {@code tokenValue}.
+	 *
+	 * @param tokenValue the token value
+	 * @param <T> the type of the token
+	 * @return the {@link Token}, or {@code null} if not available
+	 */
+	@Nullable
+	@SuppressWarnings("unchecked")
+	public <T extends AbstractOAuth2Token> Token<T> getToken(String tokenValue) {
+		Assert.hasText(tokenValue, "tokenValue cannot be empty");
+		Token<?> token = this.tokens.values().stream()
+				.filter(t -> t.getToken().getTokenValue().equals(tokenValue))
+				.findFirst()
+				.orElse(null);
+		return token != null ? (Token<T>) token : null;
 	}
 
 	/**
@@ -103,8 +137,9 @@ public class OAuth2Authorization implements Serializable {
 	 *
 	 * @param name the name of the attribute
 	 * @param <T> the type of the attribute
-	 * @return the value of the attribute associated to the authorization, or {@code null} if not available
+	 * @return the value of an attribute associated to the authorization, or {@code null} if not available
 	 */
+	@Nullable
 	@SuppressWarnings("unchecked")
 	public <T> T getAttribute(String name) {
 		Assert.hasText(name, "name cannot be empty");
@@ -143,41 +178,131 @@ public class OAuth2Authorization implements Serializable {
 	}
 
 	/**
-	 * Returns a new {@link Builder}, initialized with the values from the provided {@code authorization}.
+	 * Returns a new {@link Builder}, initialized with the values from the provided {@code OAuth2Authorization}.
 	 *
-	 * @param authorization the authorization used for initializing the {@link Builder}
+	 * @param authorization the {@code OAuth2Authorization} used for initializing the {@link Builder}
 	 * @return the {@link Builder}
 	 */
 	public static Builder from(OAuth2Authorization authorization) {
 		Assert.notNull(authorization, "authorization cannot be null");
 		return new Builder(authorization.getRegisteredClientId())
 				.principalName(authorization.getPrincipalName())
-				.tokens(OAuth2Tokens.from(authorization.getTokens()).build())
+				.tokens(authorization.tokens)
 				.attributes(attrs -> attrs.putAll(authorization.getAttributes()));
 	}
 
+	/**
+	 * A holder of an OAuth 2.0 Token and it's associated metadata.
+	 *
+	 * @author Joe Grandja
+	 * @since 0.1.0
+	 */
+	public static class Token<T extends AbstractOAuth2Token> implements Serializable {
+		private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
+		protected static final String TOKEN_METADATA_BASE = "metadata.token.";
+
+		/**
+		 * The name of the metadata that indicates if the token has been invalidated.
+		 */
+		public static final String INVALIDATED_METADATA_NAME = TOKEN_METADATA_BASE.concat("invalidated");
+
+		private final T token;
+		private final Map<String, Object> metadata;
+
+		protected Token(T token) {
+			this(token, defaultMetadata());
+		}
+
+		protected Token(T token, Map<String, Object> metadata) {
+			this.token = token;
+			this.metadata = Collections.unmodifiableMap(metadata);
+		}
+
+		/**
+		 * Returns the token of type {@link AbstractOAuth2Token}.
+		 *
+		 * @return the token of type {@link AbstractOAuth2Token}
+		 */
+		public T getToken() {
+			return this.token;
+		}
+
+		/**
+		 * Returns {@code true} if the token has been invalidated (e.g. revoked).
+		 * The default is {@code false}.
+		 *
+		 * @return {@code true} if the token has been invalidated, {@code false} otherwise
+		 */
+		public boolean isInvalidated() {
+			return Boolean.TRUE.equals(getMetadata(INVALIDATED_METADATA_NAME));
+		}
+
+		/**
+		 * Returns the value of the metadata associated to the token.
+		 *
+		 * @param name the name of the metadata
+		 * @param <V> the value type of the metadata
+		 * @return the value of the metadata, or {@code null} if not available
+		 */
+		@Nullable
+		@SuppressWarnings("unchecked")
+		public <V> V getMetadata(String name) {
+			Assert.hasText(name, "name cannot be empty");
+			return (V) this.metadata.get(name);
+		}
+
+		/**
+		 * Returns the metadata associated to the token.
+		 *
+		 * @return a {@code Map} of the metadata
+		 */
+		public Map<String, Object> getMetadata() {
+			return this.metadata;
+		}
+
+		protected static Map<String, Object> defaultMetadata() {
+			Map<String, Object> metadata = new HashMap<>();
+			metadata.put(INVALIDATED_METADATA_NAME, false);
+			return metadata;
+		}
+
+		@Override
+		public boolean equals(Object obj) {
+			if (this == obj) {
+				return true;
+			}
+			if (obj == null || getClass() != obj.getClass()) {
+				return false;
+			}
+			Token<?> that = (Token<?>) obj;
+			return Objects.equals(this.token, that.token) &&
+					Objects.equals(this.metadata, that.metadata);
+		}
+
+		@Override
+		public int hashCode() {
+			return Objects.hash(this.token, this.metadata);
+		}
+	}
+
 	/**
 	 * A builder for {@link OAuth2Authorization}.
 	 */
 	public static class Builder implements Serializable {
 		private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
-		private String registeredClientId;
+		private final String registeredClientId;
 		private String principalName;
-		private OAuth2Tokens tokens;
-
-		@Deprecated
-		private OAuth2AccessToken accessToken;
-
-		private Map<String, Object> attributes = new HashMap<>();
+		private Map<Class<? extends AbstractOAuth2Token>, Token<?>> tokens = new HashMap<>();
+		private final Map<String, Object> attributes = new HashMap<>();
 
 		protected Builder(String registeredClientId) {
 			this.registeredClientId = registeredClientId;
 		}
 
 		/**
-		 * Sets the resource owner's {@code Principal} name.
+		 * Sets the {@code Principal} name of the resource owner (or client).
 		 *
-		 * @param principalName the resource owner's {@code Principal} name
+		 * @param principalName the {@code Principal} name of the resource owner (or client)
 		 * @return the {@link Builder}
 		 */
 		public Builder principalName(String principalName) {
@@ -186,25 +311,60 @@ public class OAuth2Authorization implements Serializable {
 		}
 
 		/**
-		 * Sets the {@link OAuth2Tokens}.
+		 * Sets the {@link OAuth2AccessToken access token}.
 		 *
-		 * @param tokens the {@link OAuth2Tokens}
+		 * @param accessToken the {@link OAuth2AccessToken}
 		 * @return the {@link Builder}
 		 */
-		public Builder tokens(OAuth2Tokens tokens) {
-			this.tokens = tokens;
-			return this;
+		public Builder accessToken(OAuth2AccessToken accessToken) {
+			return token(accessToken);
 		}
 
 		/**
-		 * Sets the {@link OAuth2AccessToken access token} credential.
+		 * Sets the {@link OAuth2RefreshToken refresh token}.
 		 *
-		 * @param accessToken the {@link OAuth2AccessToken}
+		 * @param refreshToken the {@link OAuth2RefreshToken}
 		 * @return the {@link Builder}
 		 */
-		@Deprecated
-		public Builder accessToken(OAuth2AccessToken accessToken) {
-			this.accessToken = accessToken;
+		public Builder refreshToken(OAuth2RefreshToken refreshToken) {
+			return token(refreshToken);
+		}
+
+		/**
+		 * Sets the {@link AbstractOAuth2Token token}.
+		 *
+		 * @param token the token
+		 * @param <T> the type of the token
+		 * @return the {@link Builder}
+		 */
+		public <T extends AbstractOAuth2Token> Builder token(T token) {
+			return token(token, (metadata) -> {});
+		}
+
+		/**
+		 * Sets the {@link AbstractOAuth2Token token} and associated metadata.
+		 *
+		 * @param token the token
+		 * @param metadataConsumer a {@code Consumer} of the metadata {@code Map}
+		 * @param <T> the type of the token
+		 * @return the {@link Builder}
+		 */
+		public <T extends AbstractOAuth2Token> Builder token(T token,
+				Consumer<Map<String, Object>> metadataConsumer) {
+
+			Assert.notNull(token, "token cannot be null");
+			Map<String, Object> metadata = Token.defaultMetadata();
+			metadataConsumer.accept(metadata);
+			Class<? extends AbstractOAuth2Token> tokenClass = token.getClass();
+			if (tokenClass.equals(OAuth2RefreshToken2.class)) {
+				tokenClass = OAuth2RefreshToken.class;
+			}
+			this.tokens.put(tokenClass, new Token<>(token, metadata));
+			return this;
+		}
+
+		protected final Builder tokens(Map<Class<? extends AbstractOAuth2Token>, Token<?>> tokens) {
+			this.tokens = new HashMap<>(tokens);
 			return this;
 		}
 
@@ -245,14 +405,7 @@ public class OAuth2Authorization implements Serializable {
 			OAuth2Authorization authorization = new OAuth2Authorization();
 			authorization.registeredClientId = this.registeredClientId;
 			authorization.principalName = this.principalName;
-			if (this.tokens == null) {
-				OAuth2Tokens.Builder builder = OAuth2Tokens.builder();
-				if (this.accessToken != null) {
-					builder.accessToken(this.accessToken);
-				}
-				this.tokens = builder.build();
-			}
-			authorization.tokens = this.tokens;
+			authorization.tokens = Collections.unmodifiableMap(this.tokens);
 			authorization.attributes = Collections.unmodifiableMap(this.attributes);
 			return authorization;
 		}

+ 20 - 18
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthenticationProviderUtils.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,8 +24,6 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
 /**
  * Utility methods for the OAuth 2.0 {@link AuthenticationProvider}'s.
@@ -52,25 +50,29 @@ final class OAuth2AuthenticationProviderUtils {
 	static <T extends AbstractOAuth2Token> OAuth2Authorization invalidate(
 			OAuth2Authorization authorization, T token) {
 
-		OAuth2Tokens.Builder builder = OAuth2Tokens.from(authorization.getTokens())
-				.token(token, OAuth2TokenMetadata.builder().invalidated().build());
+		// @formatter:off
+		OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
+				.token(token,
+						(metadata) ->
+								metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
 
 		if (OAuth2RefreshToken.class.isAssignableFrom(token.getClass())) {
-			builder.token(
-					authorization.getTokens().getAccessToken(),
-					OAuth2TokenMetadata.builder().invalidated().build());
-			OAuth2AuthorizationCode authorizationCode =
-					authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
-			if (authorizationCode != null &&
-					!authorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()) {
-				builder.token(
-						authorizationCode,
-						OAuth2TokenMetadata.builder().invalidated().build());
+			authorizationBuilder.token(
+					authorization.getAccessToken().getToken(),
+					(metadata) ->
+							metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
+
+			OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
+					authorization.getToken(OAuth2AuthorizationCode.class);
+			if (authorizationCode != null && !authorizationCode.isInvalidated()) {
+				authorizationBuilder.token(
+						authorizationCode.getToken(),
+						(metadata) ->
+								metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true));
 			}
 		}
+		// @formatter:on
 
-		return OAuth2Authorization.from(authorization)
-				.tokens(builder.build())
-				.build();
+		return authorizationBuilder.build();
 	}
 }

+ 17 - 18
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -37,16 +37,14 @@ import org.springframework.security.oauth2.jwt.JoseHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
-import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 
@@ -104,16 +102,16 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 		if (authorization == null) {
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
-		OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
-		OAuth2TokenMetadata authorizationCodeMetadata = authorization.getTokens().getTokenMetadata(authorizationCode);
+		OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
+				authorization.getToken(OAuth2AuthorizationCode.class);
 
 		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
 				OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
 
 		if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) {
-			if (!authorizationCodeMetadata.isInvalidated()) {
+			if (!authorizationCode.isInvalidated()) {
 				// Invalidate the authorization code given that a different client is attempting to use it
-				authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode);
+				authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode.getToken());
 				this.authorizationService.save(authorization);
 			}
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
@@ -124,7 +122,7 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
 
-		if (authorizationCodeMetadata.isInvalidated()) {
+		if (authorizationCode.isInvalidated()) {
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
 
@@ -143,14 +141,11 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
 				jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE));
-		OAuth2Tokens.Builder tokensBuilder = OAuth2Tokens.from(authorization.getTokens())
-				.accessToken(accessToken);
 
 		OAuth2RefreshToken refreshToken = null;
 		if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) {
 			refreshToken = OAuth2RefreshTokenAuthenticationProvider.generateRefreshToken(
 					registeredClient.getTokenSettings().refreshTokenTimeToLive());
-			tokensBuilder.refreshToken(refreshToken);
 		}
 
 		OidcIdToken idToken = null;
@@ -170,17 +165,21 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 
 			idToken = new OidcIdToken(jwtIdToken.getTokenValue(), jwtIdToken.getIssuedAt(),
 					jwtIdToken.getExpiresAt(), jwtIdToken.getClaims());
-			tokensBuilder.token(idToken);
 		}
 
-		OAuth2Tokens tokens = tokensBuilder.build();
-		authorization = OAuth2Authorization.from(authorization)
-				.tokens(tokens)
-				.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt)
-				.build();
+		OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
+				.accessToken(accessToken)
+				.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt);
+		if (refreshToken != null) {
+			authorizationBuilder.refreshToken(refreshToken);
+		}
+		if (idToken != null) {
+			authorizationBuilder.token(idToken);
+		}
+		authorization = authorizationBuilder.build();
 
 		// Invalidate the authorization code as it can only be used once
-		authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode);
+		authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode.getToken());
 
 		this.authorizationService.save(authorization);
 

+ 2 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java

@@ -32,13 +32,12 @@ import org.springframework.security.oauth2.jwt.JoseHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
-import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 
@@ -125,7 +124,7 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica
 
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient)
 				.principalName(clientPrincipal.getName())
-				.tokens(OAuth2Tokens.builder().accessToken(accessToken).build())
+				.token(accessToken)
 				.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt)
 				.build();
 		this.authorizationService.save(authorization);

+ 9 - 11
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java

@@ -37,16 +37,14 @@ import org.springframework.security.oauth2.jwt.JoseHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
-import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 import org.springframework.util.Assert;
 
 import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient;
@@ -114,7 +112,8 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.UNAUTHORIZED_CLIENT));
 		}
 
-		Instant refreshTokenExpiresAt = authorization.getTokens().getRefreshToken().getExpiresAt();
+		OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = authorization.getRefreshToken();
+		Instant refreshTokenExpiresAt = refreshToken.getToken().getExpiresAt();
 		if (refreshTokenExpiresAt.isBefore(Instant.now())) {
 			// As per https://tools.ietf.org/html/rfc6749#section-5.2
 			// invalid_grant: The provided authorization grant (e.g., authorization code,
@@ -134,10 +133,7 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP
 			scopes = authorizedScopes;
 		}
 
-		OAuth2RefreshToken refreshToken = authorization.getTokens().getRefreshToken();
-		OAuth2TokenMetadata refreshTokenMetadata = authorization.getTokens().getTokenMetadata(refreshToken);
-
-		if (refreshTokenMetadata.isInvalidated()) {
+		if (refreshToken.isInvalidated()) {
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
 
@@ -159,18 +155,20 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP
 
 		TokenSettings tokenSettings = registeredClient.getTokenSettings();
 
+		OAuth2RefreshToken currentRefreshToken = refreshToken.getToken();
 		if (!tokenSettings.reuseRefreshTokens()) {
-			refreshToken = generateRefreshToken(tokenSettings.refreshTokenTimeToLive());
+			currentRefreshToken = generateRefreshToken(tokenSettings.refreshTokenTimeToLive());
 		}
 
 		authorization = OAuth2Authorization.from(authorization)
-				.tokens(OAuth2Tokens.from(authorization.getTokens()).accessToken(accessToken).refreshToken(refreshToken).build())
+				.accessToken(accessToken)
+				.refreshToken(currentRefreshToken)
 				.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt)
 				.build();
 		this.authorizationService.save(authorization);
 
 		return new OAuth2AccessTokenAuthenticationToken(
-				registeredClient, clientPrincipal, accessToken, refreshToken);
+				registeredClient, clientPrincipal, accessToken, currentRefreshToken);
 	}
 
 	@Override

+ 3 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProvider.java

@@ -72,11 +72,11 @@ public class OAuth2TokenRevocationAuthenticationProvider implements Authenticati
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_CLIENT));
 		}
 
-		AbstractOAuth2Token token = authorization.getTokens().getToken(tokenRevocationAuthentication.getToken());
-		authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, token);
+		OAuth2Authorization.Token<AbstractOAuth2Token> token = authorization.getToken(tokenRevocationAuthentication.getToken());
+		authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, token.getToken());
 		this.authorizationService.save(authorization);
 
-		return new OAuth2TokenRevocationAuthenticationToken(token, clientPrincipal);
+		return new OAuth2TokenRevocationAuthenticationToken(token.getToken(), clientPrincipal);
 	}
 
 	@Override

+ 0 - 169
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenMetadata.java

@@ -1,169 +0,0 @@
-/*
- * Copyright 2020 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.server.authorization.token;
-
-import org.springframework.security.oauth2.server.authorization.Version;
-import org.springframework.util.Assert;
-
-import java.io.Serializable;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Objects;
-import java.util.function.Consumer;
-
-/**
- * Holds metadata associated to an OAuth 2.0 Token.
- *
- * @author Joe Grandja
- * @since 0.0.3
- * @see OAuth2Tokens
- */
-public class OAuth2TokenMetadata implements Serializable {
-	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
-	protected static final String TOKEN_METADATA_BASE = "metadata.token.";
-
-	/**
-	 * The name of the metadata that indicates if the token has been invalidated.
-	 */
-	public static final String INVALIDATED = TOKEN_METADATA_BASE.concat("invalidated");
-
-	private final Map<String, Object> metadata;
-
-	protected OAuth2TokenMetadata(Map<String, Object> metadata) {
-		this.metadata = Collections.unmodifiableMap(new HashMap<>(metadata));
-	}
-
-	/**
-	 * Returns {@code true} if the token has been invalidated (e.g. revoked).
-	 * The default is {@code false}.
-	 *
-	 * @return {@code true} if the token has been invalidated, {@code false} otherwise
-	 */
-	public boolean isInvalidated() {
-		return getMetadata(INVALIDATED);
-	}
-
-	/**
-	 * Returns the value of the metadata associated to the token.
-	 *
-	 * @param name the name of the metadata
-	 * @param <T> the type of the metadata
-	 * @return the value of the metadata, or {@code null} if not available
-	 */
-	@SuppressWarnings("unchecked")
-	public <T> T getMetadata(String name) {
-		Assert.hasText(name, "name cannot be empty");
-		return (T) this.metadata.get(name);
-	}
-
-	/**
-	 * Returns the metadata associated to the token.
-	 *
-	 * @return a {@code Map} of the metadata
-	 */
-	public Map<String, Object> getMetadata() {
-		return this.metadata;
-	}
-
-	@Override
-	public boolean equals(Object obj) {
-		if (this == obj) {
-			return true;
-		}
-		if (obj == null || getClass() != obj.getClass()) {
-			return false;
-		}
-		OAuth2TokenMetadata that = (OAuth2TokenMetadata) obj;
-		return Objects.equals(this.metadata, that.metadata);
-	}
-
-	@Override
-	public int hashCode() {
-		return Objects.hash(this.metadata);
-	}
-
-	/**
-	 * Returns a new {@link Builder}.
-	 *
-	 * @return the {@link Builder}
-	 */
-	public static Builder builder() {
-		return new Builder();
-	}
-
-	/**
-	 * A builder for {@link OAuth2TokenMetadata}.
-	 */
-	public static class Builder implements Serializable {
-		private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
-		private final Map<String, Object> metadata = defaultMetadata();
-
-		protected Builder() {
-		}
-
-		/**
-		 * Set the token as invalidated (e.g. revoked).
-		 *
-		 * @return the {@link Builder}
-		 */
-		public Builder invalidated() {
-			metadata(INVALIDATED, true);
-			return this;
-		}
-
-		/**
-		 * Adds a metadata associated to the token.
-		 *
-		 * @param name the name of the metadata
-		 * @param value the value of the metadata
-		 * @return the {@link Builder}
-		 */
-		public Builder metadata(String name, Object value) {
-			Assert.hasText(name, "name cannot be empty");
-			Assert.notNull(value, "value cannot be null");
-			this.metadata.put(name, value);
-			return this;
-		}
-
-		/**
-		 * A {@code Consumer} of the metadata {@code Map}
-		 * allowing the ability to add, replace, or remove.
-		 *
-		 * @param metadataConsumer a {@link Consumer} of the metadata {@code Map}
-		 * @return the {@link Builder}
-		 */
-		public Builder metadata(Consumer<Map<String, Object>> metadataConsumer) {
-			metadataConsumer.accept(this.metadata);
-			return this;
-		}
-
-		/**
-		 * Builds a new {@link OAuth2TokenMetadata}.
-		 *
-		 * @return the {@link OAuth2TokenMetadata}
-		 */
-		public OAuth2TokenMetadata build() {
-			return new OAuth2TokenMetadata(this.metadata);
-		}
-
-		protected static Map<String, Object> defaultMetadata() {
-			Map<String, Object> metadata = new HashMap<>();
-			metadata.put(INVALIDATED, false);
-			return metadata;
-		}
-	}
-}

+ 0 - 292
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java

@@ -1,292 +0,0 @@
-/*
- * Copyright 2020 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.server.authorization.token;
-
-import org.springframework.lang.Nullable;
-import org.springframework.security.oauth2.core.AbstractOAuth2Token;
-import org.springframework.security.oauth2.core.OAuth2AccessToken;
-import org.springframework.security.oauth2.core.OAuth2RefreshToken;
-import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
-import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
-import org.springframework.security.oauth2.server.authorization.Version;
-import org.springframework.util.Assert;
-
-import java.io.Serializable;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Objects;
-
-/**
- * A container for OAuth 2.0 Tokens.
- *
- * @author Joe Grandja
- * @since 0.0.3
- * @see OAuth2Authorization
- * @see OAuth2TokenMetadata
- * @see AbstractOAuth2Token
- * @see OAuth2AccessToken
- * @see OAuth2RefreshToken
- */
-public class OAuth2Tokens implements Serializable {
-	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
-	private final Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens;
-
-	protected OAuth2Tokens(Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens) {
-		this.tokens = new HashMap<>(tokens);
-	}
-
-	/**
-	 * Returns the {@link OAuth2AccessToken access token}.
-	 *
-	 * @return the {@link OAuth2AccessToken}, or {@code null} if not available
-	 */
-	@Nullable
-	public OAuth2AccessToken getAccessToken() {
-		return getToken(OAuth2AccessToken.class);
-	}
-
-	/**
-	 * Returns the {@link OAuth2RefreshToken refresh token}.
-	 *
-	 * @return the {@link OAuth2RefreshToken}, or {@code null} if not available
-	 */
-	@Nullable
-	public OAuth2RefreshToken getRefreshToken() {
-		OAuth2RefreshToken refreshToken = getToken(OAuth2RefreshToken.class);
-		return refreshToken != null ? refreshToken : getToken(OAuth2RefreshToken2.class);
-	}
-
-	/**
-	 * Returns the token specified by {@code tokenType}.
-	 *
-	 * @param tokenType the token type
-	 * @param <T> the type of the token
-	 * @return the token, or {@code null} if not available
-	 */
-	@Nullable
-	@SuppressWarnings("unchecked")
-	public <T extends AbstractOAuth2Token> T getToken(Class<T> tokenType) {
-		Assert.notNull(tokenType, "tokenType cannot be null");
-		OAuth2TokenHolder tokenHolder = this.tokens.get(tokenType);
-		return tokenHolder != null ? (T) tokenHolder.getToken() : null;
-	}
-
-	/**
-	 * Returns the token specified by {@code token}.
-	 *
-	 * @param token the token
-	 * @param <T> the type of the token
-	 * @return the token, or {@code null} if not available
-	 */
-	@Nullable
-	@SuppressWarnings("unchecked")
-	public <T extends AbstractOAuth2Token> T getToken(String token) {
-		Assert.hasText(token, "token cannot be empty");
-		OAuth2TokenHolder tokenHolder = this.tokens.values().stream()
-				.filter(holder -> holder.getToken().getTokenValue().equals(token))
-				.findFirst()
-				.orElse(null);
-		return tokenHolder != null ? (T) tokenHolder.getToken() : null;
-	}
-
-	/**
-	 * Returns the token metadata associated to the provided {@code token}.
-	 *
-	 * @param token the token
-	 * @param <T> the type of the token
-	 * @return the token metadata, or {@code null} if not available
-	 */
-	@Nullable
-	public <T extends AbstractOAuth2Token> OAuth2TokenMetadata getTokenMetadata(T token) {
-		Assert.notNull(token, "token cannot be null");
-		OAuth2TokenHolder tokenHolder = this.tokens.get(token.getClass());
-		return (tokenHolder != null && tokenHolder.getToken().equals(token)) ?
-				tokenHolder.getTokenMetadata() : null;
-	}
-
-	@Override
-	public boolean equals(Object obj) {
-		if (this == obj) {
-			return true;
-		}
-		if (obj == null || getClass() != obj.getClass()) {
-			return false;
-		}
-		OAuth2Tokens that = (OAuth2Tokens) obj;
-		return Objects.equals(this.tokens, that.tokens);
-	}
-
-	@Override
-	public int hashCode() {
-		return Objects.hash(this.tokens);
-	}
-
-	/**
-	 * Returns a new {@link Builder}.
-	 *
-	 * @return the {@link Builder}
-	 */
-	public static Builder builder() {
-		return new Builder();
-	}
-
-	/**
-	 * Returns a new {@link Builder}, initialized with the values from the provided {@code tokens}.
-	 *
-	 * @param tokens the tokens used for initializing the {@link Builder}
-	 * @return the {@link Builder}
-	 */
-	public static Builder from(OAuth2Tokens tokens) {
-		Assert.notNull(tokens, "tokens cannot be null");
-		return new Builder(tokens.tokens);
-	}
-
-	/**
-	 * A builder for {@link OAuth2Tokens}.
-	 */
-	public static class Builder implements Serializable {
-		private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
-		private Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens;
-
-		protected Builder() {
-			this.tokens = new HashMap<>();
-		}
-
-		protected Builder(Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens) {
-			this.tokens = new HashMap<>(tokens);
-		}
-
-		/**
-		 * Sets the {@link OAuth2AccessToken access token}.
-		 *
-		 * @param accessToken the {@link OAuth2AccessToken}
-		 * @return the {@link Builder}
-		 */
-		public Builder accessToken(OAuth2AccessToken accessToken) {
-			return addToken(accessToken, null);
-		}
-
-		/**
-		 * Sets the {@link OAuth2AccessToken access token} and associated {@link OAuth2TokenMetadata token metadata}.
-		 *
-		 * @param accessToken the {@link OAuth2AccessToken}
-		 * @param tokenMetadata the {@link OAuth2TokenMetadata}
-		 * @return the {@link Builder}
-		 */
-		public Builder accessToken(OAuth2AccessToken accessToken, OAuth2TokenMetadata tokenMetadata) {
-			return addToken(accessToken, tokenMetadata);
-		}
-
-		/**
-		 * Sets the {@link OAuth2RefreshToken refresh token}.
-		 *
-		 * @param refreshToken the {@link OAuth2RefreshToken}
-		 * @return the {@link Builder}
-		 */
-		public Builder refreshToken(OAuth2RefreshToken refreshToken) {
-			return addToken(refreshToken, null);
-		}
-
-		/**
-		 * Sets the {@link OAuth2RefreshToken refresh token} and associated {@link OAuth2TokenMetadata token metadata}.
-		 *
-		 * @param refreshToken the {@link OAuth2RefreshToken}
-		 * @param tokenMetadata the {@link OAuth2TokenMetadata}
-		 * @return the {@link Builder}
-		 */
-		public Builder refreshToken(OAuth2RefreshToken refreshToken, OAuth2TokenMetadata tokenMetadata) {
-			return addToken(refreshToken, tokenMetadata);
-		}
-
-		/**
-		 * Sets the token.
-		 *
-		 * @param token the token
-		 * @param <T> the type of the token
-		 * @return the {@link Builder}
-		 */
-		public <T extends AbstractOAuth2Token> Builder token(T token) {
-			return addToken(token, null);
-		}
-
-		/**
-		 * Sets the token and associated {@link OAuth2TokenMetadata token metadata}.
-		 *
-		 * @param token the token
-		 * @param tokenMetadata the {@link OAuth2TokenMetadata}
-		 * @param <T> the type of the token
-		 * @return the {@link Builder}
-		 */
-		public <T extends AbstractOAuth2Token> Builder token(T token, OAuth2TokenMetadata tokenMetadata) {
-			return addToken(token, tokenMetadata);
-		}
-
-		protected Builder addToken(AbstractOAuth2Token token, OAuth2TokenMetadata tokenMetadata) {
-			Assert.notNull(token, "token cannot be null");
-			if (tokenMetadata == null) {
-				tokenMetadata = OAuth2TokenMetadata.builder().build();
-			}
-			this.tokens.put(token.getClass(), new OAuth2TokenHolder(token, tokenMetadata));
-			return this;
-		}
-
-		/**
-		 * Builds a new {@link OAuth2Tokens}.
-		 *
-		 * @return the {@link OAuth2Tokens}
-		 */
-		public OAuth2Tokens build() {
-			return new OAuth2Tokens(this.tokens);
-		}
-	}
-
-	protected static class OAuth2TokenHolder implements Serializable {
-		private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
-		private final AbstractOAuth2Token token;
-		private final OAuth2TokenMetadata tokenMetadata;
-
-		protected OAuth2TokenHolder(AbstractOAuth2Token token, OAuth2TokenMetadata tokenMetadata) {
-			this.token = token;
-			this.tokenMetadata = tokenMetadata;
-		}
-
-		protected AbstractOAuth2Token getToken() {
-			return this.token;
-		}
-
-		protected OAuth2TokenMetadata getTokenMetadata() {
-			return this.tokenMetadata;
-		}
-
-		@Override
-		public boolean equals(Object obj) {
-			if (this == obj) {
-				return true;
-			}
-			if (obj == null || getClass() != obj.getClass()) {
-				return false;
-			}
-			OAuth2TokenHolder that = (OAuth2TokenHolder) obj;
-			return Objects.equals(this.token, that.token) &&
-					Objects.equals(this.tokenMetadata, that.tokenMetadata);
-		}
-
-		@Override
-		public int hashCode() {
-			return Objects.hash(this.token, this.tokenMetadata);
-		}
-	}
-}

+ 2 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -54,7 +54,6 @@ import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 import org.springframework.security.web.DefaultRedirectStrategy;
 import org.springframework.security.web.RedirectStrategy;
 import org.springframework.security.web.util.matcher.AndRequestMatcher;
@@ -213,7 +212,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 			OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
 					this.codeGenerator.generateKey(), issuedAt, expiresAt);
 			OAuth2Authorization authorization = builder
-					.tokens(OAuth2Tokens.builder().token(authorizationCode).build())
+					.token(authorizationCode)
 					.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes())
 					.build();
 			this.authorizationService.save(authorization);
@@ -264,7 +263,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 		OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
 				this.codeGenerator.generateKey(), issuedAt, expiresAt);
 		OAuth2Authorization authorization = OAuth2Authorization.from(userConsentRequestContext.getAuthorization())
-				.tokens(OAuth2Tokens.builder().token(authorizationCode).build())
+				.token(authorizationCode)
 				.attributes(attrs -> {
 					attrs.remove(OAuth2AuthorizationAttributeNames.STATE);
 					attrs.put(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, userConsentRequestContext.getScopes());

+ 7 - 7
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java

@@ -198,7 +198,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
 		when(authorizationService.findByToken(
-				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
+				eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
 				eq(TokenType.AUTHORIZATION_CODE)))
 				.thenReturn(authorization);
 
@@ -225,7 +225,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
 		when(authorizationService.findByToken(
-				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
+				eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
 				eq(TokenType.AUTHORIZATION_CODE)))
 				.thenReturn(authorization);
 
@@ -252,7 +252,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 		verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
 		verify(authorizationService).findByToken(
-				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
+				eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
 				eq(TokenType.AUTHORIZATION_CODE));
 		verify(authorizationService).save(any());
 
@@ -286,7 +286,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 		OAuth2Authorization authorization = authorizationCaptor.getValue();
 
 		when(authorizationService.findByToken(
-				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
+				eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
 				eq(TokenType.AUTHORIZATION_CODE)))
 				.thenReturn(authorization);
 
@@ -303,7 +303,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 		verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId()));
 		verify(authorizationService, times(2)).findByToken(
-				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
+				eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
 				eq(TokenType.AUTHORIZATION_CODE));
 		verify(authorizationService, times(2)).save(any());
 	}
@@ -318,7 +318,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
 		when(authorizationService.findByToken(
-				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
+				eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
 				eq(TokenType.AUTHORIZATION_CODE)))
 				.thenReturn(authorization);
 
@@ -343,7 +343,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 			OAuth2Authorization authorization) {
 		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
 		parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
-		parameters.set(OAuth2ParameterNames.CODE, authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue());
+		parameters.set(OAuth2ParameterNames.CODE, authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue());
 		parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
 		return parameters;
 	}

+ 3 - 3
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java

@@ -126,7 +126,7 @@ public class OAuth2RefreshTokenGrantTests {
 
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
 		when(authorizationService.findByToken(
-				eq(authorization.getTokens().getRefreshToken().getTokenValue()),
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
 				eq(TokenType.REFRESH_TOKEN)))
 				.thenReturn(authorization);
 
@@ -146,7 +146,7 @@ public class OAuth2RefreshTokenGrantTests {
 
 		verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
 		verify(authorizationService).findByToken(
-				eq(authorization.getTokens().getRefreshToken().getTokenValue()),
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
 				eq(TokenType.REFRESH_TOKEN));
 		verify(authorizationService).save(any());
 
@@ -169,7 +169,7 @@ public class OAuth2RefreshTokenGrantTests {
 	private static MultiValueMap<String, String> getRefreshTokenRequestParameters(OAuth2Authorization authorization) {
 		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
 		parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue());
-		parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, authorization.getTokens().getRefreshToken().getTokenValue());
+		parameters.set(OAuth2ParameterNames.REFRESH_TOKEN, authorization.getRefreshToken().getToken().getTokenValue());
 		return parameters;
 	}
 

+ 10 - 10
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenRevocationTests.java

@@ -104,7 +104,7 @@ public class OAuth2TokenRevocationTests {
 				.thenReturn(registeredClient);
 
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
-		OAuth2RefreshToken token = authorization.getTokens().getRefreshToken();
+		OAuth2RefreshToken token = authorization.getRefreshToken().getToken();
 		TokenType tokenType = TokenType.REFRESH_TOKEN;
 		when(authorizationService.findByToken(eq(token.getTokenValue()), isNull())).thenReturn(authorization);
 
@@ -121,10 +121,10 @@ public class OAuth2TokenRevocationTests {
 		verify(authorizationService).save(authorizationCaptor.capture());
 
 		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
-		OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken();
-		assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isTrue();
-		OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken();
-		assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue();
+		OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = updatedAuthorization.getRefreshToken();
+		assertThat(refreshToken.isInvalidated()).isTrue();
+		OAuth2Authorization.Token<OAuth2AccessToken> accessToken = updatedAuthorization.getAccessToken();
+		assertThat(accessToken.isInvalidated()).isTrue();
 	}
 
 	@Test
@@ -147,7 +147,7 @@ public class OAuth2TokenRevocationTests {
 				.thenReturn(registeredClient);
 
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
-		OAuth2AccessToken token = authorization.getTokens().getAccessToken();
+		OAuth2AccessToken token = authorization.getAccessToken().getToken();
 		TokenType tokenType = TokenType.ACCESS_TOKEN;
 		when(authorizationService.findByToken(eq(token.getTokenValue()), isNull())).thenReturn(authorization);
 
@@ -164,10 +164,10 @@ public class OAuth2TokenRevocationTests {
 		verify(authorizationService).save(authorizationCaptor.capture());
 
 		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
-		OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken();
-		assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue();
-		OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken();
-		assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isFalse();
+		OAuth2Authorization.Token<OAuth2AccessToken> accessToken = updatedAuthorization.getAccessToken();
+		assertThat(accessToken.isInvalidated()).isTrue();
+		OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = updatedAuthorization.getRefreshToken();
+		assertThat(refreshToken.isInvalidated()).isFalse();
 	}
 
 	private static MultiValueMap<String, String> getTokenRevocationRequestParameters(AbstractOAuth2Token token, TokenType tokenType) {

+ 3 - 3
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java

@@ -183,7 +183,7 @@ public class OidcTests {
 		OAuth2Authorization authorization = authorizationCaptor.getValue();
 
 		when(authorizationService.findByToken(
-				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
+				eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
 				eq(TokenType.AUTHORIZATION_CODE)))
 				.thenReturn(authorization);
 
@@ -204,7 +204,7 @@ public class OidcTests {
 
 		verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId()));
 		verify(authorizationService).findByToken(
-				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
+				eq(authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue()),
 				eq(TokenType.AUTHORIZATION_CODE));
 		verify(authorizationService, times(2)).save(any());
 
@@ -238,7 +238,7 @@ public class OidcTests {
 			OAuth2Authorization authorization) {
 		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
 		parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
-		parameters.set(OAuth2ParameterNames.CODE, authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue());
+		parameters.set(OAuth2ParameterNames.CODE, authorization.getToken(OAuth2AuthorizationCode.class).getToken().getTokenValue());
 		parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
 		return parameters;
 	}

+ 11 - 10
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,17 +15,17 @@
  */
 package org.springframework.security.oauth2.server.authorization;
 
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+
 import org.junit.Before;
 import org.junit.Test;
+
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
-
-import java.time.Instant;
-import java.time.temporal.ChronoUnit;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -59,7 +59,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 	public void saveWhenAuthorizationProvidedThenSaved() {
 		OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build())
+				.token(AUTHORIZATION_CODE)
 				.build();
 		this.authorizationService.save(expectedAuthorization);
 
@@ -79,7 +79,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 	public void removeWhenAuthorizationProvidedThenRemoved() {
 		OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build())
+				.token(AUTHORIZATION_CODE)
 				.build();
 
 		this.authorizationService.save(expectedAuthorization);
@@ -120,7 +120,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 	public void findByTokenWhenAuthorizationCodeExistsThenFound() {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build())
+				.token(AUTHORIZATION_CODE)
 				.build();
 		this.authorizationService.save(authorization);
 
@@ -137,7 +137,8 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 				"access-token", Instant.now().minusSeconds(60), Instant.now());
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(accessToken).build())
+				.token(AUTHORIZATION_CODE)
+				.accessToken(accessToken)
 				.build();
 		this.authorizationService.save(authorization);
 
@@ -153,7 +154,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now());
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.tokens(OAuth2Tokens.builder().refreshToken(refreshToken).build())
+				.refreshToken(refreshToken)
 				.build();
 		this.authorizationService.save(authorization);
 

+ 16 - 13
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,16 +15,16 @@
  */
 package org.springframework.security.oauth2.server.authorization;
 
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+
 import org.junit.Test;
+
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
-
-import java.time.Instant;
-import java.time.temporal.ChronoUnit;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -62,15 +62,16 @@ public class OAuth2AuthorizationTests {
 	public void fromWhenAuthorizationProvidedThenCopied() {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(ACCESS_TOKEN).build())
+				.token(AUTHORIZATION_CODE)
+				.accessToken(ACCESS_TOKEN)
 				.build();
 		OAuth2Authorization authorizationResult = OAuth2Authorization.from(authorization).build();
 
 		assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId());
 		assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName());
-		assertThat(authorizationResult.getTokens().getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken());
-		assertThat(authorizationResult.getTokens().getToken(OAuth2AuthorizationCode.class))
-				.isEqualTo(authorization.getTokens().getToken(OAuth2AuthorizationCode.class));
+		assertThat(authorizationResult.getAccessToken()).isEqualTo(authorization.getAccessToken());
+		assertThat(authorizationResult.getToken(OAuth2AuthorizationCode.class))
+				.isEqualTo(authorization.getToken(OAuth2AuthorizationCode.class));
 		assertThat(authorizationResult.getAttributes()).isEqualTo(authorization.getAttributes());
 	}
 
@@ -103,13 +104,15 @@ public class OAuth2AuthorizationTests {
 	public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(ACCESS_TOKEN).refreshToken(REFRESH_TOKEN).build())
+				.token(AUTHORIZATION_CODE)
+				.accessToken(ACCESS_TOKEN)
+				.refreshToken(REFRESH_TOKEN)
 				.build();
 
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME);
-		assertThat(authorization.getTokens().getToken(OAuth2AuthorizationCode.class)).isEqualTo(AUTHORIZATION_CODE);
-		assertThat(authorization.getTokens().getAccessToken()).isEqualTo(ACCESS_TOKEN);
-		assertThat(authorization.getTokens().getRefreshToken()).isEqualTo(REFRESH_TOKEN);
+		assertThat(authorization.getToken(OAuth2AuthorizationCode.class).getToken()).isEqualTo(AUTHORIZATION_CODE);
+		assertThat(authorization.getAccessToken().getToken()).isEqualTo(ACCESS_TOKEN);
+		assertThat(authorization.getRefreshToken().getToken()).isEqualTo(REFRESH_TOKEN);
 	}
 }

+ 3 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java

@@ -28,7 +28,6 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
 /**
  * @author Joe Grandja
@@ -62,7 +61,9 @@ public class TestOAuth2Authorizations {
 				.build();
 		return OAuth2Authorization.withRegisteredClient(registeredClient)
 				.principalName("principal")
-				.tokens(OAuth2Tokens.builder().token(authorizationCode).accessToken(accessToken).refreshToken(refreshToken).build())
+				.token(authorizationCode)
+				.accessToken(accessToken)
+				.refreshToken(refreshToken)
 				.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest)
 				.attribute(OAuth2AuthorizationAttributeNames.PRINCIPAL,
 						new TestingAuthenticationToken("principal", null, "ROLE_A", "ROLE_B"))

+ 16 - 19
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java

@@ -50,8 +50,6 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis
 import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -172,8 +170,9 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
 		verify(this.authorizationService).save(authorizationCaptor.capture());
 		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
-		OAuth2AuthorizationCode authorizationCode = updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class);
-		assertThat(updatedAuthorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()).isTrue();
+		OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
+				updatedAuthorization.getToken(OAuth2AuthorizationCode.class);
+		assertThat(authorizationCode.isInvalidated()).isTrue();
 	}
 
 	@Test
@@ -201,9 +200,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
 				AUTHORIZATION_CODE, Instant.now(), Instant.now().plusSeconds(120));
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
-				.tokens(OAuth2Tokens.builder()
-						.token(authorizationCode, OAuth2TokenMetadata.builder().invalidated().build())
-						.build())
+				.token(authorizationCode, (metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true))
 				.build();
 		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
 				.thenReturn(authorization);
@@ -265,11 +262,11 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 
 		assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
 		assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
-		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken());
+		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
 		assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull();
-		assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getTokens().getRefreshToken());
-		OAuth2AuthorizationCode authorizationCode = updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class);
-		assertThat(updatedAuthorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()).isTrue();
+		assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
+		OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = updatedAuthorization.getToken(OAuth2AuthorizationCode.class);
+		assertThat(authorizationCode.isInvalidated()).isTrue();
 	}
 
 	@Test
@@ -321,15 +318,15 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 
 		assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
 		assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
-		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken());
+		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
 		assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull();
-		assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getTokens().getRefreshToken());
-		OAuth2AuthorizationCode authorizationCode = updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class);
-		assertThat(updatedAuthorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()).isTrue();
-		OidcIdToken idToken = updatedAuthorization.getTokens().getToken(OidcIdToken.class);
+		assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
+		OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = updatedAuthorization.getToken(OAuth2AuthorizationCode.class);
+		assertThat(authorizationCode.isInvalidated()).isTrue();
+		OAuth2Authorization.Token<OidcIdToken> idToken = updatedAuthorization.getToken(OidcIdToken.class);
 		assertThat(idToken).isNotNull();
 		assertThat(accessTokenAuthentication.getAdditionalParameters())
-				.containsExactly(entry(OidcParameterNames.ID_TOKEN, idToken.getTokenValue()));
+				.containsExactly(entry(OidcParameterNames.ID_TOKEN, idToken.getToken().getTokenValue()));
 	}
 
 	@Test
@@ -362,12 +359,12 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		verify(this.authorizationService).save(authorizationCaptor.capture());
 		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
 
-		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken());
+		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
 		Instant expectedAccessTokenExpiresAt = accessTokenAuthentication.getAccessToken().getIssuedAt().plus(accessTokenTTL);
 		assertThat(accessTokenAuthentication.getAccessToken().getExpiresAt()).isBetween(
 				expectedAccessTokenExpiresAt.minusSeconds(1), expectedAccessTokenExpiresAt.plusSeconds(1));
 
-		assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getTokens().getRefreshToken());
+		assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
 		Instant expectedRefreshTokenExpiresAt = accessTokenAuthentication.getRefreshToken().getIssuedAt().plus(refreshTokenTTL);
 		assertThat(accessTokenAuthentication.getRefreshToken().getExpiresAt()).isBetween(
 				expectedRefreshTokenExpiresAt.minusSeconds(1), expectedRefreshTokenExpiresAt.plusSeconds(1));

+ 3 - 3
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java

@@ -204,10 +204,10 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(clientPrincipal.getRegisteredClient().getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(clientPrincipal.getName());
-		assertThat(authorization.getTokens().getAccessToken()).isNotNull();
-		assertThat(authorization.getTokens().getAccessToken().getScopes()).isEqualTo(clientPrincipal.getRegisteredClient().getScopes());
+		assertThat(authorization.getAccessToken()).isNotNull();
+		assertThat(authorization.getAccessToken().getToken().getScopes()).isEqualTo(clientPrincipal.getRegisteredClient().getScopes());
 		assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
-		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken());
+		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(authorization.getAccessToken().getToken());
 	}
 
 	private static Jwt createJwt(Set<String> scope) {

+ 24 - 30
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java

@@ -47,8 +47,6 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
-import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
 import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
@@ -120,13 +118,13 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
 		when(this.authorizationService.findByToken(
-				eq(authorization.getTokens().getRefreshToken().getTokenValue()),
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
 				eq(TokenType.REFRESH_TOKEN)))
 				.thenReturn(authorization);
 
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
 
 		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
 				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@@ -149,11 +147,11 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 
 		assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
 		assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
-		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken());
-		assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotEqualTo(authorization.getTokens().getAccessToken());
-		assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getTokens().getRefreshToken());
+		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
+		assertThat(updatedAuthorization.getAccessToken()).isNotEqualTo(authorization.getAccessToken());
+		assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
 		// By default, refresh token is reused
-		assertThat(updatedAuthorization.getTokens().getRefreshToken()).isEqualTo(authorization.getTokens().getRefreshToken());
+		assertThat(updatedAuthorization.getRefreshToken()).isEqualTo(authorization.getRefreshToken());
 	}
 
 	@Test
@@ -163,13 +161,13 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 				.build();
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
 		when(this.authorizationService.findByToken(
-				eq(authorization.getTokens().getRefreshToken().getTokenValue()),
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
 				eq(TokenType.REFRESH_TOKEN)))
 				.thenReturn(authorization);
 
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
 
 		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
 				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@@ -178,8 +176,8 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		verify(this.authorizationService).save(authorizationCaptor.capture());
 		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
 
-		assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getTokens().getRefreshToken());
-		assertThat(updatedAuthorization.getTokens().getRefreshToken()).isNotEqualTo(authorization.getTokens().getRefreshToken());
+		assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
+		assertThat(updatedAuthorization.getRefreshToken()).isNotEqualTo(authorization.getRefreshToken());
 	}
 
 	@Test
@@ -187,7 +185,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
 		when(this.authorizationService.findByToken(
-				eq(authorization.getTokens().getRefreshToken().getTokenValue()),
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
 				eq(TokenType.REFRESH_TOKEN)))
 				.thenReturn(authorization);
 
@@ -196,7 +194,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		Set<String> requestedScopes = new HashSet<>(authorizedScopes);
 		requestedScopes.remove("email");
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal, requestedScopes);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes);
 
 		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
 				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@@ -209,7 +207,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
 		when(this.authorizationService.findByToken(
-				eq(authorization.getTokens().getRefreshToken().getTokenValue()),
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
 				eq(TokenType.REFRESH_TOKEN)))
 				.thenReturn(authorization);
 
@@ -218,7 +216,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		Set<String> requestedScopes = new HashSet<>(authorizedScopes);
 		requestedScopes.add("unauthorized");
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal, requestedScopes);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, requestedScopes);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -276,14 +274,14 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
 		when(this.authorizationService.findByToken(
-				eq(authorization.getTokens().getRefreshToken().getTokenValue()),
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
 				eq(TokenType.REFRESH_TOKEN)))
 				.thenReturn(authorization);
 
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
 				TestRegisteredClients.registeredClient2().build());
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -299,13 +297,13 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 				.build();
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
 		when(this.authorizationService.findByToken(
-				eq(authorization.getTokens().getRefreshToken().getTokenValue()),
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
 				eq(TokenType.REFRESH_TOKEN)))
 				.thenReturn(authorization);
 
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -320,16 +318,15 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
 		OAuth2RefreshToken expiredRefreshToken = new OAuth2RefreshToken2(
 				"expired-refresh-token", Instant.now().minusSeconds(120), Instant.now().minusSeconds(60));
-		OAuth2Tokens tokens = OAuth2Tokens.from(authorization.getTokens()).refreshToken(expiredRefreshToken).build();
-		authorization = OAuth2Authorization.from(authorization).tokens(tokens).build();
+		authorization = OAuth2Authorization.from(authorization).token(expiredRefreshToken).build();
 		when(this.authorizationService.findByToken(
-				eq(authorization.getTokens().getRefreshToken().getTokenValue()),
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
 				eq(TokenType.REFRESH_TOKEN)))
 				.thenReturn(authorization);
 
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)
@@ -343,20 +340,17 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken2(
 				"refresh-token", Instant.now().minusSeconds(120), Instant.now().plusSeconds(1000));
-		OAuth2TokenMetadata metadata = OAuth2TokenMetadata.builder().invalidated().build();
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
-				.tokens(OAuth2Tokens.builder()
-						.refreshToken(refreshToken, metadata)
-						.build())
+				.token(refreshToken, (metadata) -> metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, true))
 				.build();
 		when(this.authorizationService.findByToken(
-				eq(authorization.getTokens().getRefreshToken().getTokenValue()),
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
 				eq(TokenType.REFRESH_TOKEN)))
 				.thenReturn(authorization);
 
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
-				authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal);
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal);
 
 		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
 				.isInstanceOf(OAuth2AuthenticationException.class)

+ 13 - 12
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenRevocationAuthenticationProviderTests.java

@@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.authentication;
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
+
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -136,13 +137,13 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(
 				registeredClient).build();
 		when(this.authorizationService.findByToken(
-				eq(authorization.getTokens().getRefreshToken().getTokenValue()),
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
 				isNull()))
 				.thenReturn(authorization);
 
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
-				authorization.getTokens().getRefreshToken().getTokenValue(), clientPrincipal, TokenType.REFRESH_TOKEN.getValue());
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, TokenType.REFRESH_TOKEN.getValue());
 
 		OAuth2TokenRevocationAuthenticationToken authenticationResult =
 				(OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@@ -152,10 +153,10 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
 		verify(this.authorizationService).save(authorizationCaptor.capture());
 
 		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
-		OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken();
-		assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isTrue();
-		OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken();
-		assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue();
+		OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = updatedAuthorization.getRefreshToken();
+		assertThat(refreshToken.isInvalidated()).isTrue();
+		OAuth2Authorization.Token<OAuth2AccessToken> accessToken = updatedAuthorization.getAccessToken();
+		assertThat(accessToken.isInvalidated()).isTrue();
 	}
 
 	@Test
@@ -164,13 +165,13 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(
 				registeredClient).build();
 		when(this.authorizationService.findByToken(
-				eq(authorization.getTokens().getAccessToken().getTokenValue()),
+				eq(authorization.getAccessToken().getToken().getTokenValue()),
 				isNull()))
 				.thenReturn(authorization);
 
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2TokenRevocationAuthenticationToken authentication = new OAuth2TokenRevocationAuthenticationToken(
-				authorization.getTokens().getAccessToken().getTokenValue(), clientPrincipal, TokenType.ACCESS_TOKEN.getValue());
+				authorization.getAccessToken().getToken().getTokenValue(), clientPrincipal, TokenType.ACCESS_TOKEN.getValue());
 
 		OAuth2TokenRevocationAuthenticationToken authenticationResult =
 				(OAuth2TokenRevocationAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@@ -180,9 +181,9 @@ public class OAuth2TokenRevocationAuthenticationProviderTests {
 		verify(this.authorizationService).save(authorizationCaptor.capture());
 
 		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
-		OAuth2AccessToken accessToken = updatedAuthorization.getTokens().getAccessToken();
-		assertThat(updatedAuthorization.getTokens().getTokenMetadata(accessToken).isInvalidated()).isTrue();
-		OAuth2RefreshToken refreshToken = updatedAuthorization.getTokens().getRefreshToken();
-		assertThat(updatedAuthorization.getTokens().getTokenMetadata(refreshToken).isInvalidated()).isFalse();
+		OAuth2Authorization.Token<OAuth2AccessToken> accessToken = updatedAuthorization.getAccessToken();
+		assertThat(accessToken.isInvalidated()).isTrue();
+		OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = updatedAuthorization.getRefreshToken();
+		assertThat(refreshToken.isInvalidated()).isFalse();
 	}
 }

+ 0 - 74
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenMetadataTests.java

@@ -1,74 +0,0 @@
-/*
- * Copyright 2020 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.server.authorization.token;
-
-import org.junit.Test;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
-
-/**
- * Tests for {@link OAuth2TokenMetadata}.
- *
- * @author Joe Grandja
- */
-public class OAuth2TokenMetadataTests {
-
-	@Test
-	public void metadataWhenNameNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() ->
-				OAuth2TokenMetadata.builder()
-						.metadata(null, "value"))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("name cannot be empty");
-	}
-
-	@Test
-	public void metadataWhenValueNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() ->
-				OAuth2TokenMetadata.builder()
-						.metadata("name", null))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("value cannot be null");
-	}
-
-	@Test
-	public void getMetadataWhenNameNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> OAuth2TokenMetadata.builder().build().getMetadata(null))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("name cannot be empty");
-	}
-
-	@Test
-	public void buildWhenDefaultThenDefaultsAreSet() {
-		OAuth2TokenMetadata tokenMetadata = OAuth2TokenMetadata.builder().build();
-		assertThat(tokenMetadata.getMetadata()).hasSize(1);
-		assertThat(tokenMetadata.isInvalidated()).isFalse();
-	}
-
-	@Test
-	public void buildWhenMetadataProvidedThenMetadataIsSet() {
-		OAuth2TokenMetadata tokenMetadata = OAuth2TokenMetadata.builder()
-				.invalidated()
-				.metadata("name1", "value1")
-				.metadata(metadata -> metadata.put("name2", "value2"))
-				.build();
-		assertThat(tokenMetadata.getMetadata()).hasSize(3);
-		assertThat(tokenMetadata.isInvalidated()).isTrue();
-		assertThat(tokenMetadata.<String>getMetadata("name1")).isEqualTo("value1");
-		assertThat(tokenMetadata.<String>getMetadata("name2")).isEqualTo("value2");
-	}
-}

+ 0 - 195
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java

@@ -1,195 +0,0 @@
-/*
- * Copyright 2020 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.server.authorization.token;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.springframework.security.oauth2.core.OAuth2AccessToken;
-import org.springframework.security.oauth2.core.OAuth2RefreshToken;
-import org.springframework.security.oauth2.core.oidc.OidcIdToken;
-
-import java.time.Duration;
-import java.time.Instant;
-import java.util.Arrays;
-import java.util.HashSet;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
-
-/**
- * Tests for {@link OAuth2Tokens}.
- *
- * @author Joe Grandja
- */
-public class OAuth2TokensTests {
-	private OAuth2AccessToken accessToken;
-	private OAuth2RefreshToken refreshToken;
-	private OidcIdToken idToken;
-
-	@Before
-	public void setUp() {
-		Instant issuedAt = Instant.now();
-		this.accessToken = new OAuth2AccessToken(
-				OAuth2AccessToken.TokenType.BEARER,
-				"access-token",
-				issuedAt,
-				issuedAt.plus(Duration.ofMinutes(5)),
-				new HashSet<>(Arrays.asList("read", "write")));
-		this.refreshToken = new OAuth2RefreshToken(
-				"refresh-token",
-				issuedAt);
-		this.idToken = OidcIdToken.withTokenValue("id-token")
-				.issuer("https://provider.com")
-				.subject("subject")
-				.issuedAt(issuedAt)
-				.expiresAt(issuedAt.plus(Duration.ofMinutes(30)))
-				.build();
-	}
-
-	@Test
-	public void accessTokenWhenNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> OAuth2Tokens.builder().accessToken(null))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("token cannot be null");
-	}
-
-	@Test
-	public void refreshTokenWhenNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> OAuth2Tokens.builder().refreshToken(null))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("token cannot be null");
-	}
-
-	@Test
-	public void tokenWhenNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> OAuth2Tokens.builder().token(null))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("token cannot be null");
-	}
-
-	@Test
-	public void getTokenWhenTokenTypeNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> OAuth2Tokens.builder().build().getToken((Class<OAuth2AccessToken>) null))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("tokenType cannot be null");
-	}
-
-	@Test
-	public void getTokenWhenTokenNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> OAuth2Tokens.builder().build().getToken((String) null))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("token cannot be empty");
-	}
-
-	@Test
-	public void getTokenMetadataWhenTokenNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> OAuth2Tokens.builder().build().getTokenMetadata(null))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("token cannot be null");
-	}
-
-	@Test
-	public void fromWhenTokensNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> OAuth2Tokens.from(null))
-				.isInstanceOf(IllegalArgumentException.class)
-				.hasMessage("tokens cannot be null");
-	}
-
-	@Test
-	public void fromWhenTokensProvidedThenCopied() {
-		OAuth2Tokens tokens = OAuth2Tokens.builder()
-				.accessToken(this.accessToken)
-				.refreshToken(this.refreshToken)
-				.token(this.idToken)
-				.build();
-		OAuth2Tokens tokensResult = OAuth2Tokens.from(tokens).build();
-
-		assertThat(tokensResult.getAccessToken()).isEqualTo(tokens.getAccessToken());
-		assertThat(tokensResult.getTokenMetadata(tokensResult.getAccessToken()))
-				.isEqualTo(tokens.getTokenMetadata(tokens.getAccessToken()));
-
-		assertThat(tokensResult.getRefreshToken()).isEqualTo(tokens.getRefreshToken());
-		assertThat(tokensResult.getTokenMetadata(tokensResult.getRefreshToken()))
-				.isEqualTo(tokens.getTokenMetadata(tokens.getRefreshToken()));
-
-		assertThat(tokensResult.getToken(OidcIdToken.class)).isEqualTo(tokens.getToken(OidcIdToken.class));
-		assertThat(tokensResult.getTokenMetadata(tokensResult.getToken(OidcIdToken.class)))
-				.isEqualTo(tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class)));
-	}
-
-	@Test
-	public void buildWhenTokenMetadataNotProvidedThenDefaultsAreSet() {
-		OAuth2Tokens tokens = OAuth2Tokens.builder()
-				.accessToken(this.accessToken)
-				.refreshToken(this.refreshToken)
-				.token(this.idToken)
-				.build();
-
-		assertThat(tokens.getAccessToken()).isEqualTo(this.accessToken);
-		OAuth2TokenMetadata tokenMetadata = tokens.getTokenMetadata(tokens.getAccessToken());
-		assertThat(tokenMetadata.isInvalidated()).isFalse();
-
-		assertThat(tokens.getRefreshToken()).isEqualTo(this.refreshToken);
-		tokenMetadata = tokens.getTokenMetadata(tokens.getRefreshToken());
-		assertThat(tokenMetadata.isInvalidated()).isFalse();
-
-		assertThat(tokens.getToken(OidcIdToken.class)).isEqualTo(this.idToken);
-		tokenMetadata = tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class));
-		assertThat(tokenMetadata.isInvalidated()).isFalse();
-	}
-
-	@Test
-	public void buildWhenTokenMetadataProvidedThenTokenMetadataIsSet() {
-		OAuth2TokenMetadata expectedTokenMetadata = OAuth2TokenMetadata.builder().build();
-		OAuth2Tokens tokens = OAuth2Tokens.builder()
-				.accessToken(this.accessToken, expectedTokenMetadata)
-				.refreshToken(this.refreshToken, expectedTokenMetadata)
-				.token(this.idToken, expectedTokenMetadata)
-				.build();
-
-		assertThat(tokens.getAccessToken()).isEqualTo(this.accessToken);
-		OAuth2TokenMetadata tokenMetadata = tokens.getTokenMetadata(tokens.getAccessToken());
-		assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata);
-
-		assertThat(tokens.getRefreshToken()).isEqualTo(this.refreshToken);
-		tokenMetadata = tokens.getTokenMetadata(tokens.getRefreshToken());
-		assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata);
-
-		assertThat(tokens.getToken(OidcIdToken.class)).isEqualTo(this.idToken);
-		tokenMetadata = tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class));
-		assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata);
-	}
-
-	@Test
-	public void getTokenMetadataWhenTokenNotFoundThenNull() {
-		OAuth2TokenMetadata expectedTokenMetadata = OAuth2TokenMetadata.builder().build();
-		OAuth2Tokens tokens = OAuth2Tokens.builder()
-				.accessToken(this.accessToken, expectedTokenMetadata)
-				.build();
-
-		assertThat(tokens.getAccessToken()).isEqualTo(this.accessToken);
-		OAuth2TokenMetadata tokenMetadata = tokens.getTokenMetadata(tokens.getAccessToken());
-		assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata);
-
-		OAuth2AccessToken otherAccessToken = new OAuth2AccessToken(
-				this.accessToken.getTokenType(),
-				"other-access-token",
-				this.accessToken.getIssuedAt(),
-				this.accessToken.getExpiresAt(),
-				this.accessToken.getScopes());
-		assertThat(tokens.getTokenMetadata(otherAccessToken)).isNull();
-	}
-}

+ 3 - 3
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -470,7 +470,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
 				.isEqualTo(this.authentication);
 
-		OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
+		OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = authorization.getToken(OAuth2AuthorizationCode.class);
 		assertThat(authorizationCode).isNotNull();
 
 		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
@@ -519,7 +519,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
 				.isEqualTo(this.authentication);
 
-		OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
+		OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = authorization.getToken(OAuth2AuthorizationCode.class);
 		assertThat(authorizationCode).isNotNull();
 
 		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
@@ -795,7 +795,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
 		assertThat(updatedAuthorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
 		assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
-		assertThat(updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class)).isNotNull();
+		assertThat(updatedAuthorization.getToken(OAuth2AuthorizationCode.class)).isNotNull();
 		assertThat(updatedAuthorization.<String>getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull();
 		assertThat(updatedAuthorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST))
 				.isEqualTo(authorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST));