浏览代码

Enforce one-time use for authorization code

Closes gh-138
Joe Grandja 4 年之前
父节点
当前提交
18f8b3afaa
共有 14 个文件被更改,包括 201 次插入45 次删除
  1. 3 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java
  2. 1 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java
  3. 1 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java
  4. 18 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java
  5. 43 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2AuthorizationCode.java
  6. 17 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java
  7. 18 8
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java
  8. 6 6
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java
  9. 12 10
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java
  10. 9 8
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java
  11. 4 2
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java
  12. 34 1
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java
  13. 29 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java
  14. 6 6
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

+ 3 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationService.java

@@ -16,6 +16,7 @@
 package org.springframework.security.oauth2.server.authorization;
 
 import org.springframework.lang.Nullable;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
 import org.springframework.util.Assert;
 
 import java.io.Serializable;
@@ -63,7 +64,8 @@ public final class InMemoryOAuth2AuthorizationService implements OAuth2Authoriza
 		if (OAuth2AuthorizationAttributeNames.STATE.equals(tokenType.getValue())) {
 			return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE));
 		} else if (TokenType.AUTHORIZATION_CODE.equals(tokenType)) {
-			return token.equals(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE));
+			OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
+			return authorizationCode != null && authorizationCode.getTokenValue().equals(token);
 		} else if (TokenType.ACCESS_TOKEN.equals(tokenType)) {
 			return authorization.getTokens().getAccessToken() != null &&
 					authorization.getTokens().getAccessToken().getTokenValue().equals(token);

+ 1 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2Authorization.java

@@ -152,7 +152,7 @@ public class OAuth2Authorization implements Serializable {
 		Assert.notNull(authorization, "authorization cannot be null");
 		return new Builder(authorization.getRegisteredClientId())
 				.principalName(authorization.getPrincipalName())
-				.tokens(authorization.getTokens())
+				.tokens(OAuth2Tokens.from(authorization.getTokens()).build())
 				.attributes(attrs -> attrs.putAll(authorization.getAttributes()));
 	}
 

+ 1 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java

@@ -38,6 +38,7 @@ public interface OAuth2AuthorizationAttributeNames {
 	/**
 	 * The name of the attribute used for the {@link OAuth2ParameterNames#CODE} parameter.
 	 */
+	@Deprecated
 	String CODE = OAuth2Authorization.class.getName().concat(".CODE");
 
 	/**

+ 18 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -35,6 +35,8 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
 import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
@@ -102,11 +104,15 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 		if (authorization == null) {
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
+		OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
 
 		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
 				OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
 
 		if (!registeredClient.getClientId().equals(authorizationRequest.getClientId())) {
+			// Invalidate the authorization code given that a different client is attempting to use it
+			authorization.getTokens().invalidate(authorizationCode);
+			this.authorizationService.save(authorization);
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
 
@@ -115,6 +121,12 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
 
+		OAuth2TokenMetadata authorizationCodeMetadata = authorization.getTokens().getTokenMetadata(authorizationCode);
+		if (authorizationCodeMetadata.isInvalidated()) {
+			// Prevent the same client from using the authorization code more than once
+			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
+		}
+
 		JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
 
 		// TODO Allow configuration for issuer claim
@@ -142,9 +154,14 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
 				jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE));
 
+		OAuth2Tokens tokens = OAuth2Tokens.from(authorization.getTokens())
+				.accessToken(accessToken)
+				.build();
+		tokens.invalidate(authorizationCode);		// Invalidate the authorization code as it can only be used once
+
 		authorization = OAuth2Authorization.from(authorization)
+				.tokens(tokens)
 				.attribute(OAuth2AuthorizationAttributeNames.ACCESS_TOKEN_ATTRIBUTES, jwt)
-				.tokens(OAuth2Tokens.builder().accessToken(accessToken).build())
 				.build();
 		this.authorizationService.save(authorization);
 

+ 43 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2AuthorizationCode.java

@@ -0,0 +1,43 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.token;
+
+import org.springframework.security.oauth2.core.AbstractOAuth2Token;
+
+import java.time.Instant;
+
+/**
+ * An implementation of an {@link AbstractOAuth2Token}
+ * representing an OAuth 2.0 Authorization Code Grant.
+ *
+ * @author Joe Grandja
+ * @since 0.0.3
+ * @see AbstractOAuth2Token
+ * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
+ */
+public class OAuth2AuthorizationCode extends AbstractOAuth2Token {
+
+	/**
+	 * Constructs an {@code OAuth2AuthorizationCode} using the provided parameters.
+	 * @param tokenValue the token value
+	 * @param issuedAt the time at which the token was issued
+	 * @param expiresAt the time at which the token expires
+	 */
+	public OAuth2AuthorizationCode(String tokenValue, Instant issuedAt, Instant expiresAt) {
+		super(tokenValue, issuedAt, expiresAt);
+	}
+
+}

+ 17 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2Tokens.java

@@ -146,14 +146,30 @@ public class OAuth2Tokens implements Serializable {
 		return new Builder();
 	}
 
+	/**
+	 * Returns a new {@link Builder}, initialized with the values from the provided {@code tokens}.
+	 *
+	 * @param tokens the tokens used for initializing the {@link Builder}
+	 * @return the {@link Builder}
+	 */
+	public static Builder from(OAuth2Tokens tokens) {
+		Assert.notNull(tokens, "tokens cannot be null");
+		return new Builder(tokens.tokens);
+	}
+
 	/**
 	 * A builder for {@link OAuth2Tokens}.
 	 */
 	public static class Builder implements Serializable {
 		private static final long serialVersionUID = Version.SERIAL_VERSION_UID;
-		private final Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens = new HashMap<>();
+		private Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens;
 
 		protected Builder() {
+			this.tokens = new HashMap<>();
+		}
+
+		protected Builder(Map<Class<? extends AbstractOAuth2Token>, OAuth2TokenHolder> tokens) {
+			this.tokens = new HashMap<>(tokens);
 		}
 
 		/**

+ 18 - 8
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -36,6 +36,8 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
 import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 import org.springframework.security.web.DefaultRedirectStrategy;
 import org.springframework.security.web.RedirectStrategy;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@@ -53,6 +55,8 @@ import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import java.io.IOException;
 import java.nio.charset.StandardCharsets;
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
 import java.util.Arrays;
 import java.util.Base64;
 import java.util.Collections;
@@ -184,9 +188,12 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 
 			UserConsentPage.displayConsent(request, response, registeredClient, authorization);
 		} else {
-			String code = this.codeGenerator.generateKey();
+			Instant issuedAt = Instant.now();
+			Instant expiresAt = issuedAt.plus(5, ChronoUnit.MINUTES);		// TODO Allow configuration for authorization code time-to-live
+			OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
+					this.codeGenerator.generateKey(), issuedAt, expiresAt);
 			OAuth2Authorization authorization = builder
-					.attribute(OAuth2AuthorizationAttributeNames.CODE, code)
+					.tokens(OAuth2Tokens.builder().token(authorizationCode).build())
 					.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes())
 					.build();
 			this.authorizationService.save(authorization);
@@ -200,7 +207,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 //			The authorization code is bound to the client identifier and redirection URI.
 
 			sendAuthorizationResponse(request, response,
-					authorizationRequestContext.resolveRedirectUri(), code, authorizationRequest.getState());
+					authorizationRequestContext.resolveRedirectUri(), authorizationCode, authorizationRequest.getState());
 		}
 	}
 
@@ -232,18 +239,21 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 			return;
 		}
 
-		String code = this.codeGenerator.generateKey();
+		Instant issuedAt = Instant.now();
+		Instant expiresAt = issuedAt.plus(5, ChronoUnit.MINUTES);		// TODO Allow configuration for authorization code time-to-live
+		OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
+				this.codeGenerator.generateKey(), issuedAt, expiresAt);
 		OAuth2Authorization authorization = OAuth2Authorization.from(userConsentRequestContext.getAuthorization())
+				.tokens(OAuth2Tokens.builder().token(authorizationCode).build())
 				.attributes(attrs -> {
 					attrs.remove(OAuth2AuthorizationAttributeNames.STATE);
-					attrs.put(OAuth2AuthorizationAttributeNames.CODE, code);
 					attrs.put(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, userConsentRequestContext.getScopes());
 				})
 				.build();
 		this.authorizationService.save(authorization);
 
 		sendAuthorizationResponse(request, response, userConsentRequestContext.resolveRedirectUri(),
-				code, userConsentRequestContext.getAuthorizationRequest().getState());
+				authorizationCode, userConsentRequestContext.getAuthorizationRequest().getState());
 	}
 
 	private void validateAuthorizationRequest(OAuth2AuthorizationRequestContext authorizationRequestContext) {
@@ -389,11 +399,11 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 	}
 
 	private void sendAuthorizationResponse(HttpServletRequest request, HttpServletResponse response,
-			String redirectUri, String code, String state) throws IOException {
+			String redirectUri, OAuth2AuthorizationCode authorizationCode, String state) throws IOException {
 
 		UriComponentsBuilder uriBuilder = UriComponentsBuilder
 				.fromUriString(redirectUri)
-				.queryParam(OAuth2ParameterNames.CODE, code);
+				.queryParam(OAuth2ParameterNames.CODE, authorizationCode.getTokenValue());
 		if (StringUtils.hasText(state)) {
 			uriBuilder.queryParam(OAuth2ParameterNames.STATE, state);
 		}

+ 6 - 6
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java

@@ -34,13 +34,13 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResp
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
-import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
 import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
 import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
 import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter;
 import org.springframework.test.web.servlet.MockMvc;
@@ -153,7 +153,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
 		when(authorizationService.findByToken(
-				eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)),
+				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
 				eq(TokenType.AUTHORIZATION_CODE)))
 				.thenReturn(authorization);
 
@@ -167,7 +167,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 		verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
 		verify(authorizationService).findByToken(
-				eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)),
+				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
 				eq(TokenType.AUTHORIZATION_CODE));
 		verify(authorizationService).save(any());
 	}
@@ -199,7 +199,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 		OAuth2Authorization authorization = authorizationCaptor.getValue();
 
 		when(authorizationService.findByToken(
-				eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)),
+				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
 				eq(TokenType.AUTHORIZATION_CODE)))
 				.thenReturn(authorization);
 
@@ -212,7 +212,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 		verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId()));
 		verify(authorizationService, times(2)).findByToken(
-				eq(authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE)),
+				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
 				eq(TokenType.AUTHORIZATION_CODE));
 		verify(authorizationService, times(2)).save(any());
 	}
@@ -232,7 +232,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 			OAuth2Authorization authorization) {
 		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
 		parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
-		parameters.set(OAuth2ParameterNames.CODE, authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE));
+		parameters.set(OAuth2ParameterNames.CODE, authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue());
 		parameters.set(OAuth2ParameterNames.REDIRECT_URI, registeredClient.getRedirectUris().iterator().next());
 		return parameters;
 	}

+ 12 - 10
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/InMemoryOAuth2AuthorizationServiceTests.java

@@ -20,9 +20,11 @@ import org.junit.Test;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
 import java.time.Instant;
+import java.time.temporal.ChronoUnit;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -36,7 +38,8 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
 public class InMemoryOAuth2AuthorizationServiceTests {
 	private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
 	private static final String PRINCIPAL_NAME = "principal";
-	private static final String AUTHORIZATION_CODE = "code";
+	private static final OAuth2AuthorizationCode AUTHORIZATION_CODE = new OAuth2AuthorizationCode(
+			"code", Instant.now(), Instant.now().plus(5, ChronoUnit.MINUTES));
 	private InMemoryOAuth2AuthorizationService authorizationService;
 
 	@Before
@@ -55,12 +58,12 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 	public void saveWhenAuthorizationProvidedThenSaved() {
 		OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
+				.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build())
 				.build();
 		this.authorizationService.save(expectedAuthorization);
 
 		OAuth2Authorization authorization = this.authorizationService.findByToken(
-				AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE);
+				AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE);
 		assertThat(authorization).isEqualTo(expectedAuthorization);
 	}
 
@@ -75,17 +78,17 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 	public void removeWhenAuthorizationProvidedThenRemoved() {
 		OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
+				.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build())
 				.build();
 
 		this.authorizationService.save(expectedAuthorization);
 		OAuth2Authorization authorization = this.authorizationService.findByToken(
-				AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE);
+				AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE);
 		assertThat(authorization).isEqualTo(expectedAuthorization);
 
 		this.authorizationService.remove(expectedAuthorization);
 		authorization = this.authorizationService.findByToken(
-				AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE);
+				AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE);
 		assertThat(authorization).isNull();
 	}
 
@@ -114,12 +117,12 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 	public void findByTokenWhenTokenTypeAuthorizationCodeThenFound() {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
+				.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).build())
 				.build();
 		this.authorizationService.save(authorization);
 
 		OAuth2Authorization result = this.authorizationService.findByToken(
-				AUTHORIZATION_CODE, TokenType.AUTHORIZATION_CODE);
+				AUTHORIZATION_CODE.getTokenValue(), TokenType.AUTHORIZATION_CODE);
 		assertThat(authorization).isEqualTo(result);
 	}
 
@@ -129,8 +132,7 @@ public class InMemoryOAuth2AuthorizationServiceTests {
 				"access-token", Instant.now().minusSeconds(60), Instant.now());
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
-				.tokens(OAuth2Tokens.builder().accessToken(accessToken).build())
+				.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(accessToken).build())
 				.build();
 		this.authorizationService.save(authorization);
 

+ 9 - 8
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationTests.java

@@ -19,13 +19,14 @@ import org.junit.Test;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
 import java.time.Instant;
+import java.time.temporal.ChronoUnit;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
-import static org.assertj.core.data.MapEntry.entry;
 
 /**
  * Tests for {@link OAuth2Authorization}.
@@ -38,7 +39,8 @@ public class OAuth2AuthorizationTests {
 	private static final String PRINCIPAL_NAME = "principal";
 	private static final OAuth2AccessToken ACCESS_TOKEN = new OAuth2AccessToken(
 			OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
-	private static final String AUTHORIZATION_CODE = "code";
+	private static final OAuth2AuthorizationCode AUTHORIZATION_CODE = new OAuth2AuthorizationCode(
+			"code", Instant.now(), Instant.now().plus(5, ChronoUnit.MINUTES));
 
 	@Test
 	public void withRegisteredClientWhenRegisteredClientNullThenThrowIllegalArgumentException() {
@@ -58,14 +60,15 @@ public class OAuth2AuthorizationTests {
 	public void fromWhenAuthorizationProvidedThenCopied() {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.tokens(OAuth2Tokens.builder().accessToken(ACCESS_TOKEN).build())
-				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
+				.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(ACCESS_TOKEN).build())
 				.build();
 		OAuth2Authorization authorizationResult = OAuth2Authorization.from(authorization).build();
 
 		assertThat(authorizationResult.getRegisteredClientId()).isEqualTo(authorization.getRegisteredClientId());
 		assertThat(authorizationResult.getPrincipalName()).isEqualTo(authorization.getPrincipalName());
 		assertThat(authorizationResult.getTokens().getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken());
+		assertThat(authorizationResult.getTokens().getToken(OAuth2AuthorizationCode.class))
+				.isEqualTo(authorization.getTokens().getToken(OAuth2AuthorizationCode.class));
 		assertThat(authorizationResult.getAttributes()).isEqualTo(authorization.getAttributes());
 	}
 
@@ -98,14 +101,12 @@ public class OAuth2AuthorizationTests {
 	public void buildWhenAllAttributesAreProvidedThenAllAttributesAreSet() {
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
 				.principalName(PRINCIPAL_NAME)
-				.tokens(OAuth2Tokens.builder().accessToken(ACCESS_TOKEN).build())
-				.attribute(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE)
+				.tokens(OAuth2Tokens.builder().token(AUTHORIZATION_CODE).accessToken(ACCESS_TOKEN).build())
 				.build();
 
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(REGISTERED_CLIENT.getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(PRINCIPAL_NAME);
+		assertThat(authorization.getTokens().getToken(OAuth2AuthorizationCode.class)).isEqualTo(AUTHORIZATION_CODE);
 		assertThat(authorization.getTokens().getAccessToken()).isEqualTo(ACCESS_TOKEN);
-		assertThat(authorization.getAttributes()).containsExactly(
-				entry(OAuth2AuthorizationAttributeNames.CODE, AUTHORIZATION_CODE));
 	}
 }

+ 4 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java

@@ -19,6 +19,7 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
 import java.time.Instant;
@@ -41,6 +42,8 @@ public class TestOAuth2Authorizations {
 
 	public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient,
 			Map<String, Object> authorizationRequestAdditionalParameters) {
+		OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
+				"code", Instant.now(), Instant.now().plusSeconds(120));
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(
 				OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
 		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
@@ -53,8 +56,7 @@ public class TestOAuth2Authorizations {
 				.build();
 		return OAuth2Authorization.withRegisteredClient(registeredClient)
 				.principalName("principal")
-				.tokens(OAuth2Tokens.builder().accessToken(accessToken).build())
-				.attribute(OAuth2AuthorizationAttributeNames.CODE, "code")
+				.tokens(OAuth2Tokens.builder().token(authorizationCode).accessToken(accessToken).build())
 				.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest)
 				.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes());
 	}

+ 34 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java

@@ -37,6 +37,8 @@ import org.springframework.security.oauth2.server.authorization.client.InMemoryR
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
@@ -153,6 +155,12 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
 				.extracting("errorCode")
 				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+
+		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+		verify(this.authorizationService).save(authorizationCaptor.capture());
+		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
+		OAuth2AuthorizationCode authorizationCode = updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class);
+		assertThat(updatedAuthorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()).isTrue();
 	}
 
 	@Test
@@ -173,6 +181,30 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
 	}
 
+	@Test
+	public void authenticateWhenInvalidatedCodeThenThrowOAuth2AuthenticationException() {
+		OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
+				AUTHORIZATION_CODE, Instant.now(), Instant.now().plusSeconds(120));
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization()
+				.tokens(OAuth2Tokens.builder().token(authorizationCode).build())
+				.build();
+		authorization.getTokens().invalidate(authorizationCode);
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(TokenType.AUTHORIZATION_CODE)))
+				.thenReturn(authorization);
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(this.registeredClient);
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+				OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null);
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.extracting("errorCode")
+				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
+	}
+
 	@Test
 	public void authenticateWhenValidCodeThenReturnAccessToken() {
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
@@ -203,8 +235,9 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 
 		assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
 		assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
-		assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotNull();
 		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getTokens().getAccessToken());
+		OAuth2AuthorizationCode authorizationCode = updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class);
+		assertThat(updatedAuthorization.getTokens().getTokenMetadata(authorizationCode).isInvalidated()).isTrue();
 	}
 
 	private static Jwt createJwt() {

+ 29 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokensTests.java

@@ -94,6 +94,35 @@ public class OAuth2TokensTests {
 				.hasMessage("token cannot be null");
 	}
 
+	@Test
+	public void fromWhenTokensNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> OAuth2Tokens.from(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("tokens cannot be null");
+	}
+
+	@Test
+	public void fromWhenTokensProvidedThenCopied() {
+		OAuth2Tokens tokens = OAuth2Tokens.builder()
+				.accessToken(this.accessToken)
+				.refreshToken(this.refreshToken)
+				.token(this.idToken)
+				.build();
+		OAuth2Tokens tokensResult = OAuth2Tokens.from(tokens).build();
+
+		assertThat(tokensResult.getAccessToken()).isEqualTo(tokens.getAccessToken());
+		assertThat(tokensResult.getTokenMetadata(tokensResult.getAccessToken()))
+				.isEqualTo(tokens.getTokenMetadata(tokens.getAccessToken()));
+
+		assertThat(tokensResult.getRefreshToken()).isEqualTo(tokens.getRefreshToken());
+		assertThat(tokensResult.getTokenMetadata(tokensResult.getRefreshToken()))
+				.isEqualTo(tokens.getTokenMetadata(tokens.getRefreshToken()));
+
+		assertThat(tokensResult.getToken(OidcIdToken.class)).isEqualTo(tokens.getToken(OidcIdToken.class));
+		assertThat(tokensResult.getTokenMetadata(tokensResult.getToken(OidcIdToken.class)))
+				.isEqualTo(tokens.getTokenMetadata(tokens.getToken(OidcIdToken.class)));
+	}
+
 	@Test
 	public void buildWhenTokenMetadataNotProvidedThenDefaultsAreSet() {
 		OAuth2Tokens tokens = OAuth2Tokens.builder()

+ 6 - 6
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -40,6 +40,7 @@ import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
 import org.springframework.util.StringUtils;
 
 import javax.servlet.FilterChain;
@@ -434,8 +435,8 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
 
-		String code = authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE);
-		assertThat(code).isNotNull();
+		OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
+		assertThat(authorizationCode).isNotNull();
 
 		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
 		assertThat(authorizationRequest).isNotNull();
@@ -481,8 +482,8 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
 
-		String code = authorization.getAttribute(OAuth2AuthorizationAttributeNames.CODE);
-		assertThat(code).isNotNull();
+		OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
+		assertThat(authorizationCode).isNotNull();
 
 		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
 		assertThat(authorizationRequest).isNotNull();
@@ -755,9 +756,8 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
 		assertThat(updatedAuthorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
 		assertThat(updatedAuthorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
-		assertThat(updatedAuthorization.getTokens().getAccessToken()).isNotNull();
+		assertThat(updatedAuthorization.getTokens().getToken(OAuth2AuthorizationCode.class)).isNotNull();
 		assertThat(updatedAuthorization.<String>getAttribute(OAuth2AuthorizationAttributeNames.STATE)).isNull();
-		assertThat(updatedAuthorization.<String>getAttribute(OAuth2AuthorizationAttributeNames.CODE)).isNotNull();
 		assertThat(updatedAuthorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST))
 				.isEqualTo(authorization.<OAuth2AuthorizationRequest>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST));
 		assertThat(updatedAuthorization.<Set<String>>getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES))