فهرست منبع

Use OAuth2TokenGenerator for OAuth2AuthorizationCode

Closes gh-639
Joe Grandja 3 سال پیش
والد
کامیت
a661e1cdb7

+ 97 - 13
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProvider.java

@@ -28,6 +28,7 @@ import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.function.Supplier;
 
+import org.springframework.lang.Nullable;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.core.Authentication;
@@ -46,12 +47,16 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequ
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
+import org.springframework.security.oauth2.server.authorization.DefaultOAuth2TokenContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.context.ProviderContextHolder;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 import org.springframework.web.util.UriComponents;
@@ -72,10 +77,9 @@ import org.springframework.web.util.UriComponentsBuilder;
  * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.1">Section 4.1.1 Authorization Request</a>
  */
 public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implements AuthenticationProvider {
-	private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE);
+	private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1";
 	private static final String PKCE_ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc7636#section-4.4.1";
-	private static final StringKeyGenerator DEFAULT_AUTHORIZATION_CODE_GENERATOR =
-			new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
+	private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE);
 	private static final StringKeyGenerator DEFAULT_STATE_GENERATOR =
 			new Base64StringKeyGenerator(Base64.getUrlEncoder());
 	private static final Function<String, OAuth2AuthenticationValidator> DEFAULT_AUTHENTICATION_VALIDATOR_RESOLVER =
@@ -83,7 +87,11 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 	private final RegisteredClientRepository registeredClientRepository;
 	private final OAuth2AuthorizationService authorizationService;
 	private final OAuth2AuthorizationConsentService authorizationConsentService;
-	private Supplier<String> authorizationCodeGenerator = DEFAULT_AUTHORIZATION_CODE_GENERATOR::generateKey;
+
+	@Deprecated
+	private Supplier<String> authorizationCodeSupplier;
+
+	private OAuth2TokenGenerator<OAuth2AuthorizationCode> authorizationCodeGenerator = new OAuth2AuthorizationCodeGenerator();
 	private Function<String, OAuth2AuthenticationValidator> authenticationValidatorResolver = DEFAULT_AUTHENTICATION_VALIDATOR_RESOLVER;
 	private Consumer<OAuth2AuthorizationConsentAuthenticationContext> authorizationConsentCustomizer;
 
@@ -122,9 +130,22 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 	/**
 	 * Sets the {@code Supplier<String>} that generates the value for the {@link OAuth2AuthorizationCode}.
 	 *
+	 * @deprecated Use {@link #setAuthorizationCodeGenerator(OAuth2TokenGenerator)} instead
 	 * @param authorizationCodeGenerator the {@code Supplier<String>} that generates the value for the {@link OAuth2AuthorizationCode}
 	 */
+	@Deprecated
 	public void setAuthorizationCodeGenerator(Supplier<String> authorizationCodeGenerator) {
+		Assert.notNull(authorizationCodeGenerator, "authorizationCodeGenerator cannot be null");
+		this.authorizationCodeSupplier = authorizationCodeGenerator;
+	}
+
+	/**
+	 * Sets the {@link OAuth2TokenGenerator} that generates the {@link OAuth2AuthorizationCode}.
+	 *
+	 * @param authorizationCodeGenerator the {@link OAuth2TokenGenerator} that generates the {@link OAuth2AuthorizationCode}
+	 * @since 0.2.3
+	 */
+	public void setAuthorizationCodeGenerator(OAuth2TokenGenerator<OAuth2AuthorizationCode> authorizationCodeGenerator) {
 		Assert.notNull(authorizationCodeGenerator, "authorizationCodeGenerator cannot be null");
 		this.authorizationCodeGenerator = authorizationCodeGenerator;
 	}
@@ -258,7 +279,22 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 					.build();
 		}
 
-		OAuth2AuthorizationCode authorizationCode = generateAuthorizationCode();
+		OAuth2AuthorizationCode authorizationCode;
+		if (this.authorizationCodeSupplier != null) {
+			Instant issuedAt = Instant.now();
+			Instant expiresAt = issuedAt.plus(5, ChronoUnit.MINUTES);		// TODO Allow configuration for authorization code time-to-live
+			authorizationCode = new OAuth2AuthorizationCode(this.authorizationCodeSupplier.get(), issuedAt, expiresAt);
+		} else {
+			OAuth2TokenContext tokenContext = createAuthorizationCodeTokenContext(
+					authorizationCodeRequestAuthentication, registeredClient, null, authorizationRequest.getScopes());
+			authorizationCode = this.authorizationCodeGenerator.generate(tokenContext);
+			if (authorizationCode == null) {
+				OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
+						"The token generator failed to generate the authorization code.", ERROR_URI);
+				throw new OAuth2AuthorizationCodeRequestAuthenticationException(error, null);
+			}
+		}
+
 		OAuth2Authorization authorization = authorizationBuilder(registeredClient, principal, authorizationRequest)
 				.token(authorizationCode)
 				.attribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, authorizationRequest.getScopes())
@@ -286,12 +322,6 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 				DEFAULT_AUTHENTICATION_VALIDATOR_RESOLVER.apply(parameterName);
 	}
 
-	private OAuth2AuthorizationCode generateAuthorizationCode() {
-		Instant issuedAt = Instant.now();
-		Instant expiresAt = issuedAt.plus(5, ChronoUnit.MINUTES);		// TODO Allow configuration for authorization code time-to-live
-		return new OAuth2AuthorizationCode(this.authorizationCodeGenerator.get(), issuedAt, expiresAt);
-	}
-
 	private Authentication authenticateAuthorizationConsent(Authentication authentication) throws AuthenticationException {
 		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
 				(OAuth2AuthorizationCodeRequestAuthenticationToken) authentication;
@@ -383,7 +413,21 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 			this.authorizationConsentService.save(authorizationConsent);
 		}
 
-		OAuth2AuthorizationCode authorizationCode = generateAuthorizationCode();
+		OAuth2AuthorizationCode authorizationCode;
+		if (this.authorizationCodeSupplier != null) {
+			Instant issuedAt = Instant.now();
+			Instant expiresAt = issuedAt.plus(5, ChronoUnit.MINUTES);		// TODO Allow configuration for authorization code time-to-live
+			authorizationCode = new OAuth2AuthorizationCode(this.authorizationCodeSupplier.get(), issuedAt, expiresAt);
+		} else {
+			OAuth2TokenContext tokenContext = createAuthorizationCodeTokenContext(
+					authorizationCodeRequestAuthentication, registeredClient, authorization, authorizedScopes);
+			authorizationCode = this.authorizationCodeGenerator.generate(tokenContext);
+			if (authorizationCode == null) {
+				OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
+						"The token generator failed to generate the authorization code.", ERROR_URI);
+				throw new OAuth2AuthorizationCodeRequestAuthenticationException(error, null);
+			}
+		}
 
 		OAuth2Authorization updatedAuthorization = OAuth2Authorization.from(authorization)
 				.token(authorizationCode)
@@ -424,6 +468,28 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 				.attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest);
 	}
 
+	private static OAuth2TokenContext createAuthorizationCodeTokenContext(
+			OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
+			RegisteredClient registeredClient, OAuth2Authorization authorization, Set<String> authorizedScopes) {
+
+		// @formatter:off
+		DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
+				.registeredClient(registeredClient)
+				.principal((Authentication) authorizationCodeRequestAuthentication.getPrincipal())
+				.providerContext(ProviderContextHolder.getProviderContext())
+				.tokenType(new OAuth2TokenType(OAuth2ParameterNames.CODE))
+				.authorizedScopes(authorizedScopes)
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.authorizationGrant(authorizationCodeRequestAuthentication);
+		// @formatter:on
+
+		if (authorization != null) {
+			tokenContextBuilder.authorization(authorization);
+		}
+
+		return tokenContextBuilder.build();
+	}
+
 	private static boolean requireAuthorizationConsent(RegisteredClient registeredClient,
 			OAuth2AuthorizationRequest authorizationRequest, OAuth2AuthorizationConsent authorizationConsent) {
 
@@ -522,7 +588,7 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 	private static void throwError(String errorCode, String parameterName,
 			OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication,
 			RegisteredClient registeredClient, OAuth2AuthorizationRequest authorizationRequest) {
-		throwError(errorCode, parameterName, "https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.2.1",
+		throwError(errorCode, parameterName, ERROR_URI,
 				authorizationCodeRequestAuthentication, registeredClient, authorizationRequest);
 	}
 
@@ -580,6 +646,24 @@ public final class OAuth2AuthorizationCodeRequestAuthenticationProvider implemen
 				.authorizationCode(authorizationCodeRequestAuthentication.getAuthorizationCode());
 	}
 
+	private static class OAuth2AuthorizationCodeGenerator implements OAuth2TokenGenerator<OAuth2AuthorizationCode> {
+		private final StringKeyGenerator authorizationCodeGenerator =
+				new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
+
+		@Nullable
+		@Override
+		public OAuth2AuthorizationCode generate(OAuth2TokenContext context) {
+			if (context.getTokenType() == null ||
+					!OAuth2ParameterNames.CODE.equals(context.getTokenType().getValue())) {
+				return null;
+			}
+			Instant issuedAt = Instant.now();
+			Instant expiresAt = issuedAt.plus(5, ChronoUnit.MINUTES);		// TODO Allow configuration for authorization code time-to-live
+			return new OAuth2AuthorizationCode(this.authorizationCodeGenerator.generateKey(), issuedAt, expiresAt);
+		}
+
+	}
+
 	private static class DefaultRedirectUriOAuth2AuthenticationValidator implements OAuth2AuthenticationValidator {
 
 		@Override

+ 33 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeRequestAuthenticationProviderTests.java

@@ -46,11 +46,15 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
+import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
+import org.springframework.security.oauth2.server.authorization.context.ProviderContext;
+import org.springframework.security.oauth2.server.authorization.context.ProviderContextHolder;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -86,6 +90,8 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
 				this.registeredClientRepository, this.authorizationService, this.authorizationConsentService);
 		this.principal = new TestingAuthenticationToken("principalName", "password");
 		this.principal.setAuthenticated(true);
+		ProviderSettings providerSettings = ProviderSettings.builder().issuer("https://provider.com").build();
+		ProviderContextHolder.setProviderContext(new ProviderContext(providerSettings, null));
 	}
 
 	@Test
@@ -119,7 +125,10 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
 
 	@Test
 	public void setAuthorizationCodeGeneratorWhenNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> this.authenticationProvider.setAuthorizationCodeGenerator(null))
+		assertThatThrownBy(() -> this.authenticationProvider.setAuthorizationCodeGenerator((Supplier<String>) null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationCodeGenerator cannot be null");
+		assertThatThrownBy(() -> this.authenticationProvider.setAuthorizationCodeGenerator((OAuth2TokenGenerator<OAuth2AuthorizationCode>) null))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("authorizationCodeGenerator cannot be null");
 	}
@@ -533,6 +542,29 @@ public class OAuth2AuthorizationCodeRequestAuthenticationProviderTests {
 		assertThat(authenticationResult.getAuthorizationCode().getTokenValue()).isEqualTo(authorizationCodeGenerator.get());
 	}
 
+	@Test
+	public void authenticateWhenAuthorizationCodeNotGeneratedThenThrowOAuth2AuthorizationCodeRequestAuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		when(this.registeredClientRepository.findByClientId(eq(registeredClient.getClientId())))
+				.thenReturn(registeredClient);
+
+		@SuppressWarnings("unchecked")
+		OAuth2TokenGenerator<OAuth2AuthorizationCode> authorizationCodeGenerator = mock(OAuth2TokenGenerator.class);
+		this.authenticationProvider.setAuthorizationCodeGenerator(authorizationCodeGenerator);
+
+		OAuth2AuthorizationCodeRequestAuthenticationToken authentication =
+				authorizationCodeRequestAuthentication(registeredClient, this.principal)
+						.build();
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthorizationCodeRequestAuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthorizationCodeRequestAuthenticationException) ex).getError())
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
+					assertThat(error.getDescription()).contains("The token generator failed to generate the authorization code.");
+				});
+	}
+
 	@Test
 	public void authenticateWhenCustomAuthenticationValidatorResolverThenUsed() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();