Browse Source

Introduce OAuth2TokenGenerator

Closes gh-414
Joe Grandja 3 years ago
parent
commit
c799261a72
21 changed files with 1229 additions and 438 deletions
  1. 15 0
      oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java
  2. 23 2
      oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ConfigurerUtils.java
  3. 8 23
      oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenEndpointConfigurer.java
  4. 2 2
      oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcClientRegistrationEndpointConfigurer.java
  5. 80 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenContext.java
  6. 166 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JwtGenerator.java
  7. 26 3
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenContext.java
  8. 44 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenGenerator.java
  9. 0 101
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtUtils.java
  10. 70 79
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java
  11. 55 32
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java
  12. 75 77
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java
  13. 0 76
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/JwtUtils.java
  14. 62 22
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java
  15. 26 4
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java
  16. 77 1
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java
  17. 215 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JwtGeneratorTests.java
  18. 91 4
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java
  19. 43 4
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java
  20. 89 6
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java
  21. 62 2
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java

+ 15 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java

@@ -37,8 +37,10 @@ import org.springframework.security.config.annotation.web.configurers.ExceptionH
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Transient;
 import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenIntrospectionAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
@@ -146,6 +148,19 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
 		return this;
 	}
 
+	/**
+	 * Sets the token generator.
+	 *
+	 * @param tokenGenerator the token generator
+	 * @return the {@link OAuth2AuthorizationServerConfigurer} for further configuration
+	 * @since 0.2.3
+	 */
+	public OAuth2AuthorizationServerConfigurer<B> tokenGenerator(OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator) {
+		Assert.notNull(tokenGenerator, "tokenGenerator cannot be null");
+		getBuilder().setSharedObject(OAuth2TokenGenerator.class, tokenGenerator);
+		return this;
+	}
+
 	/**
 	 * Configures OAuth 2.0 Client Authentication.
 	 *

+ 23 - 2
oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ConfigurerUtils.java

@@ -26,14 +26,17 @@ import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
 import org.springframework.context.ApplicationContext;
 import org.springframework.core.ResolvableType;
 import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
+import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
 import org.springframework.security.oauth2.jwt.NimbusJwsEncoder;
 import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationConsentService;
 import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
+import org.springframework.security.oauth2.server.authorization.JwtGenerator;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
 import org.springframework.util.StringUtils;
@@ -82,7 +85,25 @@ final class OAuth2ConfigurerUtils {
 		return authorizationConsentService;
 	}
 
-	static <B extends HttpSecurityBuilder<B>> JwtEncoder getJwtEncoder(B builder) {
+	@SuppressWarnings("unchecked")
+	static <B extends HttpSecurityBuilder<B>> OAuth2TokenGenerator<? extends OAuth2Token> getTokenGenerator(B builder) {
+		OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator = builder.getSharedObject(OAuth2TokenGenerator.class);
+		if (tokenGenerator == null) {
+			tokenGenerator = getOptionalBean(builder, OAuth2TokenGenerator.class);
+			if (tokenGenerator == null) {
+				JwtGenerator jwtGenerator = new JwtGenerator(getJwtEncoder(builder));
+				OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = getJwtCustomizer(builder);
+				if (jwtCustomizer != null) {
+					jwtGenerator.setJwtCustomizer(jwtCustomizer);
+				}
+				tokenGenerator = jwtGenerator;
+			}
+			builder.setSharedObject(OAuth2TokenGenerator.class, tokenGenerator);
+		}
+		return tokenGenerator;
+	}
+
+	private static <B extends HttpSecurityBuilder<B>> JwtEncoder getJwtEncoder(B builder) {
 		JwtEncoder jwtEncoder = builder.getSharedObject(JwtEncoder.class);
 		if (jwtEncoder == null) {
 			jwtEncoder = getOptionalBean(builder, JwtEncoder.class);
@@ -107,7 +128,7 @@ final class OAuth2ConfigurerUtils {
 	}
 
 	@SuppressWarnings("unchecked")
-	static <B extends HttpSecurityBuilder<B>> OAuth2TokenCustomizer<JwtEncodingContext> getJwtCustomizer(B builder) {
+	private static <B extends HttpSecurityBuilder<B>> OAuth2TokenCustomizer<JwtEncodingContext> getJwtCustomizer(B builder) {
 		OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = builder.getSharedObject(OAuth2TokenCustomizer.class);
 		if (jwtCustomizer == null) {
 			ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2TokenCustomizer.class, JwtEncodingContext.class);

+ 8 - 23
oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2TokenEndpointConfigurer.java

@@ -28,10 +28,10 @@ import org.springframework.security.config.annotation.ObjectPostProcessor;
 import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
-import org.springframework.security.oauth2.jwt.JwtEncoder;
-import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
-import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationGrantAuthenticationToken;
@@ -160,34 +160,19 @@ public final class OAuth2TokenEndpointConfigurer extends AbstractOAuth2Configure
 	private <B extends HttpSecurityBuilder<B>> List<AuthenticationProvider> createDefaultAuthenticationProviders(B builder) {
 		List<AuthenticationProvider> authenticationProviders = new ArrayList<>();
 
-		JwtEncoder jwtEncoder = OAuth2ConfigurerUtils.getJwtEncoder(builder);
-		OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = OAuth2ConfigurerUtils.getJwtCustomizer(builder);
+		OAuth2AuthorizationService authorizationService = OAuth2ConfigurerUtils.getAuthorizationService(builder);
+		OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator = OAuth2ConfigurerUtils.getTokenGenerator(builder);
 
 		OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider =
-				new OAuth2AuthorizationCodeAuthenticationProvider(
-						OAuth2ConfigurerUtils.getAuthorizationService(builder),
-						jwtEncoder);
-		if (jwtCustomizer != null) {
-			authorizationCodeAuthenticationProvider.setJwtCustomizer(jwtCustomizer);
-		}
+				new OAuth2AuthorizationCodeAuthenticationProvider(authorizationService, tokenGenerator);
 		authenticationProviders.add(authorizationCodeAuthenticationProvider);
 
 		OAuth2RefreshTokenAuthenticationProvider refreshTokenAuthenticationProvider =
-				new OAuth2RefreshTokenAuthenticationProvider(
-						OAuth2ConfigurerUtils.getAuthorizationService(builder),
-						jwtEncoder);
-		if (jwtCustomizer != null) {
-			refreshTokenAuthenticationProvider.setJwtCustomizer(jwtCustomizer);
-		}
+				new OAuth2RefreshTokenAuthenticationProvider(authorizationService, tokenGenerator);
 		authenticationProviders.add(refreshTokenAuthenticationProvider);
 
 		OAuth2ClientCredentialsAuthenticationProvider clientCredentialsAuthenticationProvider =
-				new OAuth2ClientCredentialsAuthenticationProvider(
-						OAuth2ConfigurerUtils.getAuthorizationService(builder),
-						jwtEncoder);
-		if (jwtCustomizer != null) {
-			clientCredentialsAuthenticationProvider.setJwtCustomizer(jwtCustomizer);
-		}
+				new OAuth2ClientCredentialsAuthenticationProvider(authorizationService, tokenGenerator);
 		authenticationProviders.add(clientCredentialsAuthenticationProvider);
 
 		return authenticationProviders;

+ 2 - 2
oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcClientRegistrationEndpointConfigurer.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2021 the original author or authors.
+ * Copyright 2020-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -57,7 +57,7 @@ public final class OidcClientRegistrationEndpointConfigurer extends AbstractOAut
 				new OidcClientRegistrationAuthenticationProvider(
 						OAuth2ConfigurerUtils.getRegisteredClientRepository(builder),
 						OAuth2ConfigurerUtils.getAuthorizationService(builder),
-						OAuth2ConfigurerUtils.getJwtEncoder(builder));
+						OAuth2ConfigurerUtils.getTokenGenerator(builder));
 		builder.authenticationProvider(postProcess(oidcClientRegistrationAuthenticationProvider));
 	}
 

+ 80 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/DefaultOAuth2TokenContext.java

@@ -0,0 +1,80 @@
+/*
+ * Copyright 2020-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.springframework.lang.Nullable;
+import org.springframework.util.Assert;
+
+/**
+ * Default implementation of {@link OAuth2TokenContext}.
+ *
+ * @author Joe Grandja
+ * @since 0.2.3
+ * @see OAuth2TokenContext
+ */
+public final class DefaultOAuth2TokenContext implements OAuth2TokenContext {
+	private final Map<Object, Object> context;
+
+	private DefaultOAuth2TokenContext(Map<Object, Object> context) {
+		this.context = Collections.unmodifiableMap(new HashMap<>(context));
+	}
+
+	@SuppressWarnings("unchecked")
+	@Nullable
+	@Override
+	public <V> V get(Object key) {
+		return hasKey(key) ? (V) this.context.get(key) : null;
+	}
+
+	@Override
+	public boolean hasKey(Object key) {
+		Assert.notNull(key, "key cannot be null");
+		return this.context.containsKey(key);
+	}
+
+	/**
+	 * Returns a new {@link Builder}.
+	 *
+	 * @return the {@link Builder}
+	 */
+	public static Builder builder() {
+		return new Builder();
+	}
+
+	/**
+	 * A builder for {@link DefaultOAuth2TokenContext}.
+	 */
+	public static final class Builder extends AbstractBuilder<DefaultOAuth2TokenContext, Builder> {
+
+		private Builder() {
+		}
+
+		/**
+		 * Builds a new {@link DefaultOAuth2TokenContext}.
+		 *
+		 * @return the {@link DefaultOAuth2TokenContext}
+		 */
+		public DefaultOAuth2TokenContext build() {
+			return new DefaultOAuth2TokenContext(getContext());
+		}
+
+	}
+
+}

+ 166 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JwtGenerator.java

