Эх сурвалжийг харах

Introduce OAuth2Tokens

Closes gh-137
Joe Grandja 4 жил өмнө
parent
commit
af60f3d4d0
14 өөрчлөгдсөн 774 нэмэгдсэн , 24 устгасан
  1. 2 3
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java
  2. 44 7
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java
  3. 2 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java
  4. 2 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java
  5. 169 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenMetadata.java
  6. 279 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java
  7. 2 1
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java
  8. 5 4
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java
  9. 2 1
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java
  10. 2 2
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java
  11. 3 3
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java
  12. 74 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenMetadataTests.java
  13. 187 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java
  14. 1 1
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

+ 2 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java

@@ -16,7 +16,6 @@
 package org.springframework.security.oauth2.server.authorization;
 
 import org.springframework.lang.Nullable;
-import org.springframework.security.oauth2.server.authorization.Version;
 import org.springframework.util.Assert;
 
 import java.io.Serializable;
@@ -66,8 +65,8 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
 		} else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) {
 			return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE));
 		} else if (TokenType.ACCESS_TOKEN.equals(tokenType)) {
-			return authorization.getAccessToken() != null &&
-					authorization.getAccessToken().getTokenValue().equals(token);
+			return authorization.getTokens().getAccessToken() != null &&
+					authorization.getTokens().getAccessToken().getTokenValue().equals(token);
 		}
 		return false;
 	}

+ 44 - 7
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java

@@ -15,9 +15,9 @@
  */
 package org.springframework.security.oauth2.server.authorization;
 
-import org.springframework.security.oauth2.server.authorization.Version;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 import org.springframework.util.Assert;
 
 import java.io.Serializable;
@@ -36,13 +36,17 @@ import java.util.function.Consumer;
  * @author Krisztian Toth
  * @since 0.0.1
  * @see RegisteredClient
- * @see OAuth2AccessToken
+ * @see OAuth2Tokens
  */
 public class OAuth2Authorization implements Serializable {
 	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
 	private String registeredClientId;
 	private String principalName;
+	private OAuth2Tokens tokens;
+
+	@Deprecated
 	private OAuth2AccessToken accessToken;
+
 	private Map<String, Object> attributes;
 
 	protected OAuth2Authorization() {
@@ -66,13 +70,23 @@ public class OAuth2Authorization implements Serializable {
 		return this.principalName;
 	}
 
+	/**
+	 * Returns the {@link OAuth2Tokens}.
+	 *
+	 * @return the {@link OAuth2Tokens}
+	 */
+	public OAuth2Tokens getTokens() {
+		return this.tokens;
+	}
+
 	/**
 	 * Returns the {@link OAuth2AccessToken access token} credential.
 	 *
 	 * @return the {@link OAuth2AccessToken}
 	 */
+	@Deprecated
 	public OAuth2AccessToken getAccessToken() {
-		return this.accessToken;
+		return getTokens().getAccessToken();
 	}
 
 	/**
@@ -108,13 +122,13 @@ public class OAuth2Authorization implements Serializable {
 		OAuth2Authorization that = (OAuth2Authorization) obj;
 		return Objects.equals(this.registeredClientId, that.registeredClientId) &&
 				Objects.equals(this.principalName, that.principalName) &&
-				Objects.equals(this.accessToken, that.accessToken) &&
+				Objects.equals(this.tokens, that.tokens) &&
 				Objects.equals(this.attributes, that.attributes);
 	}
 
 	@Override
 	public int hashCode() {
-		return Objects.hash(this.registeredClientId, this.principalName, this.accessToken, this.attributes);
+		return Objects.hash(this.registeredClientId, this.principalName, this.tokens, this.attributes);
 	}
 
 	/**
@@ -138,7 +152,7 @@ public class OAuth2Authorization implements Serializable {
 		Assert.notNull(authorization, "authorization cannot be null");
 		return new Builder(authorization.getRegisteredClientId())
 				.principalName(authorization.getPrincipalName())
-				.accessToken(authorization.getAccessToken())
+				.tokens(authorization.getTokens())
 				.attributes(attrs -> attrs.putAll(authorization.getAttributes()));
 	}
 
@@ -149,7 +163,11 @@ public class OAuth2Authorization implements Serializable {
 		private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
 		private String registeredClientId;
 		private String principalName;
+		private OAuth2Tokens tokens;
+
+		@Deprecated
 		private OAuth2AccessToken accessToken;
+
 		private Map<String, Object> attributes = new HashMap<>();
 
 		protected Builder(String registeredClientId) {
@@ -167,12 +185,24 @@ public class OAuth2Authorization implements Serializable {
 			return this;
 		}
 
+		/**
+		 * Sets the {@link OAuth2Tokens}.
+		 *
+		 * @param tokens the {@link OAuth2Tokens}
+		 * @return the {@link Builder}
+		 */
+		public Builder tokens(OAuth2Tokens tokens) {
+			this.tokens = tokens;
+			return this;
+		}
+
 		/**
 		 * Sets the {@link OAuth2AccessToken access token} credential.
 		 *
 		 * @param accessToken the {@link OAuth2AccessToken}
 		 * @return the {@link Builder}
 		 */
+		@Deprecated
 		public Builder accessToken(OAuth2AccessToken accessToken) {
 			this.accessToken = accessToken;
 			return this;
@@ -215,7 +245,14 @@ public class OAuth2Authorization implements Serializable {
 			OAuth2Authorization authorization = new OAuth2Authorization();
 			authorization.registeredClientId = this.registeredClientId;
 			authorization.principalName = this.principalName;
-			authorization.accessToken = this.accessToken;
+			if (this.tokens == null) {
+				OAuth2Tokens.Builder builder = OAuth2Tokens.builder();
+				if (this.accessToken != null) {
+					builder.accessToken(this.accessToken);
+				}
+				this.tokens = builder.build();
+			}
+			authorization.tokens = this.tokens;
 			authorization.attributes = Collections.unmodifiableMap(this.attributes);
 			return authorization;
 		}

+ 2 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -35,6 +35,7 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
 import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 
@@ -143,7 +144,7 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 
 		authorization = OAuth2Authorization.from(authorization)
 				.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt)
-				.accessToken(accessToken)
+				.tokens(OAuth2Tokens.builder().accessToken(accessToken).build())
 				.build();
 		this.authorizationService.save(authorization);
 

+ 2 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java

@@ -32,6 +32,7 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 
@@ -129,7 +130,7 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient)
 				.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt)
 				.principalName(clientPrincipal.getName())
-				.accessToken(accessToken)
+				.tokens(OAuth2Tokens.builder().accessToken(accessToken).build())
 				.build();
 		this.authorizationService.save(authorization);
 

+ 169 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenMetadata.java