@@ -0,0 +1,166 @@
+/*
+ * Copyright 2020-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization;
+
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.Collections;
+import java.util.function.Consumer;
+
+import org.springframework.lang.Nullable;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2TokenType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
+import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
+import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
+import org.springframework.security.oauth2.jwt.JoseHeader;
+import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtClaimsSet;
+import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.StringUtils;
+
+/**
+ * An {@link OAuth2TokenGenerator} that generates a {@link Jwt}
+ * used for an {@link OAuth2AccessToken} or {@link OidcIdToken}.
+ *
+ * @author Joe Grandja
+ * @since 0.2.3
+ * @see OAuth2TokenGenerator
+ * @see Jwt
+ * @see JwtEncoder
+ * @see OAuth2TokenCustomizer
+ * @see JwtEncodingContext
+ * @see OAuth2AccessToken
+ * @see OidcIdToken
+ */
+public final class JwtGenerator implements OAuth2TokenGenerator<Jwt> {
+	private final JwtEncoder jwtEncoder;
+	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer;
+
+	/**
+	 * Constructs a {@code JwtGenerator} using the provided parameters.
+	 *
+	 * @param jwtEncoder the jwt encoder
+	 */
+	public JwtGenerator(JwtEncoder jwtEncoder) {
+		Assert.notNull(jwtEncoder, "jwtEncoder cannot be null");
+		this.jwtEncoder = jwtEncoder;
+	}
+
+	@Nullable
+	@Override
+	public Jwt generate(OAuth2TokenContext context) {
+		if (context.getTokenType() == null ||
+				(!OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType()) &&
+						!OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue()))) {
+			return null;
+		}
+
+		String issuer = null;
+		if (context.getProviderContext() != null) {
+			issuer = context.getProviderContext().getIssuer();
+		}
+		RegisteredClient registeredClient = context.getRegisteredClient();
+
+		Instant issuedAt = Instant.now();
+		Instant expiresAt;
+		if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) {
+			// TODO Allow configuration for ID Token time-to-live
+			expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES);
+		} else {
+			expiresAt = issuedAt.plus(registeredClient.getTokenSettings().getAccessTokenTimeToLive());
+		}
+
+		// @formatter:off
+		JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder();
+		if (StringUtils.hasText(issuer)) {
+			claimsBuilder.issuer(issuer);
+		}
+		claimsBuilder
+				.subject(context.getPrincipal().getName())
+				.audience(Collections.singletonList(registeredClient.getClientId()))
+				.issuedAt(issuedAt)
+				.expiresAt(expiresAt);
+		if (OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType())) {
+			claimsBuilder.notBefore(issuedAt);
+			if (!CollectionUtils.isEmpty(context.getAuthorizedScopes())) {
+				claimsBuilder.claim(OAuth2ParameterNames.SCOPE, context.getAuthorizedScopes());
+			}
+		} else if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) {
+			claimsBuilder.claim(IdTokenClaimNames.AZP, registeredClient.getClientId());
+			if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getAuthorizationGrantType())) {
+				OAuth2AuthorizationRequest authorizationRequest = context.getAuthorization().getAttribute(
+						OAuth2AuthorizationRequest.class.getName());
+				String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE);
+				if (StringUtils.hasText(nonce)) {
+					claimsBuilder.claim(IdTokenClaimNames.NONCE, nonce);
+				}
+			}
+			// TODO Add 'auth_time' claim
+		}
+		// @formatter:on
+
+		JoseHeader.Builder headersBuilder = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256);
+
+		if (this.jwtCustomizer != null) {
+			// @formatter:off
+			JwtEncodingContext.Builder jwtContextBuilder = JwtEncodingContext.with(headersBuilder, claimsBuilder)
+					.registeredClient(context.getRegisteredClient())
+					.principal(context.getPrincipal())
+					.providerContext(context.getProviderContext())
+					.authorizedScopes(context.getAuthorizedScopes())
+					.tokenType(context.getTokenType())
+					.authorizationGrantType(context.getAuthorizationGrantType());
+			if (context.getAuthorization() != null) {
+				jwtContextBuilder.authorization(context.getAuthorization());
+			}
+			if (context.getAuthorizationGrant() != null) {
+				jwtContextBuilder.authorizationGrant(context.getAuthorizationGrant());
+			}
+			// @formatter:on
+
+			JwtEncodingContext jwtContext = jwtContextBuilder.build();
+			this.jwtCustomizer.customize(jwtContext);
+		}
+
+		JoseHeader headers = headersBuilder.build();
+		JwtClaimsSet claims = claimsBuilder.build();
+
+		Jwt jwt = this.jwtEncoder.encode(headers, claims);
+
+		return jwt;
+	}
+
+	/**
+	 * Sets the {@link OAuth2TokenCustomizer} that customizes the
+	 * {@link JwtEncodingContext.Builder#headers(Consumer) headers} and/or
+	 * {@link JwtEncodingContext.Builder#claims(Consumer) claims} for the generated {@link Jwt}.
+	 *
+	 * @param jwtCustomizer the {@link OAuth2TokenCustomizer} that customizes the headers and/or claims for the generated {@code Jwt}
+	 */
+	public void setJwtCustomizer(OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer) {
+		Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
+		this.jwtCustomizer = jwtCustomizer;
+	}
+
+}

+ 26 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenContext.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2021 the original author or authors.
+ * Copyright 2020-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -27,15 +27,17 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2TokenType;
 import org.springframework.security.oauth2.core.context.Context;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.context.ProviderContext;
 import org.springframework.util.Assert;
 
 /**
- * A context that holds information associated to an OAuth 2.0 Token
- * and is used by an {@link OAuth2TokenCustomizer} for customizing the token attributes.
+ * A context that holds information (to be) associated to an OAuth 2.0 Token
+ * and is used by an {@link OAuth2TokenGenerator} and {@link OAuth2TokenCustomizer}.
  *
  * @author Joe Grandja
  * @since 0.1.0
  * @see Context
+ * @see OAuth2TokenGenerator
  * @see OAuth2TokenCustomizer
  */
 public interface OAuth2TokenContext extends Context {
@@ -59,6 +61,16 @@ public interface OAuth2TokenContext extends Context {
 		return get(AbstractBuilder.PRINCIPAL_AUTHENTICATION_KEY);
 	}
 
+	/**
+	 * Returns the {@link ProviderContext provider context}.
+	 *
+	 * @return the {@link ProviderContext}
+	 * @since 0.2.3
+	 */
+	default ProviderContext getProviderContext() {
+		return get(ProviderContext.class);
+	}
+
 	/**
 	 * Returns the {@link OAuth2Authorization authorization}.
 	 *
@@ -141,6 +153,17 @@ public interface OAuth2TokenContext extends Context {
 			return put(PRINCIPAL_AUTHENTICATION_KEY, principal);
 		}
 
+		/**
+		 * Sets the {@link ProviderContext provider context}.
+		 *
+		 * @param providerContext the {@link ProviderContext}
+		 * @return the {@link AbstractBuilder} for further configuration
+		 * @since 0.2.3
+		 */
+		public B providerContext(ProviderContext providerContext) {
+			return put(ProviderContext.class, providerContext);
+		}
+
 		/**
 		 * Sets the {@link OAuth2Authorization authorization}.
 		 *

+ 44 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2TokenGenerator.java

@@ -0,0 +1,44 @@
+/*
+ * Copyright 2020-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization;
+
+import org.springframework.lang.Nullable;
+import org.springframework.security.oauth2.core.OAuth2Token;
+
+/**
+ * Implementations of this interface are responsible for generating an {@link OAuth2Token}
+ * using the attributes contained in the {@link OAuth2TokenContext}.
+ *
+ * @author Joe Grandja
+ * @since 0.2.3
+ * @see OAuth2Token
+ * @see OAuth2TokenContext
+ * @param <T> the type of the OAuth 2.0 Token
+ */
+@FunctionalInterface
+public interface OAuth2TokenGenerator<T extends OAuth2Token> {
+
+	/**
+	 * Generate an OAuth 2.0 Token using the attributes contained in the {@link OAuth2TokenContext},
+	 * or return {@code null} if the {@link OAuth2TokenContext#getTokenType()} is not supported.
+	 *
+	 * @param context the context containing the OAuth 2.0 Token attributes
+	 * @return an {@link OAuth2Token} or {@code null} if the {@link OAuth2TokenContext#getTokenType()} is not supported
+	 */
+	@Nullable
+	T generate(OAuth2TokenContext context);
+
+}

+ 0 - 101
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtUtils.java

@@ -1,101 +0,0 @@
-/*
- * Copyright 2020-2021 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.server.authorization.authentication;
-
-import java.time.Instant;
-import java.time.temporal.ChronoUnit;
-import java.util.Collections;
-import java.util.Set;
-
-import org.springframework.security.authentication.AuthenticationProvider;
-import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
-import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
-import org.springframework.security.oauth2.jwt.JoseHeader;
-import org.springframework.security.oauth2.jwt.Jwt;
-import org.springframework.security.oauth2.jwt.JwtClaimsSet;
-import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
-import org.springframework.util.CollectionUtils;
-import org.springframework.util.StringUtils;
-
-/**
- * Utility methods used by the {@link AuthenticationProvider}'s when issuing {@link Jwt}'s.
- *
- * @author Joe Grandja
- * @since 0.1.0
- */
-final class JwtUtils {
-
-	private JwtUtils() {
-	}
-
-	static JoseHeader.Builder headers() {
-		return JoseHeader.withAlgorithm(SignatureAlgorithm.RS256);
-	}
-
-	static JwtClaimsSet.Builder accessTokenClaims(RegisteredClient registeredClient,
-			String issuer, String subject, Set<String> authorizedScopes) {
-
-		Instant issuedAt = Instant.now();
-		Instant expiresAt = issuedAt.plus(registeredClient.getTokenSettings().getAccessTokenTimeToLive());
-
-		// @formatter:off
-		JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder();
-		if (StringUtils.hasText(issuer)) {
-			claimsBuilder.issuer(issuer);
-		}
-		claimsBuilder
-				.subject(subject)
-				.audience(Collections.singletonList(registeredClient.getClientId()))
-				.issuedAt(issuedAt)
-				.expiresAt(expiresAt)
-				.notBefore(issuedAt);
-		if (!CollectionUtils.isEmpty(authorizedScopes)) {
-			claimsBuilder.claim(OAuth2ParameterNames.SCOPE, authorizedScopes);
-		}
-		// @formatter:on
-
-		return claimsBuilder;
-	}
-
-	static JwtClaimsSet.Builder idTokenClaims(RegisteredClient registeredClient,
-			String issuer, String subject, String nonce) {
-
-		Instant issuedAt = Instant.now();
-		// TODO Allow configuration for ID Token time-to-live
-		Instant expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES);
-
-		// @formatter:off
-		JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder();
-		if (StringUtils.hasText(issuer)) {
-			claimsBuilder.issuer(issuer);
-		}
-		claimsBuilder
-				.subject(subject)
-				.audience(Collections.singletonList(registeredClient.getClientId()))
-				.issuedAt(issuedAt)
-				.expiresAt(expiresAt)
-				.claim(IdTokenClaimNames.AZP, registeredClient.getClientId());
-		if (StringUtils.hasText(nonce)) {
-			claimsBuilder.claim(IdTokenClaimNames.NONCE, nonce);
-		}
-		// TODO Add 'auth_time' claim
-		// @formatter:on
-
-		return claimsBuilder;
-	}
-
-}

+ 70 - 79
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -22,7 +22,6 @@ import java.util.Base64;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
-import java.util.Set;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
 
@@ -36,22 +35,26 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2AuthorizationCode;
+import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.core.OAuth2TokenType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
 import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
-import org.springframework.security.oauth2.jwt.JoseHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
-import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.DefaultOAuth2TokenContext;
 import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
+import org.springframework.security.oauth2.server.authorization.JwtGenerator;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
 import org.springframework.security.oauth2.server.authorization.context.ProviderContextHolder;
@@ -70,13 +73,12 @@ import static org.springframework.security.oauth2.server.authorization.authentic
  * @see OAuth2AccessTokenAuthenticationToken
  * @see OAuth2AuthorizationCodeRequestAuthenticationProvider
  * @see OAuth2AuthorizationService