@@ -0,0 +1,169 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.token;
+
+import org.springframework.security.oauth2.server.authorization.Version;
+import org.springframework.util.Assert;
+
+import java.io.Serializable;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.function.Consumer;
+
+/**
+ * Holds metadata associated to an OAuth 2.0 Token.
+ *
+ * @author Joe Grandja
+ * @since 0.0.3
+ * @see OAuth2Tokens
+ */
+public class OAuth2TokenMetadata implements Serializable {
+	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
+	protected static final String TOKEN_METADATA_BASE = "token.metadata.";
+
+	/**
+	 * The name of the metadata that indicates if the token has been invalidated.
+	 */
+	public static final String INVALIDATED = TOKEN_METADATA_BASE.concat("invalidated");
+
+	private final Map<String, Object> metadata;
+
+	protected OAuth2TokenMetadata(Map<String, Object> metadata) {
+		this.metadata = Collections.unmodifiableMap(new HashMap<>(metadata));
+	}
+
+	/**
+	 * Returns {@code true} if the token has been invalidated (e.g. revoked).
+	 * The default is {@code false}.
+	 *
+	 * @return {@code true} if the token has been invalidated, {@code false} otherwise
+	 */
+	public boolean isInvalidated() {
+		return getMetadata(INVALIDATED);
+	}
+
+	/**
+	 * Returns the value of the metadata associated to the token.
+	 *
+	 * @param name the name of the metadata
+	 * @param <T> the type of the metadata
+	 * @return the value of the metadata, or {@code null} if not available
+	 */
+	@SuppressWarnings("unchecked")
+	public <T> T getMetadata(String name) {
+		Assert.hasText(name, "name cannot be empty");
+		return (T) this.metadata.get(name);
+	}
+
+	/**
+	 * Returns the metadata associated to the token.
+	 *
+	 * @return a {@code Map} of the metadata
+	 */
+	public Map<String, Object> getMetadata() {
+		return this.metadata;
+	}
+
+	@Override
+	public boolean equals(Object obj) {
+		if (this == obj) {
+			return true;
+		}
+		if (obj == null || getClass() != obj.getClass()) {
+			return false;
+		}
+		OAuth2TokenMetadata that = (OAuth2TokenMetadata) obj;
+		return Objects.equals(this.metadata, that.metadata);
+	}
+
+	@Override
+	public int hashCode() {
+		return Objects.hash(this.metadata);
+	}
+
+	/**
+	 * Returns a new {@link Builder}.
+	 *
+	 * @return the {@link Builder}
+	 */
+	public static Builder builder() {
+		return new Builder();
+	}
+
+	/**
+	 * A builder for {@link OAuth2TokenMetadata}.
+	 */
+	public static class Builder implements Serializable {
+		private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
+		private final Map<String, Object> metadata = defaultMetadata();
+
+		protected Builder() {
+		}
+
+		/**
+		 * Set the token as invalidated (e.g. revoked).
+		 *
+		 * @return the {@link Builder}
+		 */
+		public Builder invalidated() {
+			metadata(INVALIDATED, true);
+			return this;
+		}
+
+		/**
+		 * Adds a metadata associated to the token.
+		 *
+		 * @param name the name of the metadata
+		 * @param value the value of the metadata
+		 * @return the {@link Builder}
+		 */
+		public Builder metadata(String name, Object value) {
+			Assert.hasText(name, "name cannot be empty");
+			Assert.notNull(value, "value cannot be null");
+			this.metadata.put(name, value);
+			return this;
+		}
+
+		/**
+		 * A {@code Consumer} of the metadata {@code Map}
+		 * allowing the ability to add, replace, or remove.
+		 *
+		 * @param metadataConsumer a {@link Consumer} of the metadata {@code Map}
+		 * @return the {@link Builder}
+		 */
+		public Builder metadata(Consumer<Map<String, Object>> metadataConsumer) {
+			metadataConsumer.accept(this.metadata);
+			return this;
+		}
+
+		/**
+		 * Builds a new {@link OAuth2TokenMetadata}.
+		 *
+		 * @return the {@link OAuth2TokenMetadata}
+		 */
+		public OAuth2TokenMetadata build() {
+			return new OAuth2TokenMetadata(this.metadata);
+		}
+
+		protected static Map<String, Object> defaultMetadata() {
+			Map<String, Object> metadata = new HashMap<>();
+			metadata.put(INVALIDATED, false);
+			return metadata;
+		}
+	}
+}

+ 279 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java

@@ -0,0 +1,279 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.token;
+
+import org.springframework.lang.Nullable;
+import org.springframework.security.oauth2.core.AbstractOAuth2Token;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.Version;
+import org.springframework.util.Assert;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * A container for OAuth 2.0 Tokens.
+ *
+ * @author Joe Grandja
+ * @since 0.0.3
+ * @see OAuth2Authorization
+ * @see OAuth2TokenMetadata
+ * @see AbstractOAuth2Token
+ * @see OAuth2AccessToken
+ * @see OAuth2RefreshToken
+ */
+public class OAuth2Tokens implements Serializable {
+	private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
+	private final Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens;
+
+	protected OAuth2Tokens(Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens) {
+		this.tokens = new HashMap<>(tokens);
+	}
+
+	/**
+	 * Returns the {@link OAuth2AccessToken access token}.
+	 *
+	 * @return the {@link OAuth2AccessToken}, or {@code null} if not available
+	 */
+	@Nullable
+	public OAuth2AccessToken getAccessToken() {
+		return getToken(OAuth2AccessToken.class);
+	}
+
+	/**
+	 * Returns the {@link OAuth2RefreshToken refresh token}.
+	 *
+	 * @return the {@link OAuth2RefreshToken}, or {@code null} if not available
+	 */
+	@Nullable
+	public OAuth2RefreshToken getRefreshToken() {
+		return getToken(OAuth2RefreshToken.class);
+	}
+
+	/**
+	 * Returns the token specified by {@code tokenType}.
+	 *
+	 * @param tokenType the token type
+	 * @param <T> the type of the token
+	 * @return the token, or {@code null} if not available
+	 */
+	@Nullable
+	@SuppressWarnings("unchecked")
+	public <T extends AbstractOAuth2Token> T getToken(Class<T> tokenType) {
+		Assert.notNull(tokenType, "tokenType cannot be null");
+		OAuth2TokenHolder tokenHolder = this.tokens.get(tokenType);
+		return tokenHolder != null ? (T) tokenHolder.getToken() : null;
+	}
+
+	/**
+	 * Returns the token metadata associated to the provided {@code token}.
+	 *
+	 * @param token the token
+	 * @param <T> the type of the token
+	 * @return the token metadata, or {@code null} if not available
+	 */
+	@Nullable
+	public <T extends AbstractOAuth2Token> OAuth2TokenMetadata getTokenMetadata(T token) {
+		Assert.notNull(token, "token cannot be null");
+		OAuth2TokenHolder tokenHolder = this.tokens.get(token.getClass());
+		return (tokenHolder != null && tokenHolder.getToken().equals(token)) ?
+				tokenHolder.getTokenMetadata() : null;
+	}
+
+	/**
+	 * Invalidates all tokens.
+	 */
+	public void invalidate() {
+		this.tokens.values().forEach(tokenHolder -> invalidate(tokenHolder.getToken()));
+	}
+
+	/**
+	 * Invalidates the token matching the provided {@code token}.
+	 *
+	 * @param token the token
+	 * @param <T> the type of the token
+	 */
+	public <T extends AbstractOAuth2Token> void invalidate(T token) {
+		Assert.notNull(token, "token cannot be null");
+		this.tokens.computeIfPresent(token.getClass(),
+				(tokenType, tokenHolder) ->
+						new OAuth2TokenHolder(
+								tokenHolder.getToken(),
+								OAuth2TokenMetadata.builder().invalidated().build())
+		);
+	}
+
+	@Override
+	public boolean equals(Object obj) {
+		if (this == obj) {
+			return true;
+		}
+		if (obj == null || getClass() != obj.getClass()) {
+			return false;
+		}
+		OAuth2Tokens that = (OAuth2Tokens) obj;
+		return Objects.equals(this.tokens, that.tokens);
+	}
+
+	@Override
+	public int hashCode() {
+		return Objects.hash(this.tokens);
+	}
+
+	/**
+	 * Returns a new {@link Builder}.
+	 *
+	 * @return the {@link Builder}
+	 */
+	public static Builder builder() {
+		return new Builder();
+	}
+
+	/**
+	 * A builder for {@link OAuth2Tokens}.
+	 */
+	public static class Builder implements Serializable {
+		private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
+		private final Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens = new HashMap<>();
+
+		protected Builder() {
+		}
+
+		/**
+		 * Sets the {@link OAuth2AccessToken access token}.
+		 *
+		 * @param accessToken the {@link OAuth2AccessToken}
+		 * @return the {@link Builder}
+		 */
+		public Builder accessToken(OAuth2AccessToken accessToken) {
+			return addToken(accessToken, null);
+		}
+
+		/**
+		 * Sets the {@link OAuth2AccessToken access token} and associated {@link OAuth2TokenMetadata token metadata}.
+		 *
+		 * @param accessToken the {@link OAuth2AccessToken}
+		 * @param tokenMetadata the {@link OAuth2TokenMetadata}
+		 * @return the {@link Builder}
+		 */
+		public Builder accessToken(OAuth2AccessToken accessToken, OAuth2TokenMetadata tokenMetadata) {
+			return addToken(accessToken, tokenMetadata);
+		}
+
+		/**
+		 * Sets the {@link OAuth2RefreshToken refresh token}.
+		 *
+		 * @param refreshToken the {@link OAuth2RefreshToken}
+		 * @return the {@link Builder}
+		 */
+		public Builder refreshToken(OAuth2RefreshToken refreshToken) {
+			return addToken(refreshToken, null);
+		}
+
+		/**
+		 * Sets the {@link OAuth2RefreshToken refresh token} and associated {@link OAuth2TokenMetadata token metadata}.
+		 *
+		 * @param refreshToken the {@link OAuth2RefreshToken}
+		 * @param tokenMetadata the {@link OAuth2TokenMetadata}
+		 * @return the {@link Builder}
+		 */
+		public Builder refreshToken(OAuth2RefreshToken refreshToken, OAuth2TokenMetadata tokenMetadata) {
+			return addToken(refreshToken, tokenMetadata);
+		}
+
+		/**
+		 * Sets the token.
+		 *
+		 * @param token the token
+		 * @param <T> the type of the token
+		 * @return the {@link Builder}
+		 */
+		public <T extends AbstractOAuth2Token> Builder token(T token) {
+			return addToken(token, null);
+		}
+
+		/**
+		 * Sets the token and associated {@link OAuth2TokenMetadata token metadata}.
+		 *
+		 * @param token the token
+		 * @param tokenMetadata the {@link OAuth2TokenMetadata}
+		 * @param <T> the type of the token
+		 * @return the {@link Builder}
+		 */
+		public <T extends AbstractOAuth2Token> Builder token(T token, OAuth2TokenMetadata tokenMetadata) {
+			return addToken(token, tokenMetadata);
+		}
+
+		protected Builder addToken(AbstractOAuth2Token token, OAuth2TokenMetadata tokenMetadata) {
+			Assert.notNull(token, "token cannot be null");
+			if (tokenMetadata == null) {
+				tokenMetadata = OAuth2TokenMetadata.builder().build();
+			}
+			this.tokens.put(token.getClass(), new OAuth2TokenHolder(token, tokenMetadata));
+			return this;
+		}
+
+		/**
+		 * Builds a new {@link OAuth2Tokens}.
+		 *
+		 * @return the {@link OAuth2Tokens}
+		 */
+		public OAuth2Tokens build() {
+			return new OAuth2Tokens(this.tokens);
+		}
+	}
+
+	protected static class OAuth2TokenHolder implements Serializable {
+		private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
+		private final AbstractOAuth2Token token;
+		private final OAuth2TokenMetadata tokenMetadata;
+
+		protected OAuth2TokenHolder(AbstractOAuth2Token token, OAuth2TokenMetadata tokenMetadata) {
+			this.token = token;
+			this.tokenMetadata = tokenMetadata;
+		}
+
+		protected AbstractOAuth2Token getToken() {
+			return this.token;
+		}
+
+		protected OAuth2TokenMetadata getTokenMetadata() {
+			return this.tokenMetadata;
+		}
+
+		@Override
+		public boolean equals(Object obj) {
+			if (this == obj) {
+				return true;
+			}
+			if (obj == null || getClass() != obj.getClass()) {
+				return false;
+			}
+			OAuth2TokenHolder that = (OAuth2TokenHolder) obj;
+			return Objects.equals(this.token, that.token) &&
+					Objects.equals(this.tokenMetadata, that.tokenMetadata);
+		}
+
+		@Override
+		public int hashCode() {
+			return Objects.hash(this.token, this.tokenMetadata);
+		}
+	}
+}