- * @see JwtEncoder
- * @see OAuth2TokenCustomizer
- * @see JwtEncodingContext
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request</a>
+ * @see OAuth2TokenGenerator
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request</a>
  */
 public final class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider {
+	private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
 	private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE =
 			new OAuth2TokenType(OAuth2ParameterNames.CODE);
 	private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE =
@@ -84,21 +86,37 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
 	private static final StringKeyGenerator DEFAULT_REFRESH_TOKEN_GENERATOR =
 			new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
 	private final OAuth2AuthorizationService authorizationService;
-	private final JwtEncoder jwtEncoder;
-	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = (context) -> {};
+	private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
 	private Supplier<String> refreshTokenGenerator = DEFAULT_REFRESH_TOKEN_GENERATOR::generateKey;
 
 	/**
 	 * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters.
 	 *
+	 * @deprecated Use {@link #OAuth2AuthorizationCodeAuthenticationProvider(OAuth2AuthorizationService, OAuth2TokenGenerator)} instead
 	 * @param authorizationService the authorization service
 	 * @param jwtEncoder the jwt encoder
 	 */
+	@Deprecated
 	public OAuth2AuthorizationCodeAuthenticationProvider(OAuth2AuthorizationService authorizationService, JwtEncoder jwtEncoder) {
 		Assert.notNull(authorizationService, "authorizationService cannot be null");
 		Assert.notNull(jwtEncoder, "jwtEncoder cannot be null");
 		this.authorizationService = authorizationService;
-		this.jwtEncoder = jwtEncoder;
+		this.tokenGenerator = new JwtGenerator(jwtEncoder);
+	}
+
+	/**
+	 * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters.
+	 *
+	 * @param authorizationService the authorization service
+	 * @param tokenGenerator the token generator
+	 * @since 0.2.3
+	 */
+	public OAuth2AuthorizationCodeAuthenticationProvider(OAuth2AuthorizationService authorizationService,
+			OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator) {
+		Assert.notNull(authorizationService, "authorizationService cannot be null");
+		Assert.notNull(tokenGenerator, "tokenGenerator cannot be null");
+		this.authorizationService = authorizationService;
+		this.tokenGenerator = tokenGenerator;
 	}
 
 	/**
@@ -106,11 +124,15 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
 	 * {@link JwtEncodingContext.Builder#headers(Consumer) headers} and/or
 	 * {@link JwtEncodingContext.Builder#claims(Consumer) claims} for the generated {@link Jwt}.
 	 *
+	 * @deprecated Use {@link JwtGenerator#setJwtCustomizer(OAuth2TokenCustomizer)} instead
 	 * @param jwtCustomizer the {@link OAuth2TokenCustomizer} that customizes the headers and/or claims for the generated {@code Jwt}
 	 */
+	@Deprecated
 	public void setJwtCustomizer(OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer) {
 		Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
-		this.jwtCustomizer = jwtCustomizer;
+		if (this.tokenGenerator instanceof JwtGenerator) {
+			((JwtGenerator) this.tokenGenerator).setJwtCustomizer(jwtCustomizer);
+		}
 	}
 
 	/**
@@ -165,96 +187,65 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
 			throw new OAuth2AuthenticationException(OAuth2ErrorCodes.INVALID_GRANT);
 		}
 
-		String issuer = ProviderContextHolder.getProviderContext().getIssuer();
-		Set<String> authorizedScopes = authorization.getAttribute(
-				OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME);
-
-		JoseHeader.Builder headersBuilder = JwtUtils.headers();
-		JwtClaimsSet.Builder claimsBuilder = JwtUtils.accessTokenClaims(
-				registeredClient, issuer, authorization.getPrincipalName(),
-				authorizedScopes);
-
 		// @formatter:off
-		JwtEncodingContext context = JwtEncodingContext.with(headersBuilder, claimsBuilder)
+		DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
 				.registeredClient(registeredClient)
 				.principal(authorization.getAttribute(Principal.class.getName()))
+				.providerContext(ProviderContextHolder.getProviderContext())
 				.authorization(authorization)
-				.authorizedScopes(authorizedScopes)
-				.tokenType(OAuth2TokenType.ACCESS_TOKEN)
+				.authorizedScopes(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME))
 				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
-				.authorizationGrant(authorizationCodeAuthentication)
-				.build();
+				.authorizationGrant(authorizationCodeAuthentication);
 		// @formatter:on
 
-		this.jwtCustomizer.customize(context);
-
-		JoseHeader headers = context.getHeaders().build();
-		JwtClaimsSet claims = context.getClaims().build();
-		Jwt jwtAccessToken = this.jwtEncoder.encode(headers, claims);
+		OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization);
 
+		// ----- Access token -----
+		OAuth2TokenContext tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.ACCESS_TOKEN).build();
+		OAuth2Token generatedAccessToken = this.tokenGenerator.generate(tokenContext);
+		if (generatedAccessToken == null) {
+			OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
+					"The token generator failed to generate the access token.", ERROR_URI);
+			throw new OAuth2AuthenticationException(error);
+		}
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
-				jwtAccessToken.getTokenValue(), jwtAccessToken.getIssuedAt(),
-				jwtAccessToken.getExpiresAt(), authorizedScopes);
+				generatedAccessToken.getTokenValue(), generatedAccessToken.getIssuedAt(),
+				generatedAccessToken.getExpiresAt(), tokenContext.getAuthorizedScopes());
+		if (generatedAccessToken instanceof Jwt) {
+			authorizationBuilder.token(accessToken, (metadata) ->
+					metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, ((Jwt) generatedAccessToken).getClaims()));
+		} else {
+			authorizationBuilder.accessToken(accessToken);
+		}
 
+		// ----- Refresh token -----
 		OAuth2RefreshToken refreshToken = null;
 		if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN) &&
 				// Do not issue refresh token to public client
 				!clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE)) {
 			refreshToken = generateRefreshToken(registeredClient.getTokenSettings().getRefreshTokenTimeToLive());
+			authorizationBuilder.refreshToken(refreshToken);
 		}
 
-		Jwt jwtIdToken = null;
-		if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)) {
-			String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE);
-
-			headersBuilder = JwtUtils.headers();
-			claimsBuilder = JwtUtils.idTokenClaims(
-					registeredClient, issuer, authorization.getPrincipalName(), nonce);
-
-			// @formatter:off
-			context = JwtEncodingContext.with(headersBuilder, claimsBuilder)
-					.registeredClient(registeredClient)
-					.principal(authorization.getAttribute(Principal.class.getName()))
-					.authorization(authorization)
-					.authorizedScopes(authorizedScopes)
-					.tokenType(ID_TOKEN_TOKEN_TYPE)
-					.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
-					.authorizationGrant(authorizationCodeAuthentication)
-					.build();
-			// @formatter:on
-
-			this.jwtCustomizer.customize(context);
-
-			headers = context.getHeaders().build();
-			claims = context.getClaims().build();
-			jwtIdToken = this.jwtEncoder.encode(headers, claims);
-		}
-
+		// ----- ID token -----
 		OidcIdToken idToken;
-		if (jwtIdToken != null) {
-			idToken = new OidcIdToken(jwtIdToken.getTokenValue(), jwtIdToken.getIssuedAt(),
-					jwtIdToken.getExpiresAt(), jwtIdToken.getClaims());
+		if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)) {
+			tokenContext = tokenContextBuilder.tokenType(ID_TOKEN_TOKEN_TYPE).build();
+			OAuth2Token generatedIdToken = this.tokenGenerator.generate(tokenContext);
+			if (!(generatedIdToken instanceof Jwt)) {
+				OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
+						"The token generator failed to generate the ID token.", ERROR_URI);
+				throw new OAuth2AuthenticationException(error);
+			}
+			idToken = new OidcIdToken(generatedIdToken.getTokenValue(), generatedIdToken.getIssuedAt(),
+					generatedIdToken.getExpiresAt(), ((Jwt) generatedIdToken).getClaims());
+			authorizationBuilder.token(idToken, (metadata) ->
+					metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims()));
 		} else {
 			idToken = null;
 		}
 
-		// @formatter:off
-		OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
-				.token(accessToken,
-						(metadata) ->
-								metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, jwtAccessToken.getClaims())
-				);
-		if (refreshToken != null) {
-			authorizationBuilder.refreshToken(refreshToken);
-		}
-		if (idToken != null) {
-			authorizationBuilder
-					.token(idToken,
-							(metadata) ->
-									metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims()));
-		}
 		authorization = authorizationBuilder.build();
-		// @formatter:on
 
 		// Invalidate the authorization code as it can only be used once
 		authorization = OAuth2AuthenticationProviderUtils.invalidate(authorization, authorizationCode.getToken());

+ 55 - 32
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java

@@ -25,16 +25,20 @@ import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.core.OAuth2TokenType;
-import org.springframework.security.oauth2.jwt.JoseHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
-import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.DefaultOAuth2TokenContext;
 import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
+import org.springframework.security.oauth2.server.authorization.JwtGenerator;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
 import org.springframework.security.oauth2.server.authorization.context.ProviderContextHolder;
@@ -52,29 +56,44 @@ import static org.springframework.security.oauth2.server.authorization.authentic
  * @see OAuth2ClientCredentialsAuthenticationToken
  * @see OAuth2AccessTokenAuthenticationToken
  * @see OAuth2AuthorizationService
- * @see JwtEncoder
- * @see OAuth2TokenCustomizer
- * @see JwtEncodingContext
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.4">Section 4.4 Client Credentials Grant</a>
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.4.2">Section 4.4.2 Access Token Request</a>
+ * @see OAuth2TokenGenerator
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc6749#section-4.4">Section 4.4 Client Credentials Grant</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc6749#section-4.4.2">Section 4.4.2 Access Token Request</a>
  */
 public final class OAuth2ClientCredentialsAuthenticationProvider implements AuthenticationProvider {
+	private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
 	private final OAuth2AuthorizationService authorizationService;
-	private final JwtEncoder jwtEncoder;
-	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = (context) -> {};
+	private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
 
 	/**
 	 * Constructs an {@code OAuth2ClientCredentialsAuthenticationProvider} using the provided parameters.
 	 *
+	 * @deprecated Use {@link #OAuth2ClientCredentialsAuthenticationProvider(OAuth2AuthorizationService, OAuth2TokenGenerator)} instead
 	 * @param authorizationService the authorization service
 	 * @param jwtEncoder the jwt encoder
 	 */
+	@Deprecated
 	public OAuth2ClientCredentialsAuthenticationProvider(OAuth2AuthorizationService authorizationService,
 			JwtEncoder jwtEncoder) {
 		Assert.notNull(authorizationService, "authorizationService cannot be null");
 		Assert.notNull(jwtEncoder, "jwtEncoder cannot be null");
 		this.authorizationService = authorizationService;
-		this.jwtEncoder = jwtEncoder;
+		this.tokenGenerator = new JwtGenerator(jwtEncoder);
+	}
+
+	/**
+	 * Constructs an {@code OAuth2ClientCredentialsAuthenticationProvider} using the provided parameters.
+	 *
+	 * @param authorizationService the authorization service
+	 * @param tokenGenerator the token generator
+	 * @since 0.2.3
+	 */
+	public OAuth2ClientCredentialsAuthenticationProvider(OAuth2AuthorizationService authorizationService,
+			OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator) {
+		Assert.notNull(authorizationService, "authorizationService cannot be null");
+		Assert.notNull(tokenGenerator, "tokenGenerator cannot be null");
+		this.authorizationService = authorizationService;
+		this.tokenGenerator = tokenGenerator;
 	}
 
 	/**
@@ -82,11 +101,15 @@ public final class OAuth2ClientCredentialsAuthenticationProvider implements Auth
 	 * {@link JwtEncodingContext.Builder#headers(Consumer) headers} and/or
 	 * {@link JwtEncodingContext.Builder#claims(Consumer) claims} for the generated {@link Jwt}.
 	 *
+	 * @deprecated Use {@link JwtGenerator#setJwtCustomizer(OAuth2TokenCustomizer)} instead
 	 * @param jwtCustomizer the {@link OAuth2TokenCustomizer} that customizes the headers and/or claims for the generated {@code Jwt}
 	 */
+	@Deprecated
 	public void setJwtCustomizer(OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer) {
 		Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
-		this.jwtCustomizer = jwtCustomizer;
+		if (this.tokenGenerator instanceof JwtGenerator) {
+			((JwtGenerator) this.tokenGenerator).setJwtCustomizer(jwtCustomizer);
+		}
 	}
 
 	@Deprecated
@@ -116,16 +139,11 @@ public final class OAuth2ClientCredentialsAuthenticationProvider implements Auth
 			authorizedScopes = new LinkedHashSet<>(clientCredentialsAuthentication.getScopes());
 		}
 
-		String issuer = ProviderContextHolder.getProviderContext().getIssuer();
-
-		JoseHeader.Builder headersBuilder = JwtUtils.headers();
-		JwtClaimsSet.Builder claimsBuilder = JwtUtils.accessTokenClaims(
-				registeredClient, issuer, clientPrincipal.getName(), authorizedScopes);
-
 		// @formatter:off
-		JwtEncodingContext context = JwtEncodingContext.with(headersBuilder, claimsBuilder)
+		OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder()
 				.registeredClient(registeredClient)
 				.principal(clientPrincipal)
+				.providerContext(ProviderContextHolder.getProviderContext())
 				.authorizedScopes(authorizedScopes)
 				.tokenType(OAuth2TokenType.ACCESS_TOKEN)
 				.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
@@ -133,26 +151,30 @@ public final class OAuth2ClientCredentialsAuthenticationProvider implements Auth
 				.build();
 		// @formatter:on
 
-		this.jwtCustomizer.customize(context);
-
-		JoseHeader headers = context.getHeaders().build();
-		JwtClaimsSet claims = context.getClaims().build();
-		Jwt jwtAccessToken = this.jwtEncoder.encode(headers, claims);
-
+		OAuth2Token generatedAccessToken = this.tokenGenerator.generate(tokenContext);
+		if (generatedAccessToken == null) {
+			OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
+					"The token generator failed to generate the access token.", ERROR_URI);
+			throw new OAuth2AuthenticationException(error);
+		}
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
-				jwtAccessToken.getTokenValue(), jwtAccessToken.getIssuedAt(),
-				jwtAccessToken.getExpiresAt(), authorizedScopes);
+				generatedAccessToken.getTokenValue(), generatedAccessToken.getIssuedAt(),
+				generatedAccessToken.getExpiresAt(), tokenContext.getAuthorizedScopes());
 
 		// @formatter:off