+ 2 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java

@@ -20,6 +20,7 @@ import org.junit.Test;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
 import java.time.Instant;
 
@@ -129,7 +130,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
 				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
-				.accessToken(accessToken)
+				.tokens(OAuth2Tokens.builder().accessToken(accessToken).build())
 				.build();
 		this.authorizationService.save(authorization);
 

+ 5 - 4
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java

@@ -19,6 +19,7 @@ import org.junit.Test;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
 import java.time.Instant;
 
@@ -57,14 +58,14 @@ public class OAuth2AuthorizationTests {
 	public void fromWhenAuthorizationProvidedThenCopied() {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.accessToken(ACCESS_TOKEN)
+				.tokens(OAuth2Tokens.builder().accessToken(ACCESS_TOKEN).build())
 				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
 				.build();
 		OAuth2Authorization authorizationResult = OAuth2Authorization.from(authorization).build();
 
 		assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId());
 		assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName());
-		assertThat(authorizationResult.getAccessToken()).isEqualTo(authorization.getAccessToken());
+		assertThat(authorizationResult.getTokens().getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken());
 		assertThat(authorizationResult.getAttributes()).isEqualTo(authorization.getAttributes());
 	}
 
@@ -97,13 +98,13 @@ public class OAuth2AuthorizationTests {
 	public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.accessToken(ACCESS_TOKEN)
+				.tokens(OAuth2Tokens.builder().accessToken(ACCESS_TOKEN).build())
 				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
 				.build();
 
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME);
-		assertThat(authorization.getAccessToken()).isEqualTo(ACCESS_TOKEN);
+		assertThat(authorization.getTokens().getAccessToken()).isEqualTo(ACCESS_TOKEN);
 		assertThat(authorization.getAttributes()).containsExactly(
 				entry(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE));
 	}

+ 2 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java

@@ -19,6 +19,7 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
 import java.time.Instant;
 import java.util.Collections;
@@ -52,7 +53,7 @@ public class TestOAuth2Authorizations {
 				.build();
 		return OAuth2Authorization.withRegisteredClient(registeredClient)
 				.principalName("principal")
-				.accessToken(accessToken)
+				.tokens(OAuth2Tokens.builder().accessToken(accessToken).build())
 				.attribute(OAuth2AuthorizationAttributeNames.CODE, "code")
 				.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest)
 				.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes());

+ 2 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java

@@ -203,8 +203,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 
 		assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
 		assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
-		assertThat(updatedAuthorization.getAccessToken()).isNotNull();
-		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken());
+		assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotNull();
+		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken());
 	}
 
 	private static Jwt createJwt() {

+ 3 - 3
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java

@@ -156,10 +156,10 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(clientPrincipal.getRegisteredClient().getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(clientPrincipal.getName());
-		assertThat(authorization.getAccessToken()).isNotNull();
-		assertThat(authorization.getAccessToken().getScopes()).isEqualTo(clientPrincipal.getRegisteredClient().getScopes());
+		assertThat(authorization.getTokens().getAccessToken()).isNotNull();
+		assertThat(authorization.getTokens().getAccessToken().getScopes()).isEqualTo(clientPrincipal.getRegisteredClient().getScopes());
 		assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
-		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(authorization.getAccessToken());
+		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken());
 	}
 
 	private static Jwt createJwt() {

+ 74 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenMetadataTests.java

@@ -0,0 +1,74 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.token;
+
+import org.junit.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Tests for {@link OAuth2TokenMetadata}.
+ *
+ * @author Joe Grandja
+ */
+public class OAuth2TokenMetadataTests {
+
+	@Test
+	public void metadataWhenNameNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() ->
+				OAuth2TokenMetadata.builder()
+						.metadata(null, "value"))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("name cannot be empty");
+	}
+
+	@Test
+	public void metadataWhenValueNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() ->
+				OAuth2TokenMetadata.builder()
+						.metadata("name", null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("value cannot be null");
+	}
+
+	@Test
+	public void getMetadataWhenNameNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2TokenMetadata.builder().build().getMetadata(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("name cannot be empty");
+	}
+
+	@Test
+	public void buildWhenDefaultThenDefaultsAreSet() {
+		OAuth2TokenMetadata tokenMetadata = OAuth2TokenMetadata.builder().build();
+		assertThat(tokenMetadata.getMetadata()).hasSize(1);
+		assertThat(tokenMetadata.isInvalidated()).isFalse();
+	}
+
+	@Test
+	public void buildWhenMetadataProvidedThenMetadataIsSet() {
+		OAuth2TokenMetadata tokenMetadata = OAuth2TokenMetadata.builder()
+				.invalidated()
+				.metadata("name1", "value1")
+				.metadata(metadata -> metadata.put("name2", "value2"))
+				.build();
+		assertThat(tokenMetadata.getMetadata()).hasSize(3);
+		assertThat(tokenMetadata.isInvalidated()).isTrue();
+		assertThat(tokenMetadata.<String>getMetadata("name1")).isEqualTo("value1");
+		assertThat(tokenMetadata.<String>getMetadata("name2")).isEqualTo("value2");
+	}
+}