-		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient)
+		OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.withRegisteredClient(registeredClient)
 				.principalName(clientPrincipal.getName())
 				.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
-				.token(accessToken,
-						(metadata) ->
-								metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, jwtAccessToken.getClaims()))
-				.attribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, authorizedScopes)
-				.build();
+				.attribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, authorizedScopes);
 		// @formatter:on
+		if (generatedAccessToken instanceof Jwt) {
+			authorizationBuilder.token(accessToken, (metadata) ->
+					metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, ((Jwt) generatedAccessToken).getClaims()));
+		} else {
+			authorizationBuilder.accessToken(accessToken);
+		}
+
+		OAuth2Authorization authorization = authorizationBuilder.build();
 
 		this.authorizationService.save(authorization);
 
@@ -163,4 +185,5 @@ public final class OAuth2ClientCredentialsAuthenticationProvider implements Auth
 	public boolean supports(Class<?> authentication) {
 		return OAuth2ClientCredentialsAuthenticationToken.class.isAssignableFrom(authentication);
 	}
+
 }

+ 75 - 77
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java

@@ -34,23 +34,26 @@ import org.springframework.security.crypto.keygen.StringKeyGenerator;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.core.OAuth2TokenType;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
 import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
-import org.springframework.security.oauth2.jwt.JoseHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
-import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.DefaultOAuth2TokenContext;
 import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
+import org.springframework.security.oauth2.server.authorization.JwtGenerator;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
-import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
 import org.springframework.security.oauth2.server.authorization.context.ProviderContextHolder;
 import org.springframework.util.Assert;
 
@@ -66,33 +69,48 @@ import static org.springframework.security.oauth2.server.authorization.authentic
  * @see OAuth2RefreshTokenAuthenticationToken
  * @see OAuth2AccessTokenAuthenticationToken
  * @see OAuth2AuthorizationService
- * @see JwtEncoder
- * @see OAuth2TokenCustomizer
- * @see JwtEncodingContext
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-1.5">Section 1.5 Refresh Token Grant</a>
- * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-6">Section 6 Refreshing an Access Token</a>
+ * @see OAuth2TokenGenerator
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc6749#section-1.5">Section 1.5 Refresh Token Grant</a>
+ * @see <a target="_blank" href="https://datatracker.ietf.org/doc/html/rfc6749#section-6">Section 6 Refreshing an Access Token</a>
  */
 public final class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider {
+	private static final String ERROR_URI = "https://datatracker.ietf.org/doc/html/rfc6749#section-5.2";
 	private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = new OAuth2TokenType(OidcParameterNames.ID_TOKEN);
 	private static final StringKeyGenerator DEFAULT_REFRESH_TOKEN_GENERATOR =
 			new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
 	private final OAuth2AuthorizationService authorizationService;
-	private final JwtEncoder jwtEncoder;
-	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = (context) -> {};
+	private final OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
 	private Supplier<String> refreshTokenGenerator = DEFAULT_REFRESH_TOKEN_GENERATOR::generateKey;
 
 	/**
 	 * Constructs an {@code OAuth2RefreshTokenAuthenticationProvider} using the provided parameters.
 	 *
+	 * @deprecated Use {@link #OAuth2RefreshTokenAuthenticationProvider(OAuth2AuthorizationService, OAuth2TokenGenerator)} instead
 	 * @param authorizationService the authorization service
 	 * @param jwtEncoder the jwt encoder
 	 */
+	@Deprecated
 	public OAuth2RefreshTokenAuthenticationProvider(OAuth2AuthorizationService authorizationService,
 			JwtEncoder jwtEncoder) {
 		Assert.notNull(authorizationService, "authorizationService cannot be null");
 		Assert.notNull(jwtEncoder, "jwtEncoder cannot be null");
 		this.authorizationService = authorizationService;
-		this.jwtEncoder = jwtEncoder;
+		this.tokenGenerator = new JwtGenerator(jwtEncoder);
+	}
+
+	/**
+	 * Constructs an {@code OAuth2RefreshTokenAuthenticationProvider} using the provided parameters.
+	 *
+	 * @param authorizationService the authorization service
+	 * @param tokenGenerator the token generator
+	 * @since 0.2.3
+	 */
+	public OAuth2RefreshTokenAuthenticationProvider(OAuth2AuthorizationService authorizationService,
+			OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator) {
+		Assert.notNull(authorizationService, "authorizationService cannot be null");
+		Assert.notNull(tokenGenerator, "tokenGenerator cannot be null");
+		this.authorizationService = authorizationService;
+		this.tokenGenerator = tokenGenerator;
 	}
 
 	/**
@@ -100,11 +118,15 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
 	 * {@link JwtEncodingContext.Builder#headers(Consumer) headers} and/or
 	 * {@link JwtEncodingContext.Builder#claims(Consumer) claims} for the generated {@link Jwt}.
 	 *
+	 * @deprecated Use {@link JwtGenerator#setJwtCustomizer(OAuth2TokenCustomizer)} instead
 	 * @param jwtCustomizer the {@link OAuth2TokenCustomizer} that customizes the headers and/or claims for the generated {@code Jwt}
 	 */
+	@Deprecated
 	public void setJwtCustomizer(OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer) {
 		Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
-		this.jwtCustomizer = jwtCustomizer;
+		if (this.tokenGenerator instanceof JwtGenerator) {
+			((JwtGenerator) this.tokenGenerator).setJwtCustomizer(jwtCustomizer);
+		}
 	}
 
 	/**
@@ -164,90 +186,65 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
 			scopes = authorizedScopes;
 		}
 
-		String issuer = ProviderContextHolder.getProviderContext().getIssuer();
-
-		JoseHeader.Builder headersBuilder = JwtUtils.headers();
-		JwtClaimsSet.Builder claimsBuilder = JwtUtils.accessTokenClaims(
-				registeredClient, issuer, authorization.getPrincipalName(), scopes);
-
 		// @formatter:off
-		JwtEncodingContext context = JwtEncodingContext.with(headersBuilder, claimsBuilder)
+		DefaultOAuth2TokenContext.Builder tokenContextBuilder = DefaultOAuth2TokenContext.builder()
 				.registeredClient(registeredClient)
 				.principal(authorization.getAttribute(Principal.class.getName()))
+				.providerContext(ProviderContextHolder.getProviderContext())
 				.authorization(authorization)
-				.authorizedScopes(authorizedScopes)
-				.tokenType(OAuth2TokenType.ACCESS_TOKEN)
+				.authorizedScopes(scopes)
 				.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
-				.authorizationGrant(refreshTokenAuthentication)
-				.build();
+				.authorizationGrant(refreshTokenAuthentication);
 		// @formatter:on
 
-		this.jwtCustomizer.customize(context);
-
-		JoseHeader headers = context.getHeaders().build();
-		JwtClaimsSet claims = context.getClaims().build();
-		Jwt jwtAccessToken = this.jwtEncoder.encode(headers, claims);
+		OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization);
 
+		// ----- Access token -----
+		OAuth2TokenContext tokenContext = tokenContextBuilder.tokenType(OAuth2TokenType.ACCESS_TOKEN).build();
+		OAuth2Token generatedAccessToken = this.tokenGenerator.generate(tokenContext);
+		if (generatedAccessToken == null) {
+			OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
+					"The token generator failed to generate the access token.", ERROR_URI);
+			throw new OAuth2AuthenticationException(error);
+		}
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
-				jwtAccessToken.getTokenValue(), jwtAccessToken.getIssuedAt(),
-				jwtAccessToken.getExpiresAt(), scopes);
-
-		TokenSettings tokenSettings = registeredClient.getTokenSettings();
-
-		OAuth2RefreshToken currentRefreshToken = refreshToken.getToken();
-		if (!tokenSettings.isReuseRefreshTokens()) {
-			currentRefreshToken = generateRefreshToken(tokenSettings.getRefreshTokenTimeToLive());
+				generatedAccessToken.getTokenValue(), generatedAccessToken.getIssuedAt(),
+				generatedAccessToken.getExpiresAt(), tokenContext.getAuthorizedScopes());
+		if (generatedAccessToken instanceof Jwt) {
+			authorizationBuilder.token(accessToken, (metadata) -> {
+				metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, ((Jwt) generatedAccessToken).getClaims());
+				metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, false);
+			});
+		} else {
+			authorizationBuilder.accessToken(accessToken);
 		}
 
-		Jwt jwtIdToken = null;
-		if (authorizedScopes.contains(OidcScopes.OPENID)) {
-			headersBuilder = JwtUtils.headers();
-			claimsBuilder = JwtUtils.idTokenClaims(
-					registeredClient, issuer, authorization.getPrincipalName(), null);
-
-			// @formatter:off
-			context = JwtEncodingContext.with(headersBuilder, claimsBuilder)
-					.registeredClient(registeredClient)
-					.principal(authorization.getAttribute(Principal.class.getName()))
-					.authorization(authorization)
-					.authorizedScopes(authorizedScopes)
-					.tokenType(ID_TOKEN_TOKEN_TYPE)
-					.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
-					.authorizationGrant(refreshTokenAuthentication)
-					.build();
-			// @formatter:on
-
-			this.jwtCustomizer.customize(context);
-
-			headers = context.getHeaders().build();
-			claims = context.getClaims().build();
-			jwtIdToken = this.jwtEncoder.encode(headers, claims);
+		// ----- Refresh token -----
+		OAuth2RefreshToken currentRefreshToken = refreshToken.getToken();
+		if (!registeredClient.getTokenSettings().isReuseRefreshTokens()) {
+			currentRefreshToken = generateRefreshToken(registeredClient.getTokenSettings().getRefreshTokenTimeToLive());
+			authorizationBuilder.refreshToken(currentRefreshToken);
 		}
 
+		// ----- ID token -----
 		OidcIdToken idToken;
-		if (jwtIdToken != null) {
-			idToken = new OidcIdToken(jwtIdToken.getTokenValue(), jwtIdToken.getIssuedAt(),
-					jwtIdToken.getExpiresAt(), jwtIdToken.getClaims());
+		if (authorizedScopes.contains(OidcScopes.OPENID)) {
+			tokenContext = tokenContextBuilder.tokenType(ID_TOKEN_TOKEN_TYPE).build();
+			OAuth2Token generatedIdToken = this.tokenGenerator.generate(tokenContext);
+			if (!(generatedIdToken instanceof Jwt)) {
+				OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
+						"The token generator failed to generate the ID token.", ERROR_URI);
+				throw new OAuth2AuthenticationException(error);
+			}
+			idToken = new OidcIdToken(generatedIdToken.getTokenValue(), generatedIdToken.getIssuedAt(),
+					generatedIdToken.getExpiresAt(), ((Jwt) generatedIdToken).getClaims());
+			authorizationBuilder.token(idToken, (metadata) ->
+					metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims()));
 		} else {
 			idToken = null;
 		}
 
-		// @formatter:off
-		OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
-				.token(accessToken,
-						(metadata) -> {
-								metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, jwtAccessToken.getClaims());
-								metadata.put(OAuth2Authorization.Token.INVALIDATED_METADATA_NAME, false);
-						})
-				.refreshToken(currentRefreshToken);
-		if (idToken != null) {
-			authorizationBuilder
-					.token(idToken,
-							(metadata) ->
-									metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims()));
-		}
 		authorization = authorizationBuilder.build();
-		// @formatter:on
 
 		this.authorizationService.save(authorization);
 
@@ -271,4 +268,5 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
 		Instant expiresAt = issuedAt.plus(tokenTimeToLive);
 		return new OAuth2RefreshToken(this.refreshTokenGenerator.get(), issuedAt, expiresAt);
 	}
+
 }

+ 0 - 76
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/JwtUtils.java

@@ -1,76 +0,0 @@
-/*
- * Copyright 2020-2021 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.server.authorization.oidc.authentication;
-
-import java.time.Instant;
-import java.util.Collections;
-import java.util.Set;
-
-import org.springframework.security.authentication.AuthenticationProvider;
-import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
-import org.springframework.security.oauth2.jwt.JoseHeader;
-import org.springframework.security.oauth2.jwt.Jwt;
-import org.springframework.security.oauth2.jwt.JwtClaimsSet;
-import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
-import org.springframework.util.CollectionUtils;
-import org.springframework.util.StringUtils;
-
-/**
- * TODO
- * This class is mostly a straight copy from {@code org.springframework.security.oauth2.server.authorization.authentication.JwtUtils}.
- * It should be consolidated when we introduce a token generator abstraction.
- *
- * Utility methods used by the {@link AuthenticationProvider}'s when issuing {@link Jwt}'s.
- *
- * @author Ovidiu Popa
- * @since 0.2.1
- */
-final class JwtUtils {
-
-	private JwtUtils() {
-	}
-
-	static JoseHeader.Builder headers() {
-		return JoseHeader.withAlgorithm(SignatureAlgorithm.RS256);
-	}
-
-	static JwtClaimsSet.Builder accessTokenClaims(RegisteredClient registeredClient,
-			String issuer, String subject, Set<String> authorizedScopes) {
-
-		Instant issuedAt = Instant.now();
-		Instant expiresAt = issuedAt.plus(registeredClient.getTokenSettings().getAccessTokenTimeToLive());
-
-		// @formatter:off
-		JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder();
-		if (StringUtils.hasText(issuer)) {
-			claimsBuilder.issuer(issuer);
-		}
-		claimsBuilder
-				.subject(subject)
-				.audience(Collections.singletonList(registeredClient.getClientId()))
-				.issuedAt(issuedAt)
-				.expiresAt(expiresAt)
-				.notBefore(issuedAt);
-		if (!CollectionUtils.isEmpty(authorizedScopes)) {
-			claimsBuilder.claim(OAuth2ParameterNames.SCOPE, authorizedScopes);
-		}
-		// @formatter:on
-
-		return claimsBuilder;
-	}
-
-}

+ 62 - 22
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProvider.java

@@ -38,6 +38,7 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.core.OAuth2TokenType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -45,12 +46,15 @@ import org.springframework.security.oauth2.core.oidc.OidcClientMetadataClaimName
 import org.springframework.security.oauth2.core.oidc.OidcClientRegistration;
 import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
-import org.springframework.security.oauth2.jwt.JoseHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
-import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.DefaultOAuth2TokenContext;
+import org.springframework.security.oauth2.server.authorization.JwtGenerator;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
@@ -73,11 +77,12 @@ import org.springframework.web.util.UriComponentsBuilder;
  * @since 0.1.1
  * @see RegisteredClientRepository
  * @see OAuth2AuthorizationService
- * @see JwtEncoder
+ * @see OAuth2TokenGenerator
  * @see <a href="https://openid.net/specs/openid-connect-registration-1_0.html#ClientRegistration">3. Client Registration Endpoint</a>
  * @see <a href="https://openid.net/specs/openid-connect-registration-1_0.html#ClientConfigurationEndpoint">4. Client Configuration Endpoint</a>
  */
 public final class OidcClientRegistrationAuthenticationProvider implements AuthenticationProvider {
+	private static final String ERROR_URI = "https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationError";
 	private static final StringKeyGenerator CLIENT_ID_GENERATOR = new Base64StringKeyGenerator(
 			Base64.getUrlEncoder().withoutPadding(), 32);
 	private static final StringKeyGenerator CLIENT_SECRET_GENERATOR = new Base64StringKeyGenerator(
@@ -86,7 +91,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 	private static final String DEFAULT_CLIENT_CONFIGURATION_AUTHORIZED_SCOPE = "client.read";
 	private final RegisteredClientRepository registeredClientRepository;
 	private final OAuth2AuthorizationService authorizationService;
-	private JwtEncoder jwtEncoder;
+	private OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator;
 
 	/**
 	 * Constructs an {@code OidcClientRegistrationAuthenticationProvider} using the provided parameters.
@@ -110,7 +115,9 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 	 * @param registeredClientRepository the repository of registered clients
 	 * @param authorizationService the authorization service
 	 * @param jwtEncoder the jwt encoder
+	 * @deprecated Use {@link #OidcClientRegistrationAuthenticationProvider(RegisteredClientRepository, OAuth2AuthorizationService, OAuth2TokenGenerator)} instead
 	 */
+	@Deprecated
 	public OidcClientRegistrationAuthenticationProvider(RegisteredClientRepository registeredClientRepository,
 			OAuth2AuthorizationService authorizationService, JwtEncoder jwtEncoder) {
 		Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
@@ -118,13 +125,31 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 		Assert.notNull(jwtEncoder, "jwtEncoder cannot be null");
 		this.registeredClientRepository = registeredClientRepository;
 		this.authorizationService = authorizationService;
-		this.jwtEncoder = jwtEncoder;
+		this.tokenGenerator = new JwtGenerator(jwtEncoder);
+	}
+
+	/**
+	 * Constructs an {@code OidcClientRegistrationAuthenticationProvider} using the provided parameters.
+	 *
+	 * @param registeredClientRepository the repository of registered clients
+	 * @param authorizationService the authorization service
+	 * @param tokenGenerator the token generator
+	 * @since 0.2.3
+	 */
+	public OidcClientRegistrationAuthenticationProvider(RegisteredClientRepository registeredClientRepository,
+			OAuth2AuthorizationService authorizationService, OAuth2TokenGenerator<? extends OAuth2Token> tokenGenerator) {
+		Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
+		Assert.notNull(authorizationService, "authorizationService cannot be null");
+		Assert.notNull(tokenGenerator, "tokenGenerator cannot be null");
+		this.registeredClientRepository = registeredClientRepository;
+		this.authorizationService = authorizationService;
+		this.tokenGenerator = tokenGenerator;
 	}
 
 	@Deprecated
 	@Autowired(required = false)
 	protected void setJwtEncoder(JwtEncoder jwtEncoder) {
-		this.jwtEncoder = jwtEncoder;
+		this.tokenGenerator = new JwtGenerator(jwtEncoder);
 	}
 
 	@Deprecated
@@ -227,37 +252,52 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 	}
 
 	private OAuth2Authorization registerAccessToken(RegisteredClient registeredClient) {
-		JoseHeader headers = JwtUtils.headers().build();
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient,
+				registeredClient.getClientAuthenticationMethods().iterator().next(), registeredClient.getClientSecret());
 
 		Set<String> authorizedScopes = new HashSet<>();
 		authorizedScopes.add(DEFAULT_CLIENT_CONFIGURATION_AUTHORIZED_SCOPE);
 		authorizedScopes = Collections.unmodifiableSet(authorizedScopes);
 
-		String issuer = ProviderContextHolder.getProviderContext().getIssuer();
-		JwtClaimsSet claims = JwtUtils.accessTokenClaims(
-				registeredClient, issuer, registeredClient.getClientId(), authorizedScopes)
+		// @formatter:off
+		OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder()
+				.registeredClient(registeredClient)
+				.principal(clientPrincipal)
+				.providerContext(ProviderContextHolder.getProviderContext())
+				.authorizedScopes(authorizedScopes)
+				.tokenType(OAuth2TokenType.ACCESS_TOKEN)
+				.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
 				.build();
+		// @formatter:on
 
-		Jwt registrationAccessToken = this.jwtEncoder.encode(headers, claims);
-
+		OAuth2Token registrationAccessToken = this.tokenGenerator.generate(tokenContext);
+		if (registrationAccessToken == null) {
+			OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,
+					"The token generator failed to generate the registration access token.", ERROR_URI);
+			throw new OAuth2AuthenticationException(error);
+		}
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
 				registrationAccessToken.getTokenValue(), registrationAccessToken.getIssuedAt(),
-				registrationAccessToken.getExpiresAt(), authorizedScopes);
+				registrationAccessToken.getExpiresAt(), tokenContext.getAuthorizedScopes());
 
 		// @formatter:off
-		OAuth2Authorization registeredClientAuthorization = OAuth2Authorization.withRegisteredClient(registeredClient)
+		OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.withRegisteredClient(registeredClient)
 				.principalName(registeredClient.getClientId())
 				.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
-				.token(accessToken,
-						(metadata) ->
-								metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, registrationAccessToken.getClaims()))
-				.attribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, authorizedScopes)
-				.build();
+				.attribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, authorizedScopes);
 		// @formatter:on
+		if (registrationAccessToken instanceof Jwt) {
+			authorizationBuilder.token(accessToken, (metadata) ->
+					metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, ((Jwt) registrationAccessToken).getClaims()));
+		} else {
+			authorizationBuilder.accessToken(accessToken);
+		}
+
+		OAuth2Authorization authorization = authorizationBuilder.build();
 
-		this.authorizationService.save(registeredClientAuthorization);
+		this.authorizationService.save(authorization);
 
-		return registeredClientAuthorization;
+		return authorization;
 	}
 
 	private OidcClientRegistration.Builder buildRegistration(RegisteredClient registeredClient) {
@@ -445,7 +485,7 @@ public final class OidcClientRegistrationAuthenticationProvider implements Authe
 		OAuth2Error error = new OAuth2Error(
 				errorCode,
 				"Invalid Client Registration: " + fieldName,
-				"https://openid.net/specs/openid-connect-registration-1_0.html#RegistrationError");
+				ERROR_URI);
 		throw new OAuth2AuthenticationException(error);
 	}
 

+ 26 - 4
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java

@@ -84,11 +84,14 @@ import org.springframework.security.oauth2.jwt.NimbusJwsEncoder;
 import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationConsentService;
 import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
+import org.springframework.security.oauth2.server.authorization.JwtGenerator;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsent;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationConsentService;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeRequestAuthenticationToken;
@@ -184,6 +187,9 @@ public class OAuth2AuthorizationCodeGrantTests {
 	@Autowired
 	private JwtDecoder jwtDecoder;
 
+	@Autowired(required = false)
+	private OAuth2TokenGenerator<?> tokenGenerator;
+
 	@BeforeClass
 	public static void init() {
 		JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK);
@@ -425,8 +431,8 @@ public class OAuth2AuthorizationCodeGrantTests {
 	}
 
 	@Test
-	public void requestWhenCustomJwtEncoderThenUsed() throws Exception {
-		this.spring.register(AuthorizationServerConfigurationWithJwtEncoder.class).autowire();
+	public void requestWhenCustomTokenGeneratorThenUsed() throws Exception {
+		this.spring.register(AuthorizationServerConfigurationWithTokenGenerator.class).autowire();
 
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
 		this.registeredClientRepository.save(registeredClient);
@@ -436,7 +442,10 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 		this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
 				.params(getTokenRequestParameters(registeredClient, authorization))
-				.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient)));
+				.header(HttpHeaders.AUTHORIZATION, getAuthorizationHeader(registeredClient)))
+				.andExpect(status().isOk());
+
+		verify(this.tokenGenerator).generate(any());
 	}
 
 	@Test
@@ -822,12 +831,25 @@ public class OAuth2AuthorizationCodeGrantTests {
 
 	@EnableWebSecurity
 	@Import(OAuth2AuthorizationServerConfiguration.class)
-	static class AuthorizationServerConfigurationWithJwtEncoder extends AuthorizationServerConfiguration {
+	static class AuthorizationServerConfigurationWithTokenGenerator extends AuthorizationServerConfiguration {
 
 		@Bean
 		JwtEncoder jwtEncoder() {
 			return jwtEncoder;
 		}
+
+		@Bean
+		OAuth2TokenGenerator<?> tokenGenerator() {
+			JwtGenerator jwtGenerator = new JwtGenerator(jwtEncoder());
+			jwtGenerator.setJwtCustomizer(jwtCustomizer());
+			return spy(new OAuth2TokenGenerator<Jwt>() {
+				@Override
+				public Jwt generate(OAuth2TokenContext context) {
+					return jwtGenerator.generate(context);
+				}
+			});
+		}
+
 	}
 
 	@EnableWebSecurity

+ 77 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2021 the original author or authors.
+ * Copyright 2020-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -48,6 +48,7 @@ import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
 import org.springframework.mock.http.client.MockClientHttpResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
 import org.springframework.security.config.test.SpringTestRule;
@@ -67,11 +68,16 @@ import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames
 import org.springframework.security.oauth2.jose.TestJwks;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtDecoder;
+import org.springframework.security.oauth2.jwt.NimbusJwsEncoder;
 import org.springframework.security.oauth2.server.authorization.JdbcOAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
+import org.springframework.security.oauth2.server.authorization.JwtGenerator;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
+import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
 import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.JdbcRegisteredClientRepository.RegisteredClientParametersMapper;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@@ -79,6 +85,8 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
 import org.springframework.security.oauth2.server.authorization.jackson2.TestingAuthenticationTokenMixin;
+import org.springframework.security.web.SecurityFilterChain;
+import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.MvcResult;
 import org.springframework.util.LinkedMultiValueMap;
@@ -90,6 +98,10 @@ import org.springframework.web.util.UriComponentsBuilder;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.hamcrest.CoreMatchers.containsString;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
@@ -101,6 +113,7 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
  * Integration tests for OpenID Connect 1.0.
  *
  * @author Daniel Garnier-Moiroux
+ * @author Joe Grandja
  */
 public class OidcTests {
 	private static final String DEFAULT_AUTHORIZATION_ENDPOINT_URI = "/oauth2/authorize";
@@ -132,6 +145,9 @@ public class OidcTests {
 	@Autowired
 	private JwtDecoder jwtDecoder;
 
+	@Autowired(required = false)
+	private OAuth2TokenGenerator<?> tokenGenerator;
+
 	@BeforeClass
 	public static void init() {
 		JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK);
@@ -230,6 +246,25 @@ public class OidcTests {
 		assertThat(authoritiesClaim).containsExactlyInAnyOrderElementsOf(userAuthorities);
 	}
 
+	@Test
+	public void requestWhenCustomTokenGeneratorThenUsed() throws Exception {
+		this.spring.register(AuthorizationServerConfigurationWithTokenGenerator.class).autowire();
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
+		this.registeredClientRepository.save(registeredClient);
+
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		this.authorizationService.save(authorization);
+
+		this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
+				.params(getTokenRequestParameters(registeredClient, authorization))
+				.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
+						registeredClient.getClientId(), registeredClient.getClientSecret())))
+				.andExpect(status().isOk());
+
+		verify(this.tokenGenerator, times(2)).generate(any());
+	}
+
 	private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) {
 		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
 		parameters.set(OAuth2ParameterNames.RESPONSE_TYPE, OAuth2AuthorizationResponseType.CODE.getValue());
@@ -339,6 +374,46 @@ public class OidcTests {
 
 	}
 
+	@EnableWebSecurity
+	static class AuthorizationServerConfigurationWithTokenGenerator extends AuthorizationServerConfiguration {
+
+		// @formatter:off
+		@Bean
+		public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity http) throws Exception {
+			OAuth2AuthorizationServerConfigurer<HttpSecurity> authorizationServerConfigurer =
+					new OAuth2AuthorizationServerConfigurer<>();
+			http.apply(authorizationServerConfigurer);
+
+			authorizationServerConfigurer
+					.tokenGenerator(tokenGenerator());
+
+			RequestMatcher endpointsMatcher = authorizationServerConfigurer.getEndpointsMatcher();
+
+			http
+					.requestMatcher(endpointsMatcher)
+					.authorizeRequests(authorizeRequests ->
+							authorizeRequests.anyRequest().authenticated()
+					)
+					.csrf(csrf -> csrf.ignoringRequestMatchers(endpointsMatcher));
+
+			return http.build();
+		}
+		// @formatter:on
+
+		@Bean
+		OAuth2TokenGenerator<?> tokenGenerator() {
+			JwtGenerator jwtGenerator = new JwtGenerator(new NimbusJwsEncoder(jwkSource()));
+			jwtGenerator.setJwtCustomizer(jwtCustomizer());
+			return spy(new OAuth2TokenGenerator<Jwt>() {
+				@Override
+				public Jwt generate(OAuth2TokenContext context) {
+					return jwtGenerator.generate(context);
+				}
+			});
+		}
+
+	}
+
 	@EnableWebSecurity
 	@Import(OAuth2AuthorizationServerConfiguration.class)
 	static class AuthorizationServerConfigurationWithIssuer extends AuthorizationServerConfiguration {
@@ -368,4 +443,5 @@ public class OidcTests {
 			return ProviderSettings.builder().issuer("https://not a valid uri").build();
 		}
 	}
+
 }

+ 215 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JwtGeneratorTests.java

@@ -0,0 +1,215 @@
+/*
+ * Copyright 2020-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization;
+
+import java.security.Principal;
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.mockito.ArgumentCaptor;
+
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2TokenType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
+import org.springframework.security.oauth2.core.oidc.OidcScopes;
+import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
+import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
+import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
+import org.springframework.security.oauth2.jwt.JoseHeader;
+import org.springframework.security.oauth2.jwt.JwtClaimsSet;
+import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
+import org.springframework.security.oauth2.server.authorization.context.ProviderContext;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+
+/**
+ * Tests for {@link JwtGenerator}.
+ *
+ * @author Joe Grandja
+ */
+public class JwtGeneratorTests {
+	private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = new OAuth2TokenType(OidcParameterNames.ID_TOKEN);
+	private JwtEncoder jwtEncoder;
+	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer;
+	private JwtGenerator jwtGenerator;
+	private ProviderContext providerContext;
+
+	@Before
+	public void setUp() {
+		this.jwtEncoder = mock(JwtEncoder.class);
+		this.jwtCustomizer = mock(OAuth2TokenCustomizer.class);
+		this.jwtGenerator = new JwtGenerator(this.jwtEncoder);
+		this.jwtGenerator.setJwtCustomizer(this.jwtCustomizer);
+		ProviderSettings providerSettings = ProviderSettings.builder().issuer("https://provider.com").build();
+		this.providerContext = new ProviderContext(providerSettings, null);
+	}
+
+	@Test
+	public void constructorWhenJwtEncoderNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new JwtGenerator(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("jwtEncoder cannot be null");
+	}
+
+	@Test
+	public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.jwtGenerator.setJwtCustomizer(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("jwtCustomizer cannot be null");
+	}
+
+	@Test
+	public void generateWhenUnsupportedTokenTypeThenReturnNull() {
+		// @formatter:off
+		OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder()
+				.tokenType(new OAuth2TokenType("unsupported_token_type"))
+				.build();
+		// @formatter:on
+
+		assertThat(this.jwtGenerator.generate(tokenContext)).isNull();
+	}
+
+	@Test
+	public void generateWhenAccessTokenTypeThenReturnJwt() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+				registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+				OAuth2AuthorizationRequest.class.getName());
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri(), null);
+
+		// @formatter:off
+		OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder()
+				.registeredClient(registeredClient)
+				.principal(authorization.getAttribute(Principal.class.getName()))
+				.providerContext(this.providerContext)
+				.authorization(authorization)
+				.authorizedScopes(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME))
+				.tokenType(OAuth2TokenType.ACCESS_TOKEN)
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.authorizationGrant(authentication)
+				.build();
+		// @formatter:on
+
+		assertGeneratedTokenType(tokenContext);
+	}
+
+	@Test
+	public void generateWhenIdTokenTypeThenReturnJwt() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
+		Map<String, Object> authenticationRequestAdditionalParameters = new HashMap<>();
+		authenticationRequestAdditionalParameters.put(OidcParameterNames.NONCE, "nonce");
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(
+				registeredClient, authenticationRequestAdditionalParameters).build();
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+				registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+				OAuth2AuthorizationRequest.class.getName());
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken("code", clientPrincipal, authorizationRequest.getRedirectUri(), null);
+
+		// @formatter:off
+		OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder()
+				.registeredClient(registeredClient)
+				.principal(authorization.getAttribute(Principal.class.getName()))
+				.providerContext(this.providerContext)
+				.authorization(authorization)
+				.authorizedScopes(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME))
+				.tokenType(ID_TOKEN_TOKEN_TYPE)
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.authorizationGrant(authentication)
+				.build();
+		// @formatter:on
+
+		assertGeneratedTokenType(tokenContext);
+	}
+
+	private void assertGeneratedTokenType(OAuth2TokenContext tokenContext) {
+		this.jwtGenerator.generate(tokenContext);
+
+		ArgumentCaptor<JwtEncodingContext> jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class);
+		verify(this.jwtCustomizer).customize(jwtEncodingContextCaptor.capture());
+
+		JwtEncodingContext jwtEncodingContext = jwtEncodingContextCaptor.getValue();
+		assertThat(jwtEncodingContext.getHeaders()).isNotNull();
+		assertThat(jwtEncodingContext.getClaims()).isNotNull();
+		assertThat(jwtEncodingContext.getRegisteredClient()).isEqualTo(tokenContext.getRegisteredClient());
+		assertThat(jwtEncodingContext.<Authentication>getPrincipal()).isEqualTo(tokenContext.getPrincipal());
+		assertThat(jwtEncodingContext.getAuthorization()).isEqualTo(tokenContext.getAuthorization());
+		assertThat(jwtEncodingContext.getAuthorizedScopes()).isEqualTo(tokenContext.getAuthorizedScopes());
+		assertThat(jwtEncodingContext.getTokenType()).isEqualTo(tokenContext.getTokenType());
+		assertThat(jwtEncodingContext.getAuthorizationGrantType()).isEqualTo(tokenContext.getAuthorizationGrantType());
+		assertThat(jwtEncodingContext.<Authentication>getAuthorizationGrant()).isEqualTo(tokenContext.getAuthorizationGrant());
+
+		ArgumentCaptor<JoseHeader> joseHeaderCaptor = ArgumentCaptor.forClass(JoseHeader.class);
+		ArgumentCaptor<JwtClaimsSet> jwtClaimsSetCaptor = ArgumentCaptor.forClass(JwtClaimsSet.class);
+		verify(this.jwtEncoder).encode(joseHeaderCaptor.capture(), jwtClaimsSetCaptor.capture());
+
+		JoseHeader joseHeader = joseHeaderCaptor.getValue();
+		assertThat(joseHeader.<JwsAlgorithm>getAlgorithm()).isEqualTo(SignatureAlgorithm.RS256);
+
+		JwtClaimsSet jwtClaimsSet = jwtClaimsSetCaptor.getValue();
+		assertThat(jwtClaimsSet.getIssuer().toExternalForm()).isEqualTo(tokenContext.getProviderContext().getIssuer());
+		assertThat(jwtClaimsSet.getSubject()).isEqualTo(tokenContext.getAuthorization().getPrincipalName());
+		assertThat(jwtClaimsSet.getAudience()).containsExactly(tokenContext.getRegisteredClient().getClientId());
+
+		Instant issuedAt = Instant.now();
+		Instant expiresAt;
+		if (tokenContext.getTokenType().equals(OAuth2TokenType.ACCESS_TOKEN)) {
+			expiresAt = issuedAt.plus(tokenContext.getRegisteredClient().getTokenSettings().getAccessTokenTimeToLive());
+		} else {
+			expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES);
+		}
+		assertThat(jwtClaimsSet.getIssuedAt()).isBetween(issuedAt.minusSeconds(1), issuedAt.plusSeconds(1));
+		assertThat(jwtClaimsSet.getExpiresAt()).isBetween(expiresAt.minusSeconds(1), expiresAt.plusSeconds(1));
+
+		if (tokenContext.getTokenType().equals(OAuth2TokenType.ACCESS_TOKEN)) {
+			assertThat(jwtClaimsSet.getNotBefore()).isBetween(issuedAt.minusSeconds(1), issuedAt.plusSeconds(1));
+
+			Set<String> scopes = jwtClaimsSet.getClaim(OAuth2ParameterNames.SCOPE);
+			assertThat(scopes).isEqualTo(tokenContext.getAuthorizedScopes());
+		} else {
+			assertThat(jwtClaimsSet.<String>getClaim(IdTokenClaimNames.AZP)).isEqualTo(tokenContext.getRegisteredClient().getClientId());
+
+			OAuth2AuthorizationRequest authorizationRequest = tokenContext.getAuthorization().getAttribute(
+					OAuth2AuthorizationRequest.class.getName());
+			String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE);
+			assertThat(jwtClaimsSet.<String>getClaim(IdTokenClaimNames.NONCE)).isEqualTo(nonce);
+		}
+	}
+
+}

+ 91 - 4
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java

@@ -49,9 +49,12 @@ import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
 import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
+import org.springframework.security.oauth2.server.authorization.JwtGenerator;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
@@ -65,6 +68,7 @@ import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.assertj.core.api.Assertions.entry;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
@@ -83,16 +87,24 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 	private OAuth2AuthorizationService authorizationService;
 	private JwtEncoder jwtEncoder;
 	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer;
+	private OAuth2TokenGenerator<?> tokenGenerator;
 	private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider;
 
 	@Before
 	public void setUp() {
 		this.authorizationService = mock(OAuth2AuthorizationService.class);
 		this.jwtEncoder = mock(JwtEncoder.class);
-		this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(
-				this.authorizationService, this.jwtEncoder);
 		this.jwtCustomizer = mock(OAuth2TokenCustomizer.class);
-		this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer);
+		JwtGenerator jwtGenerator = new JwtGenerator(this.jwtEncoder);
+		jwtGenerator.setJwtCustomizer(this.jwtCustomizer);
+		this.tokenGenerator = spy(new OAuth2TokenGenerator<Jwt>() {
+			@Override
+			public Jwt generate(OAuth2TokenContext context) {
+				return jwtGenerator.generate(context);
+			}
+		});
+		this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(
+				this.authorizationService, this.tokenGenerator);
 		ProviderSettings providerSettings = ProviderSettings.builder().issuer("https://provider.com").build();
 		ProviderContextHolder.setProviderContext(new ProviderContext(providerSettings, null));
 	}
@@ -111,11 +123,18 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 
 	@Test
 	public void constructorWhenJwtEncoderNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(this.authorizationService, null))
+		assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(this.authorizationService, (JwtEncoder) null))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("jwtEncoder cannot be null");
 	}
 
+	@Test
+	public void constructorWhenTokenGeneratorNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2AuthorizationCodeAuthenticationProvider(this.authorizationService, (OAuth2TokenGenerator<?>) null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("tokenGenerator cannot be null");
+	}
+
 	@Test
 	public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null))
@@ -273,6 +292,74 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
 	}
 
+	@Test
+	public void authenticateWhenAccessTokenNotGeneratedThenThrowOAuth2AuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)))
+				.thenReturn(authorization);
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+				registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+				OAuth2AuthorizationRequest.class.getName());
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null);
+
+		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());
+
+		doAnswer(answer -> {
+			OAuth2TokenContext context = answer.getArgument(0);
+			if (OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType())) {
+				return null;
+			} else {
+				return answer.callRealMethod();
+			}
+		}).when(this.tokenGenerator).generate(any());
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
+					assertThat(error.getDescription()).contains("The token generator failed to generate the access token.");
+				});
+	}
+
+	@Test
+	public void authenticateWhenIdTokenNotGeneratedThenThrowOAuth2AuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)))
+				.thenReturn(authorization);
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+				registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+				OAuth2AuthorizationRequest.class.getName());
+		OAuth2AuthorizationCodeAuthenticationToken authentication =
+				new OAuth2AuthorizationCodeAuthenticationToken(AUTHORIZATION_CODE, clientPrincipal, authorizationRequest.getRedirectUri(), null);
+
+		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());
+
+		doAnswer(answer -> {
+			OAuth2TokenContext context = answer.getArgument(0);
+			if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) {
+				return null;
+			} else {
+				return answer.callRealMethod();
+			}
+		}).when(this.tokenGenerator).generate(any());
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
+					assertThat(error.getDescription()).contains("The token generator failed to generate the ID token.");
+				});
+	}
+
 	@Test
 	public void authenticateWhenValidCodeThenReturnAccessToken() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();

+ 43 - 4
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java

@@ -38,9 +38,12 @@ import org.springframework.security.oauth2.jwt.JoseHeaderNames;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
 import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
+import org.springframework.security.oauth2.server.authorization.JwtGenerator;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
@@ -50,7 +53,9 @@ import org.springframework.security.oauth2.server.authorization.context.Provider
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -64,16 +69,24 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 	private OAuth2AuthorizationService authorizationService;
 	private JwtEncoder jwtEncoder;
 	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer;
+	private OAuth2TokenGenerator<?> tokenGenerator;
 	private OAuth2ClientCredentialsAuthenticationProvider authenticationProvider;
 
 	@Before
 	public void setUp() {
 		this.authorizationService = mock(OAuth2AuthorizationService.class);
 		this.jwtEncoder = mock(JwtEncoder.class);
-		this.authenticationProvider = new OAuth2ClientCredentialsAuthenticationProvider(
-				this.authorizationService, this.jwtEncoder);
 		this.jwtCustomizer = mock(OAuth2TokenCustomizer.class);
-		this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer);
+		JwtGenerator jwtGenerator = new JwtGenerator(this.jwtEncoder);
+		jwtGenerator.setJwtCustomizer(this.jwtCustomizer);
+		this.tokenGenerator = spy(new OAuth2TokenGenerator<Jwt>() {
+			@Override
+			public Jwt generate(OAuth2TokenContext context) {
+				return jwtGenerator.generate(context);
+			}
+		});
+		this.authenticationProvider = new OAuth2ClientCredentialsAuthenticationProvider(
+				this.authorizationService, this.tokenGenerator);
 		ProviderSettings providerSettings = ProviderSettings.builder().issuer("https://provider.com").build();
 		ProviderContextHolder.setProviderContext(new ProviderContext(providerSettings, null));
 	}
@@ -92,11 +105,18 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 
 	@Test
 	public void constructorWhenJwtEncoderNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationProvider(this.authorizationService, null))
+		assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationProvider(this.authorizationService, (JwtEncoder) null))
 				.isInstanceOf(IllegalArgumentException.class)
 				.hasMessage("jwtEncoder cannot be null");
 	}
 
+	@Test
+	public void constructorWhenTokenGeneratorNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2ClientCredentialsAuthenticationProvider(this.authorizationService, (OAuth2TokenGenerator<?>) null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("tokenGenerator cannot be null");
+	}
+
 	@Test
 	public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null))
@@ -193,6 +213,25 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 		assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(requestedScope);
 	}
 
+	@Test
+	public void authenticateWhenAccessTokenNotGeneratedThenThrowOAuth2AuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+				registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
+		OAuth2ClientCredentialsAuthenticationToken authentication =
+				new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, null, null);
+
+		doReturn(null).when(this.tokenGenerator).generate(any());
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
+					assertThat(error.getDescription()).contains("The token generator failed to generate the access token.");
+				});
+	}
+
 	@Test
 	public void authenticateWhenValidAuthenticationThenReturnAccessToken() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient2().build();

+ 89 - 6
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java

@@ -47,9 +47,12 @@ import org.springframework.security.oauth2.jwt.JoseHeaderNames;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
 import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
+import org.springframework.security.oauth2.server.authorization.JwtGenerator;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
@@ -58,11 +61,12 @@ import org.springframework.security.oauth2.server.authorization.config.TokenSett
 import org.springframework.security.oauth2.server.authorization.context.ProviderContext;
 import org.springframework.security.oauth2.server.authorization.context.ProviderContextHolder;
 
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.assertj.core.api.Assertions.entry;
-import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
-import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
@@ -81,6 +85,7 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 	private OAuth2AuthorizationService authorizationService;
 	private JwtEncoder jwtEncoder;
 	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer;
+	private OAuth2TokenGenerator<?> tokenGenerator;
 	private OAuth2RefreshTokenAuthenticationProvider authenticationProvider;
 
 	@Before
@@ -88,10 +93,17 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		this.authorizationService = mock(OAuth2AuthorizationService.class);
 		this.jwtEncoder = mock(JwtEncoder.class);
 		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(Collections.singleton("scope1")));
-		this.authenticationProvider = new OAuth2RefreshTokenAuthenticationProvider(
-				this.authorizationService, this.jwtEncoder);
 		this.jwtCustomizer = mock(OAuth2TokenCustomizer.class);
-		this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer);
+		JwtGenerator jwtGenerator = new JwtGenerator(this.jwtEncoder);
+		jwtGenerator.setJwtCustomizer(this.jwtCustomizer);
+		this.tokenGenerator = spy(new OAuth2TokenGenerator<Jwt>() {
+			@Override
+			public Jwt generate(OAuth2TokenContext context) {
+				return jwtGenerator.generate(context);
+			}
+		});
+		this.authenticationProvider = new OAuth2RefreshTokenAuthenticationProvider(
+				this.authorizationService, this.tokenGenerator);
 		ProviderSettings providerSettings = ProviderSettings.builder().issuer("https://provider.com").build();
 		ProviderContextHolder.setProviderContext(new ProviderContext(providerSettings, null));
 	}
@@ -111,12 +123,19 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 
 	@Test
 	public void constructorWhenJwtEncoderNullThenThrowIllegalArgumentException() {
-		assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationProvider(this.authorizationService, null))
+		assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationProvider(this.authorizationService, (JwtEncoder) null))
 				.isInstanceOf(IllegalArgumentException.class)
 				.extracting(Throwable::getMessage)
 				.isEqualTo("jwtEncoder cannot be null");
 	}
 
+	@Test
+	public void constructorWhenTokenGeneratorNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new OAuth2RefreshTokenAuthenticationProvider(this.authorizationService, (OAuth2TokenGenerator<?>) null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("tokenGenerator cannot be null");
+	}
+
 	@Test
 	public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
 		assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null))
@@ -500,6 +519,70 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
 	}
 
+	@Test
+	public void authenticateWhenAccessTokenNotGeneratedThenThrowOAuth2AuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		when(this.authorizationService.findByToken(
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
+				eq(OAuth2TokenType.REFRESH_TOKEN)))
+				.thenReturn(authorization);
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+				registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
+		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
+
+		doAnswer(answer -> {
+			OAuth2TokenContext context = answer.getArgument(0);
+			if (OAuth2TokenType.ACCESS_TOKEN.equals(context.getTokenType())) {
+				return null;
+			} else {
+				return answer.callRealMethod();
+			}
+		}).when(this.tokenGenerator).generate(any());
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
+					assertThat(error.getDescription()).contains("The token generator failed to generate the access token.");
+				});
+	}
+
+	@Test
+	public void authenticateWhenIdTokenNotGeneratedThenThrowOAuth2AuthenticationException() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		when(this.authorizationService.findByToken(
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
+				eq(OAuth2TokenType.REFRESH_TOKEN)))
+				.thenReturn(authorization);
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+				registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
+		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
+
+		doAnswer(answer -> {
+			OAuth2TokenContext context = answer.getArgument(0);
+			if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) {
+				return null;
+			} else {
+				return answer.callRealMethod();
+			}
+		}).when(this.tokenGenerator).generate(any());
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
+					assertThat(error.getDescription()).contains("The token generator failed to generate the ID token.");
+				});
+	}
+
 	private static Jwt createJwt(Set<String> scope) {
 		Instant issuedAt = Instant.now();
 		Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS);

+ 62 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/oidc/authentication/OidcClientRegistrationAuthenticationProviderTests.java

@@ -48,8 +48,11 @@ import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
 import org.springframework.security.oauth2.jwt.TestJoseHeaders;
 import org.springframework.security.oauth2.jwt.TestJwtClaimsSets;
+import org.springframework.security.oauth2.server.authorization.JwtGenerator;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenContext;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenGenerator;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
@@ -66,8 +69,10 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
@@ -82,6 +87,7 @@ public class OidcClientRegistrationAuthenticationProviderTests {
 	private RegisteredClientRepository registeredClientRepository;
 	private OAuth2AuthorizationService authorizationService;
 	private JwtEncoder jwtEncoder;
+	private OAuth2TokenGenerator<?> tokenGenerator;
 	private ProviderSettings providerSettings;
 	private OidcClientRegistrationAuthenticationProvider authenticationProvider;
 
@@ -90,10 +96,17 @@ public class OidcClientRegistrationAuthenticationProviderTests {
 		this.registeredClientRepository = mock(RegisteredClientRepository.class);
 		this.authorizationService = mock(OAuth2AuthorizationService.class);
 		this.jwtEncoder = mock(JwtEncoder.class);
+		JwtGenerator jwtGenerator = new JwtGenerator(this.jwtEncoder);
+		this.tokenGenerator = spy(new OAuth2TokenGenerator<Jwt>() {
+			@Override
+			public Jwt generate(OAuth2TokenContext context) {
+				return jwtGenerator.generate(context);
+			}
+		});
 		this.providerSettings = ProviderSettings.builder().issuer("https://provider.com").build();
 		ProviderContextHolder.setProviderContext(new ProviderContext(this.providerSettings, null));
 		this.authenticationProvider = new OidcClientRegistrationAuthenticationProvider(
-				this.registeredClientRepository, this.authorizationService, this.jwtEncoder);
+				this.registeredClientRepository, this.authorizationService, this.tokenGenerator);
 	}
 
 	@After
@@ -118,10 +131,17 @@ public class OidcClientRegistrationAuthenticationProviderTests {
 	@Test
 	public void constructorWhenJwtEncoderNullThenThrowIllegalArgumentException() {
 		assertThatIllegalArgumentException()
-				.isThrownBy(() -> new OidcClientRegistrationAuthenticationProvider(this.registeredClientRepository, this.authorizationService, null))
+				.isThrownBy(() -> new OidcClientRegistrationAuthenticationProvider(this.registeredClientRepository, this.authorizationService, (JwtEncoder) null))
 				.withMessage("jwtEncoder cannot be null");
 	}
 
+	@Test
+	public void constructorWhenTokenGeneratorNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new OidcClientRegistrationAuthenticationProvider(this.registeredClientRepository, this.authorizationService, (OAuth2TokenGenerator<?>) null))
+				.withMessage("tokenGenerator cannot be null");
+	}
+
 	@Test
 	public void supportsWhenTypeOidcClientRegistrationAuthenticationTokenThenReturnTrue() {
 		assertThat(this.authenticationProvider.supports(OidcClientRegistrationAuthenticationToken.class)).isTrue();
@@ -464,6 +484,46 @@ public class OidcClientRegistrationAuthenticationProviderTests {
 				.isEqualTo(SignatureAlgorithm.RS256.getName());
 	}
 
+	@Test
+	public void authenticateWhenRegistrationAccessTokenNotGeneratedThenThrowOAuth2AuthenticationException() {
+		Jwt jwt = createJwtClientRegistration();
+		OAuth2AccessToken jwtAccessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+				jwt.getTokenValue(), jwt.getIssuedAt(),
+				jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE));
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(
+				registeredClient, jwtAccessToken, jwt.getClaims()).build();
+		when(this.authorizationService.findByToken(
+				eq(jwtAccessToken.getTokenValue()), eq(OAuth2TokenType.ACCESS_TOKEN)))
+				.thenReturn(authorization);
+
+		doReturn(null).when(this.tokenGenerator).generate(any());
+
+		JwtAuthenticationToken principal = new JwtAuthenticationToken(
+				jwt, AuthorityUtils.createAuthorityList("SCOPE_client.create"));
+		// @formatter:off
+		OidcClientRegistration clientRegistration = OidcClientRegistration.builder()
+				.clientName("client-name")
+				.redirectUri("https://client.example.com")
+				.grantType(AuthorizationGrantType.AUTHORIZATION_CODE.getValue())
+				.grantType(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue())
+				.scope("scope1")
+				.scope("scope2")
+				.build();
+		// @formatter:on
+
+		OidcClientRegistrationAuthenticationToken authentication = new OidcClientRegistrationAuthenticationToken(
+				principal, clientRegistration);
+
+		assertThatThrownBy(() -> this.authenticationProvider.authenticate(authentication))
+				.isInstanceOf(OAuth2AuthenticationException.class)
+				.extracting(ex -> ((OAuth2AuthenticationException) ex).getError())
+				.satisfies(error -> {
+					assertThat(error.getErrorCode()).isEqualTo(OAuth2ErrorCodes.SERVER_ERROR);
+					assertThat(error.getDescription()).contains("The token generator failed to generate the registration access token.");
+				});
+	}
+
 	@Test
 	public void authenticateWhenClientRegistrationRequestAndValidAccessTokenThenReturnClientRegistration() {
 		Jwt jwt = createJwtClientRegistration();