+ 187 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java

@@ -0,0 +1,187 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.token;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
+
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.HashSet;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Tests for {@link OAuth2Tokens}.
+ *
+ * @author Joe Grandja
+ */
+public class OAuth2TokensTests {
+	private OAuth2AccessToken accessToken;
+	private OAuth2RefreshToken refreshToken;
+	private OidcIdToken idToken;
+
+	@Before
+	public void setUp() {
+		Instant issuedAt = Instant.now();
+		this.accessToken = new OAuth2AccessToken(
+				OAuth2AccessToken.TokenType.BEARER,
+				"access-token",
+				issuedAt,
+				issuedAt.plus(Duration.ofMinutes(5)),
+				new HashSet<>(Arrays.asList("read", "write")));
+		this.refreshToken = new OAuth2RefreshToken(
+				"refresh-token",
+				issuedAt);
+		this.idToken = OidcIdToken.withTokenValue("id-token")
+				.issuer("https://provider.com")
+				.subject("subject")
+				.issuedAt(issuedAt)
+				.expiresAt(issuedAt.plus(Duration.ofMinutes(30)))
+				.build();
+	}
+
+	@Test
+	public void accessTokenWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2Tokens.builder().accessToken(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("token cannot be null");
+	}
+
+	@Test
+	public void refreshTokenWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2Tokens.builder().refreshToken(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("token cannot be null");
+	}
+
+	@Test
+	public void tokenWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2Tokens.builder().token(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("token cannot be null");
+	}
+
+	@Test
+	public void getTokenWhenTokenTypeNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2Tokens.builder().build().getToken(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("tokenType cannot be null");
+	}
+
+	@Test
+	public void getTokenMetadataWhenTokenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2Tokens.builder().build().getTokenMetadata(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("token cannot be null");
+	}
+
+	@Test
+	public void buildWhenTokenMetadataNotProvidedThenDefaultsAreSet() {
+		OAuth2Tokens tokens = OAuth2Tokens.builder()
+				.accessToken(this.accessToken)
+				.refreshToken(this.refreshToken)
+				.token(this.idToken)
+				.build();
+
+		assertThat(tokens.getAccessToken()).isEqualTo(this.accessToken);
+		OAuth2TokenMetadata tokenMetadata = tokens.getTokenMetadata(tokens.getAccessToken());
+		assertThat(tokenMetadata.isInvalidated()).isFalse();
+
+		assertThat(tokens.getRefreshToken()).isEqualTo(this.refreshToken);
+		tokenMetadata = tokens.getTokenMetadata(tokens.getRefreshToken());
+		assertThat(tokenMetadata.isInvalidated()).isFalse();
+
+		assertThat(tokens.getToken(OidcIdToken.class)).isEqualTo(this.idToken);
+		tokenMetadata = tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class));
+		assertThat(tokenMetadata.isInvalidated()).isFalse();
+	}
+
+	@Test
+	public void buildWhenTokenMetadataProvidedThenTokenMetadataIsSet() {
+		OAuth2TokenMetadata expectedTokenMetadata = OAuth2TokenMetadata.builder().build();
+		OAuth2Tokens tokens = OAuth2Tokens.builder()
+				.accessToken(this.accessToken, expectedTokenMetadata)
+				.refreshToken(this.refreshToken, expectedTokenMetadata)
+				.token(this.idToken, expectedTokenMetadata)
+				.build();
+
+		assertThat(tokens.getAccessToken()).isEqualTo(this.accessToken);
+		OAuth2TokenMetadata tokenMetadata = tokens.getTokenMetadata(tokens.getAccessToken());
+		assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata);
+
+		assertThat(tokens.getRefreshToken()).isEqualTo(this.refreshToken);
+		tokenMetadata = tokens.getTokenMetadata(tokens.getRefreshToken());
+		assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata);
+
+		assertThat(tokens.getToken(OidcIdToken.class)).isEqualTo(this.idToken);
+		tokenMetadata = tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class));
+		assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata);
+	}
+
+	@Test
+	public void getTokenMetadataWhenTokenNotFoundThenNull() {
+		OAuth2TokenMetadata expectedTokenMetadata = OAuth2TokenMetadata.builder().build();
+		OAuth2Tokens tokens = OAuth2Tokens.builder()
+				.accessToken(this.accessToken, expectedTokenMetadata)
+				.build();
+
+		assertThat(tokens.getAccessToken()).isEqualTo(this.accessToken);
+		OAuth2TokenMetadata tokenMetadata = tokens.getTokenMetadata(tokens.getAccessToken());
+		assertThat(tokenMetadata).isEqualTo(expectedTokenMetadata);
+
+		OAuth2AccessToken otherAccessToken = new OAuth2AccessToken(
+				this.accessToken.getTokenType(),
+				"other-access-token",
+				this.accessToken.getIssuedAt(),
+				this.accessToken.getExpiresAt(),
+				this.accessToken.getScopes());
+		assertThat(tokens.getTokenMetadata(otherAccessToken)).isNull();
+	}
+
+	@Test
+	public void invalidateWhenAllTokensThenAllInvalidated() {
+		OAuth2Tokens tokens = OAuth2Tokens.builder()
+				.accessToken(this.accessToken)
+				.refreshToken(this.refreshToken)
+				.token(this.idToken)
+				.build();
+		tokens.invalidate();
+
+		assertThat(tokens.getTokenMetadata(tokens.getAccessToken()).isInvalidated()).isTrue();
+		assertThat(tokens.getTokenMetadata(tokens.getRefreshToken()).isInvalidated()).isTrue();
+		assertThat(tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class)).isInvalidated()).isTrue();
+	}
+
+	@Test
+	public void invalidateWhenTokenProvidedThenInvalidated() {
+		OAuth2Tokens tokens = OAuth2Tokens.builder()
+				.accessToken(this.accessToken)
+				.refreshToken(this.refreshToken)
+				.token(this.idToken)
+				.build();
+		tokens.invalidate(this.accessToken);
+
+		assertThat(tokens.getTokenMetadata(tokens.getAccessToken()).isInvalidated()).isTrue();
+		assertThat(tokens.getTokenMetadata(tokens.getRefreshToken()).isInvalidated()).isFalse();
+		assertThat(tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class)).isInvalidated()).isFalse();
+	}
+}

+ 1 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -755,7 +755,7 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
 		assertThat(updatedAuthorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
 		assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
-		assertThat(updatedAuthorization.getAccessToken()).isNotNull();
+		assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotNull();
 		assertThat(updatedAuthorization.<String>getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull();
 		assertThat(updatedAuthorization.<String>getAttribute(OAuth2AuthorizationAttributeNames.CODE)).isNotNull();
 		assertThat(updatedAuthorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST))