Browse Source

Add OAuth2TokenCustomizer

Closes gh-199
Joe Grandja 4 years ago
parent
commit
adf96b4e25
25 changed files with 1115 additions and 268 deletions
  1. 35 0
      oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java
  2. 47 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/context/Context.java
  3. 50 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/context/DefaultContext.java
  4. 14 37
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java
  5. 6 1
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java
  6. 123 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtEncodingContextUtils.java
  7. 47 16
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java
  8. 33 8
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java
  9. 45 8
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java
  10. 0 97
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenIssuerUtil.java
  11. 87 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtEncodingContext.java
  12. 119 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenContext.java
  13. 28 0
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenCustomizer.java
  14. 18 16
      oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java
  15. 55 12
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java
  16. 13 0
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientCredentialsGrantTests.java
  17. 54 2
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java
  18. 57 6
      oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java
  19. 0 24
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java
  20. 9 6
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java
  21. 55 6
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java
  22. 38 9
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java
  23. 47 12
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java
  24. 118 0
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtEncodingContextTests.java
  25. 17 8
      oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

+ 35 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationServerConfigurer.java

@@ -33,7 +33,9 @@ import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
 import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
 import org.springframework.security.config.annotation.web.configurers.ExceptionHandlingConfigurer;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.jwt.NimbusJwsEncoder;
 import org.springframework.security.oauth2.server.authorization.InMemoryOAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@@ -152,23 +154,33 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
 		builder.authenticationProvider(postProcess(clientAuthenticationProvider));
 
 		JwtEncoder jwtEncoder = getJwtEncoder(builder);
+		OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = getJwtCustomizer(builder);
 
 		OAuth2AuthorizationCodeAuthenticationProvider authorizationCodeAuthenticationProvider =
 				new OAuth2AuthorizationCodeAuthenticationProvider(
 						getAuthorizationService(builder),
 						jwtEncoder);
+		if (jwtCustomizer != null) {
+			authorizationCodeAuthenticationProvider.setJwtCustomizer(jwtCustomizer);
+		}
 		builder.authenticationProvider(postProcess(authorizationCodeAuthenticationProvider));
 
 		OAuth2RefreshTokenAuthenticationProvider refreshTokenAuthenticationProvider =
 				new OAuth2RefreshTokenAuthenticationProvider(
 						getAuthorizationService(builder),
 						jwtEncoder);
+		if (jwtCustomizer != null) {
+			refreshTokenAuthenticationProvider.setJwtCustomizer(jwtCustomizer);
+		}
 		builder.authenticationProvider(postProcess(refreshTokenAuthenticationProvider));
 
 		OAuth2ClientCredentialsAuthenticationProvider clientCredentialsAuthenticationProvider =
 				new OAuth2ClientCredentialsAuthenticationProvider(
 						getAuthorizationService(builder),
 						jwtEncoder);
+		if (jwtCustomizer != null) {
+			clientCredentialsAuthenticationProvider.setJwtCustomizer(jwtCustomizer);
+		}
 		builder.authenticationProvider(postProcess(clientCredentialsAuthenticationProvider));
 
 		OAuth2TokenRevocationAuthenticationProvider tokenRevocationAuthenticationProvider =
@@ -314,6 +326,19 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
 		return jwkSource;
 	}
 
+	@SuppressWarnings("unchecked")
+	private static <B extends HttpSecurityBuilder<B>> OAuth2TokenCustomizer<JwtEncodingContext> getJwtCustomizer(B builder) {
+		OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = builder.getSharedObject(OAuth2TokenCustomizer.class);
+		if (jwtCustomizer == null) {
+			ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2TokenCustomizer.class, JwtEncodingContext.class);
+			jwtCustomizer = getOptionalBean(builder, type);
+			if (jwtCustomizer != null) {
+				builder.setSharedObject(OAuth2TokenCustomizer.class, jwtCustomizer);
+			}
+		}
+		return jwtCustomizer;
+	}
+
 	private static <B extends HttpSecurityBuilder<B>> ProviderSettings getProviderSettings(B builder) {
 		ProviderSettings providerSettings = builder.getSharedObject(ProviderSettings.class);
 		if (providerSettings == null) {
@@ -353,4 +378,14 @@ public final class OAuth2AuthorizationServerConfigurer<B extends HttpSecurityBui
 		}
 		return (!beansMap.isEmpty() ? beansMap.values().iterator().next() : null);
 	}
+
+	@SuppressWarnings("unchecked")
+	private static <B extends HttpSecurityBuilder<B>, T> T getOptionalBean(B builder, ResolvableType type) {
+		ApplicationContext context = builder.getSharedObject(ApplicationContext.class);
+		String[] names = context.getBeanNamesForType(type);
+		if (names.length > 1) {
+			throw new NoUniqueBeanDefinitionException(type, names);
+		}
+		return names.length == 1 ? (T) context.getBean(names[0]) : null;
+	}
 }

+ 47 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/context/Context.java

@@ -0,0 +1,47 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.core.context;
+
+import java.util.Map;
+
+import org.springframework.lang.Nullable;
+import org.springframework.util.Assert;
+
+/**
+ * A facility for holding information associated to a specific context.
+ *
+ * @author Joe Grandja
+ * @since 0.1.0
+ */
+public interface Context {
+
+	@Nullable
+	<V> V get(Object key);
+
+	@Nullable
+	default <V> V get(Class<V> key) {
+		Assert.notNull(key, "key cannot be null");
+		V value = get((Object) key);
+		return key.isInstance(value) ? value : null;
+	}
+
+	boolean hasKey(Object key);
+
+	static Context of(Map<Object, Object> context) {
+		return new DefaultContext(context);
+	}
+
+}

+ 50 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/core/context/DefaultContext.java

@@ -0,0 +1,50 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.core.context;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.springframework.lang.Nullable;
+import org.springframework.util.Assert;
+
+/**
+ * @author Joe Grandja
+ * @since 0.1.0
+ */
+final class DefaultContext implements Context {
+	private final Map<Object, Object> context;
+
+	DefaultContext(Map<Object, Object> context) {
+		Assert.notNull(context, "context cannot be null");
+		this.context = Collections.unmodifiableMap(new HashMap<>(context));
+	}
+
+	@SuppressWarnings("unchecked")
+	@Override
+	@Nullable
+	public <V> V get(Object key) {
+		return hasKey(key) ? (V) this.context.get(key) : null;
+	}
+
+	@Override
+	public boolean hasKey(Object key) {
+		Assert.notNull(key, "key cannot be null");
+		return this.context.containsKey(key);
+	}
+
+}

+ 14 - 37
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoder.java

@@ -23,8 +23,6 @@ import java.util.Map;
 import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.atomic.AtomicReference;
-import java.util.function.BiConsumer;
 import java.util.stream.Collectors;
 
 import com.nimbusds.jose.JOSEException;
@@ -46,7 +44,6 @@ import com.nimbusds.jwt.JWTClaimsSet;
 import com.nimbusds.jwt.SignedJWT;
 
 import org.springframework.core.convert.converter.Converter;
-import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 import org.springframework.util.StringUtils;
@@ -88,9 +85,6 @@ public final class NimbusJwsEncoder implements JwtEncoder {
 
 	private final JWKSource<SecurityContext> jwkSource;
 
-	private BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer = (headers, claims) -> {
-	};
-
 	/**
 	 * Constructs a {@code NimbusJwsEncoder} using the provided parameters.
 	 * @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource}
@@ -100,32 +94,12 @@ public final class NimbusJwsEncoder implements JwtEncoder {
 		this.jwkSource = jwkSource;
 	}
 
-	/**
-	 * Sets the {@link Jwt} customizer to be provided the {@link JoseHeader.Builder} and
-	 * {@link JwtClaimsSet.Builder} allowing for further customizations.
-	 * @param jwtCustomizer the {@link Jwt} customizer to be provided the
-	 * {@link JoseHeader.Builder} and {@link JwtClaimsSet.Builder}
-	 */
-	public void setJwtCustomizer(BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer) {
-		Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
-		this.jwtCustomizer = jwtCustomizer;
-	}
-
 	@Override
 	public Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException {
 		Assert.notNull(headers, "headers cannot be null");
 		Assert.notNull(claims, "claims cannot be null");
 
-		// @formatter:off
-		JoseHeader.Builder headersBuilder = JoseHeader.from(headers)
-				.type(JOSEObjectType.JWT.getType());
-		JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.from(claims)
-				.id(UUID.randomUUID().toString());
-		// @formatter:on
-
-		this.jwtCustomizer.accept(headersBuilder, claimsBuilder);
-
-		JWK jwk = selectJwk(headersBuilder);
+		JWK jwk = selectJwk(headers);
 		if (jwk == null) {
 			throw new JwtEncodingException(
 					String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
@@ -135,8 +109,15 @@ public final class NimbusJwsEncoder implements JwtEncoder {
 					"The \"kid\" (key ID) from the selected JWK cannot be empty"));
 		}
 
-		headers = headersBuilder.keyId(jwk.getKeyID()).build();
-		claims = claimsBuilder.build();
+		// @formatter:off
+		headers = JoseHeader.from(headers)
+				.type(JOSEObjectType.JWT.getType())
+				.keyId(jwk.getKeyID())
+				.build();
+		claims = JwtClaimsSet.from(claims)
+				.id(UUID.randomUUID().toString())
+				.build();
+		// @formatter:on
 
 		JWSHeader jwsHeader = JWS_HEADER_CONVERTER.convert(headers);
 		JWTClaimsSet jwtClaimsSet = JWT_CLAIMS_SET_CONVERTER.convert(claims);
@@ -164,13 +145,9 @@ public final class NimbusJwsEncoder implements JwtEncoder {
 		return new Jwt(jws, claims.getIssuedAt(), claims.getExpiresAt(), headers.getHeaders(), claims.getClaims());
 	}
 
-	private JWK selectJwk(JoseHeader.Builder headersBuilder) {
-		final AtomicReference<JWSAlgorithm> jwsAlgorithm = new AtomicReference<>();
-		headersBuilder.headers((h) -> {
-			JwsAlgorithm jwsAlg = (JwsAlgorithm) h.get(JoseHeaderNames.ALG);
-			jwsAlgorithm.set(JWSAlgorithm.parse(jwsAlg.getName()));
-		});
-		JWSHeader jwsHeader = new JWSHeader(jwsAlgorithm.get());
+	private JWK selectJwk(JoseHeader headers) {
+		JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(headers.getJwsAlgorithm().getName());
+		JWSHeader jwsHeader = new JWSHeader(jwsAlgorithm);
 		JWKSelector jwkSelector = new JWKSelector(JWKMatcher.forJWSHeader(jwsHeader));
 
 		List<JWK> jwks;
@@ -184,7 +161,7 @@ public final class NimbusJwsEncoder implements JwtEncoder {
 
 		if (jwks.size() > 1) {
 			throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
-					"Found multiple JWK signing keys for algorithm '" + jwsAlgorithm.get().getName() + "'"));
+					"Found multiple JWK signing keys for algorithm '" + jwsAlgorithm.getName() + "'"));
 		}
 
 		return !jwks.isEmpty() ? jwks.get(0) : null;

+ 6 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/OAuth2AuthorizationAttributeNames.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -56,4 +56,9 @@ public interface OAuth2AuthorizationAttributeNames {
 	 */
 	String ACCESS_TOKEN_ATTRIBUTES = OAuth2Authorization.class.getName().concat(".ACCESS_TOKEN_ATTRIBUTES");
 
+	/**
+	 * The name of the attribute used for the resource owner {@code Principal}.
+	 */
+	String PRINCIPAL = OAuth2Authorization.class.getName().concat(".PRINCIPAL");
+
 }

+ 123 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/JwtEncodingContextUtils.java

@@ -0,0 +1,123 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.authentication;
+
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.Collections;
+import java.util.Set;
+
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
+import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
+import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
+import org.springframework.security.oauth2.jwt.JoseHeader;
+import org.springframework.security.oauth2.jwt.JwtClaimsSet;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
+import org.springframework.security.oauth2.server.authorization.TokenType;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.StringUtils;
+
+/**
+ * @author Joe Grandja
+ * @since 0.1.0
+ */
+final class JwtEncodingContextUtils {
+
+	private JwtEncodingContextUtils() {
+	}
+
+	static JwtEncodingContext.Builder accessTokenContext(RegisteredClient registeredClient, OAuth2Authorization authorization) {
+		// @formatter:off
+		return accessTokenContext(registeredClient, authorization,
+				authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES));
+		// @formatter:on
+	}
+
+	static JwtEncodingContext.Builder accessTokenContext(RegisteredClient registeredClient, OAuth2Authorization authorization,
+			Set<String> authorizedScopes) {
+		// @formatter:off
+		return accessTokenContext(registeredClient, authorization.getPrincipalName(), authorizedScopes)
+				.authorization(authorization);
+		// @formatter:on
+	}
+
+	static JwtEncodingContext.Builder accessTokenContext(RegisteredClient registeredClient,
+			String principalName, Set<String> authorizedScopes) {
+
+		JoseHeader.Builder headersBuilder = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256);
+
+		String issuer = "http://auth-server:9000";        // TODO Allow configuration for issuer claim
+		Instant issuedAt = Instant.now();
+		Instant expiresAt = issuedAt.plus(registeredClient.getTokenSettings().accessTokenTimeToLive());
+
+		// @formatter:off
+		JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder()
+				.issuer(issuer)
+				.subject(principalName)
+				.audience(Collections.singletonList(registeredClient.getClientId()))
+				.issuedAt(issuedAt)
+				.expiresAt(expiresAt)
+				.notBefore(issuedAt);
+		if (!CollectionUtils.isEmpty(authorizedScopes)) {
+			claimsBuilder.claim(OAuth2ParameterNames.SCOPE, authorizedScopes);
+		}
+		// @formatter:on
+
+		// @formatter:off
+		return JwtEncodingContext.with(headersBuilder, claimsBuilder)
+				.registeredClient(registeredClient)
+				.tokenType(TokenType.ACCESS_TOKEN);
+		// @formatter:on
+	}
+
+	static JwtEncodingContext.Builder idTokenContext(RegisteredClient registeredClient, OAuth2Authorization authorization) {
+		JoseHeader.Builder headersBuilder = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256);
+
+		String issuer = "http://auth-server:9000";        // TODO Allow configuration for issuer claim
+		Instant issuedAt = Instant.now();
+		Instant expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES);		// TODO Allow configuration for id token time-to-live
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+				OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+		String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE);
+
+		// @formatter:off
+		JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder()
+				.issuer(issuer)
+				.subject(authorization.getPrincipalName())
+				.audience(Collections.singletonList(registeredClient.getClientId()))
+				.issuedAt(issuedAt)
+				.expiresAt(expiresAt)
+				.claim(IdTokenClaimNames.AZP, registeredClient.getClientId());
+		if (StringUtils.hasText(nonce)) {
+			claimsBuilder.claim(IdTokenClaimNames.NONCE, nonce);
+		}
+		// TODO Add 'auth_time' claim
+		// @formatter:on
+
+		// @formatter:off
+		return JwtEncodingContext.with(headersBuilder, claimsBuilder)
+				.registeredClient(registeredClient)
+				.authorization(authorization)
+				.tokenType(new TokenType(OidcParameterNames.ID_TOKEN));
+		// @formatter:on
+	}
+
+}

+ 47 - 16
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,6 +15,10 @@
  */
 package org.springframework.security.oauth2.server.authorization.authentication;
 
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
@@ -25,27 +29,27 @@ import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
 import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
+import org.springframework.security.oauth2.jwt.JoseHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-
 import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient;
 
 /**
@@ -58,12 +62,15 @@ import static org.springframework.security.oauth2.server.authorization.authentic
  * @see OAuth2AccessTokenAuthenticationToken
  * @see OAuth2AuthorizationService
  * @see JwtEncoder
+ * @see OAuth2TokenCustomizer
+ * @see JwtEncodingContext
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1">Section 4.1 Authorization Code Grant</a>
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.1.3">Section 4.1.3 Access Token Request</a>
  */
 public class OAuth2AuthorizationCodeAuthenticationProvider implements AuthenticationProvider {
 	private final OAuth2AuthorizationService authorizationService;
 	private final JwtEncoder jwtEncoder;
+	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = (context) -> {};
 
 	/**
 	 * Constructs an {@code OAuth2AuthorizationCodeAuthenticationProvider} using the provided parameters.
@@ -78,6 +85,11 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 		this.jwtEncoder = jwtEncoder;
 	}
 
+	public final void setJwtCustomizer(OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer) {
+		Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
+		this.jwtCustomizer = jwtCustomizer;
+	}
+
 	@Override
 	public Authentication authenticate(Authentication authentication) throws AuthenticationException {
 		OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication =
@@ -116,27 +128,46 @@ public class OAuth2AuthorizationCodeAuthenticationProvider implements Authentica
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
 
-		Set<String> authorizedScopes = authorization.getAttribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES);
-		Jwt jwt = OAuth2TokenIssuerUtil.issueJwtAccessToken(
-				this.jwtEncoder, authorization.getPrincipalName(), registeredClient.getClientId(),
-				authorizedScopes, registeredClient.getTokenSettings().accessTokenTimeToLive());
-		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
-				jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), authorizedScopes);
+		// @formatter:off
+		JwtEncodingContext context = JwtEncodingContextUtils.accessTokenContext(registeredClient, authorization)
+				.principal(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.authorizationGrant(authorizationCodeAuthentication)
+				.build();
+		// @formatter:on
+		this.jwtCustomizer.customize(context);
 
+		JoseHeader headers = context.getHeaders().build();
+		JwtClaimsSet claims = context.getClaims().build();
+		Jwt jwt = this.jwtEncoder.encode(headers, claims);
+
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+				jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE));
 		OAuth2Tokens.Builder tokensBuilder = OAuth2Tokens.from(authorization.getTokens())
 				.accessToken(accessToken);
 
 		OAuth2RefreshToken refreshToken = null;
 		if (registeredClient.getAuthorizationGrantTypes().contains(AuthorizationGrantType.REFRESH_TOKEN)) {
-			refreshToken = OAuth2TokenIssuerUtil.issueRefreshToken(registeredClient.getTokenSettings().refreshTokenTimeToLive());
+			refreshToken = OAuth2RefreshTokenAuthenticationProvider.generateRefreshToken(
+					registeredClient.getTokenSettings().refreshTokenTimeToLive());
 			tokensBuilder.refreshToken(refreshToken);
 		}
 
 		OidcIdToken idToken = null;
 		if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)) {
-			Jwt jwtIdToken = OAuth2TokenIssuerUtil.issueIdToken(
-					this.jwtEncoder, authorization.getPrincipalName(), registeredClient.getClientId(),
-					(String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE));
+			// @formatter:off
+			context = JwtEncodingContextUtils.idTokenContext(registeredClient, authorization)
+					.principal(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
+					.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+					.authorizationGrant(authorizationCodeAuthentication)
+					.build();
+			// @formatter:on
+			this.jwtCustomizer.customize(context);
+
+			headers = context.getHeaders().build();
+			claims = context.getClaims().build();
+			Jwt jwtIdToken = this.jwtEncoder.encode(headers, claims);
+
 			idToken = new OidcIdToken(jwtIdToken.getTokenValue(), jwtIdToken.getIssuedAt(),
 					jwtIdToken.getExpiresAt(), jwtIdToken.getClaims());
 			tokensBuilder.token(idToken);

+ 33 - 8
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,6 +15,10 @@
  */
 package org.springframework.security.oauth2.server.authorization.authentication;
 
+import java.util.LinkedHashSet;
+import java.util.Set;
+import java.util.stream.Collectors;
+
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
@@ -23,37 +27,42 @@ import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.jwt.JoseHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 
-import java.util.LinkedHashSet;
-import java.util.Set;
-import java.util.stream.Collectors;
-
 import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient;
 
 /**
  * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Client Credentials Grant.
  *
  * @author Alexey Nesterov
+ * @author Joe Grandja
  * @since 0.0.1
  * @see OAuth2ClientCredentialsAuthenticationToken
  * @see OAuth2AccessTokenAuthenticationToken
  * @see OAuth2AuthorizationService
  * @see JwtEncoder
+ * @see OAuth2TokenCustomizer
+ * @see JwtEncodingContext
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.4">Section 4.4 Client Credentials Grant</a>
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-4.4.2">Section 4.4.2 Access Token Request</a>
  */
 public class OAuth2ClientCredentialsAuthenticationProvider implements AuthenticationProvider {
 	private final OAuth2AuthorizationService authorizationService;
 	private final JwtEncoder jwtEncoder;
+	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = (context) -> {};
 
 	/**
 	 * Constructs an {@code OAuth2ClientCredentialsAuthenticationProvider} using the provided parameters.
@@ -69,6 +78,11 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica
 		this.jwtEncoder = jwtEncoder;
 	}
 
+	public final void setJwtCustomizer(OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer) {
+		Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
+		this.jwtCustomizer = jwtCustomizer;
+	}
+
 	@Override
 	public Authentication authenticate(Authentication authentication) throws AuthenticationException {
 		OAuth2ClientCredentialsAuthenticationToken clientCredentialsAuthentication =
@@ -93,10 +107,21 @@ public class OAuth2ClientCredentialsAuthenticationProvider implements Authentica
 			scopes = new LinkedHashSet<>(clientCredentialsAuthentication.getScopes());
 		}
 
-		Jwt jwt = OAuth2TokenIssuerUtil
-			.issueJwtAccessToken(this.jwtEncoder, clientPrincipal.getName(), registeredClient.getClientId(), scopes, registeredClient.getTokenSettings().accessTokenTimeToLive());
+		// @formatter:off
+		JwtEncodingContext context = JwtEncodingContextUtils.accessTokenContext(registeredClient, clientPrincipal.getName(), scopes)
+				.principal(clientPrincipal)
+				.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
+				.authorizationGrant(clientCredentialsAuthentication)
+				.build();
+		// @formatter:on
+		this.jwtCustomizer.customize(context);
+
+		JoseHeader headers = context.getHeaders().build();
+		JwtClaimsSet claims = context.getClaims().build();
+		Jwt jwt = this.jwtEncoder.encode(headers, claims);
+
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
-				jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), scopes);
+				jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE));
 
 		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(registeredClient)
 				.principalName(clientPrincipal.getName())

+ 45 - 8
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,47 +15,62 @@
  */
 package org.springframework.security.oauth2.server.authorization.authentication;
 
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Base64;
+import java.util.Set;
+
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
+import org.springframework.security.crypto.keygen.StringKeyGenerator;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.jwt.JoseHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 import org.springframework.util.Assert;
 
-import java.time.Instant;
-import java.util.Set;
-
 import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient;
 
 /**
  * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Refresh Token Grant.
  *
  * @author Alexey Nesterov
+ * @author Joe Grandja
  * @since 0.0.3
  * @see OAuth2RefreshTokenAuthenticationToken
  * @see OAuth2AccessTokenAuthenticationToken
  * @see OAuth2AuthorizationService
  * @see JwtEncoder
+ * @see OAuth2TokenCustomizer
+ * @see JwtEncodingContext
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-1.5">Section 1.5 Refresh Token Grant</a>
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-6">Section 6 Refreshing an Access Token</a>
  */
 public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider {
+	private static final StringKeyGenerator TOKEN_GENERATOR = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
 	private final OAuth2AuthorizationService authorizationService;
 	private final JwtEncoder jwtEncoder;
+	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer = (context) -> {};
 
 	/**
 	 * Constructs an {@code OAuth2RefreshTokenAuthenticationProvider} using the provided parameters.
@@ -71,6 +86,11 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP
 		this.jwtEncoder = jwtEncoder;
 	}
 
+	public final void setJwtCustomizer(OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer) {
+		Assert.notNull(jwtCustomizer, "jwtCustomizer cannot be null");
+		this.jwtCustomizer = jwtCustomizer;
+	}
+
 	@Override
 	public Authentication authenticate(Authentication authentication) throws AuthenticationException {
 		OAuth2RefreshTokenAuthenticationToken refreshTokenAuthentication =
@@ -121,15 +141,26 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP
 			throw new OAuth2AuthenticationException(new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT));
 		}
 
-		Jwt jwt = OAuth2TokenIssuerUtil
-			.issueJwtAccessToken(this.jwtEncoder, authorization.getPrincipalName(), registeredClient.getClientId(), scopes, registeredClient.getTokenSettings().accessTokenTimeToLive());
+		// @formatter:off
+		JwtEncodingContext context = JwtEncodingContextUtils.accessTokenContext(registeredClient, authorization, scopes)
+				.principal(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
+				.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
+				.authorizationGrant(refreshTokenAuthentication)
+				.build();
+		// @formatter:on
+		this.jwtCustomizer.customize(context);
+
+		JoseHeader headers = context.getHeaders().build();
+		JwtClaimsSet claims = context.getClaims().build();
+		Jwt jwt = this.jwtEncoder.encode(headers, claims);
+
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
-			jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), scopes);
+			jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaim(OAuth2ParameterNames.SCOPE));
 
 		TokenSettings tokenSettings = registeredClient.getTokenSettings();
 
 		if (!tokenSettings.reuseRefreshTokens()) {
-			refreshToken = OAuth2TokenIssuerUtil.issueRefreshToken(tokenSettings.refreshTokenTimeToLive());
+			refreshToken = generateRefreshToken(tokenSettings.refreshTokenTimeToLive());
 		}
 
 		authorization = OAuth2Authorization.from(authorization)
@@ -146,4 +177,10 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP
 	public boolean supports(Class<?> authentication) {
 		return OAuth2RefreshTokenAuthenticationToken.class.isAssignableFrom(authentication);
 	}
+
+	static OAuth2RefreshToken generateRefreshToken(Duration tokenTimeToLive) {
+		Instant issuedAt = Instant.now();
+		Instant expiresAt = issuedAt.plus(tokenTimeToLive);
+		return new OAuth2RefreshToken2(TOKEN_GENERATOR.generateKey(), issuedAt, expiresAt);
+	}
 }

+ 0 - 97
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2TokenIssuerUtil.java

@@ -1,97 +0,0 @@
-/*
- * Copyright 2020-2021 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.springframework.security.oauth2.server.authorization.authentication;
-
-import org.springframework.security.crypto.keygen.Base64StringKeyGenerator;
-import org.springframework.security.crypto.keygen.StringKeyGenerator;
-import org.springframework.security.oauth2.core.OAuth2RefreshToken;
-import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
-import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
-import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
-import org.springframework.security.oauth2.jwt.JoseHeader;
-import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
-import org.springframework.security.oauth2.jwt.Jwt;
-import org.springframework.security.oauth2.jwt.JwtClaimsSet;
-import org.springframework.security.oauth2.jwt.JwtEncoder;
-import org.springframework.util.StringUtils;
-
-import java.time.Duration;
-import java.time.Instant;
-import java.time.temporal.ChronoUnit;
-import java.util.Base64;
-import java.util.Collections;
-import java.util.Set;
-
-/**
- * @author Alexey Nesterov
- * @since 0.0.3
- */
-class OAuth2TokenIssuerUtil {
-
-	private static final StringKeyGenerator TOKEN_GENERATOR = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
-
-	static Jwt issueJwtAccessToken(JwtEncoder jwtEncoder, String subject, String audience, Set<String> scopes, Duration tokenTimeToLive) {
-		JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
-
-		String issuer = "http://auth-server:9000";		// TODO Allow configuration for issuer claim
-		Instant issuedAt = Instant.now();
-		Instant expiresAt = issuedAt.plus(tokenTimeToLive);
-
-		JwtClaimsSet jwtClaimsSet = JwtClaimsSet.builder()
-											.issuer(issuer)
-											.subject(subject)
-											.audience(Collections.singletonList(audience))
-											.issuedAt(issuedAt)
-											.expiresAt(expiresAt)
-											.notBefore(issuedAt)
-											.claim(OAuth2ParameterNames.SCOPE, scopes)
-											.build();
-
-		return jwtEncoder.encode(joseHeader, jwtClaimsSet);
-	}
-
-	static Jwt issueIdToken(JwtEncoder jwtEncoder, String subject, String audience, String nonce) {
-		JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
-
-		String issuer = "http://auth-server:9000";		// TODO Allow configuration for issuer claim
-		Instant issuedAt = Instant.now();
-		Instant expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES);		// TODO Allow configuration for id token time-to-live
-
-		JwtClaimsSet.Builder builder = JwtClaimsSet.builder()
-				.issuer(issuer)
-				.subject(subject)
-				.audience(Collections.singletonList(audience))
-				.issuedAt(issuedAt)
-				.expiresAt(expiresAt)
-				.claim(IdTokenClaimNames.AZP, audience);
-		if (StringUtils.hasText(nonce)) {
-			builder.claim(IdTokenClaimNames.NONCE, nonce);
-		}
-
-		// TODO Add 'auth_time' claim
-
-		JwtClaimsSet jwtClaimsSet = builder.build();
-
-		return jwtEncoder.encode(joseHeader, jwtClaimsSet);
-	}
-
-	static OAuth2RefreshToken issueRefreshToken(Duration tokenTimeToLive) {
-		Instant issuedAt = Instant.now();
-		Instant expiresAt = issuedAt.plus(tokenTimeToLive);
-
-		return new OAuth2RefreshToken2(TOKEN_GENERATOR.generateKey(), issuedAt, expiresAt);
-	}
-}

+ 87 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtEncodingContext.java

@@ -0,0 +1,87 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.token;
+
+import java.util.Map;
+import java.util.function.Consumer;
+
+import org.springframework.lang.Nullable;
+import org.springframework.security.oauth2.core.context.Context;
+import org.springframework.security.oauth2.jwt.JoseHeader;
+import org.springframework.security.oauth2.jwt.JwtClaimsSet;
+import org.springframework.util.Assert;
+
+/**
+ * @author Joe Grandja
+ * @since 0.1.0
+ * @see OAuth2TokenContext
+ * @see JoseHeader.Builder
+ * @see JwtClaimsSet.Builder
+ */
+public final class JwtEncodingContext implements OAuth2TokenContext {
+	private final Context context;
+
+	private JwtEncodingContext(Map<Object, Object> context) {
+		this.context = Context.of(context);
+	}
+
+	@Nullable
+	@Override
+	public <V> V get(Object key) {
+		return this.context.get(key);
+	}
+
+	@Override
+	public boolean hasKey(Object key) {
+		return this.context.hasKey(key);
+	}
+
+	public JoseHeader.Builder getHeaders() {
+		return get(JoseHeader.Builder.class);
+	}
+
+	public JwtClaimsSet.Builder getClaims() {
+		return get(JwtClaimsSet.Builder.class);
+	}
+
+	public static Builder with(JoseHeader.Builder headersBuilder, JwtClaimsSet.Builder claimsBuilder) {
+		return new Builder(headersBuilder, claimsBuilder);
+	}
+
+	public static final class Builder extends AbstractBuilder<JwtEncodingContext, Builder> {
+
+		private Builder(JoseHeader.Builder headersBuilder, JwtClaimsSet.Builder claimsBuilder) {
+			Assert.notNull(headersBuilder, "headersBuilder cannot be null");
+			Assert.notNull(claimsBuilder, "claimsBuilder cannot be null");
+			put(JoseHeader.Builder.class, headersBuilder);
+			put(JwtClaimsSet.Builder.class, claimsBuilder);
+		}
+
+		public Builder headers(Consumer<JoseHeader.Builder> headersConsumer) {
+			headersConsumer.accept(get(JoseHeader.Builder.class));
+			return this;
+		}
+
+		public Builder claims(Consumer<JwtClaimsSet.Builder> claimsConsumer) {
+			claimsConsumer.accept(get(JwtClaimsSet.Builder.class));
+			return this;
+		}
+
+		public JwtEncodingContext build() {
+			return new JwtEncodingContext(this.context);
+		}
+	}
+}

+ 119 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenContext.java

@@ -0,0 +1,119 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.token;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.function.Consumer;
+
+import org.springframework.lang.Nullable;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.context.Context;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.TokenType;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.util.Assert;
+
+/**
+ * @author Joe Grandja
+ * @since 0.1.0
+ * @see Context
+ */
+public interface OAuth2TokenContext extends Context {
+
+	default RegisteredClient getRegisteredClient() {
+		return get(RegisteredClient.class);
+	}
+
+	default <T extends Authentication> T getPrincipal() {
+		return get(AbstractBuilder.PRINCIPAL_AUTHENTICATION_KEY);
+	}
+
+	@Nullable
+	default OAuth2Authorization getAuthorization() {
+		return get(OAuth2Authorization.class);
+	}
+
+	default TokenType getTokenType() {
+		return get(TokenType.class);
+	}
+
+	default AuthorizationGrantType getAuthorizationGrantType() {
+		return get(AuthorizationGrantType.class);
+	}
+
+	default <T extends Authentication> T getAuthorizationGrant() {
+		return get(AbstractBuilder.AUTHORIZATION_GRANT_AUTHENTICATION_KEY);
+	}
+
+	abstract class AbstractBuilder<T extends OAuth2TokenContext, B extends AbstractBuilder<T, B>> {
+		private static final String PRINCIPAL_AUTHENTICATION_KEY =
+				Authentication.class.getName().concat(".PRINCIPAL");
+		private static final String AUTHORIZATION_GRANT_AUTHENTICATION_KEY =
+				Authentication.class.getName().concat(".AUTHORIZATION_GRANT");
+		protected final Map<Object, Object> context = new HashMap<>();
+
+		public B registeredClient(RegisteredClient registeredClient) {
+			return put(RegisteredClient.class, registeredClient);
+		}
+
+		public B principal(Authentication principal) {
+			return put(PRINCIPAL_AUTHENTICATION_KEY, principal);
+		}
+
+		public B authorization(OAuth2Authorization authorization) {
+			return put(OAuth2Authorization.class, authorization);
+		}
+
+		public B tokenType(TokenType tokenType) {
+			return put(TokenType.class, tokenType);
+		}
+
+		public B authorizationGrantType(AuthorizationGrantType authorizationGrantType) {
+			return put(AuthorizationGrantType.class, authorizationGrantType);
+		}
+
+		public B authorizationGrant(Authentication authorizationGrant) {
+			return put(AUTHORIZATION_GRANT_AUTHENTICATION_KEY, authorizationGrant);
+		}
+
+		public B put(Object key, Object value) {
+			Assert.notNull(key, "key cannot be null");
+			Assert.notNull(value, "value cannot be null");
+			this.context.put(key, value);
+			return getThis();
+		}
+
+		public B context(Consumer<Map<Object, Object>> contextConsumer) {
+			contextConsumer.accept(this.context);
+			return getThis();
+		}
+
+		@SuppressWarnings("unchecked")
+		protected <V> V get(Object key) {
+			return (V) this.context.get(key);
+		}
+
+		@SuppressWarnings("unchecked")
+		protected B getThis() {
+			return (B) this;
+		}
+
+		public abstract T build();
+
+	}
+}

+ 28 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/OAuth2TokenCustomizer.java

@@ -0,0 +1,28 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.token;
+
+/**
+ * @author Joe Grandja
+ * @since 0.1.0
+ * @see OAuth2TokenContext
+ */
+@FunctionalInterface
+public interface OAuth2TokenCustomizer<C extends OAuth2TokenContext> {
+
+	void customize(C context);
+
+}

+ 18 - 16
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,6 +15,22 @@
  */
 package org.springframework.security.oauth2.server.authorization.web;
 
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.Arrays;
+import java.util.Base64;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
@@ -53,21 +69,6 @@ import org.springframework.util.StringUtils;
 import org.springframework.web.filter.OncePerRequestFilter;
 import org.springframework.web.util.UriComponentsBuilder;
 
-import javax.servlet.FilterChain;
-import javax.servlet.ServletException;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-import java.io.IOException;
-import java.nio.charset.StandardCharsets;
-import java.time.Instant;
-import java.time.temporal.ChronoUnit;
-import java.util.Arrays;
-import java.util.Base64;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Set;
-
 /**
  * A {@code Filter} for the OAuth 2.0 Authorization Code Grant,
  * which handles the processing of the OAuth 2.0 Authorization Request.
@@ -193,6 +194,7 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 		OAuth2AuthorizationRequest authorizationRequest = authorizationRequestContext.buildAuthorizationRequest();
 		OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient)
 				.principalName(principal.getName())
+				.attribute(OAuth2AuthorizationAttributeNames.PRINCIPAL, principal)
 				.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest);
 
 		if (registeredClient.getClientSettings().requireUserConsent()) {

+ 55 - 12
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2AuthorizationCodeGrantTests.java

@@ -18,7 +18,9 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.se
 import java.net.URLEncoder;
 import java.nio.charset.StandardCharsets;
 import java.util.Base64;
-import java.util.function.BiConsumer;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
 
 import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.jwk.source.JWKSource;
@@ -33,19 +35,29 @@ import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Import;
 import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.mock.http.client.MockClientHttpResponse;
+import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
 import org.springframework.security.config.test.SpringTestRule;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
+import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
 import org.springframework.security.oauth2.jose.TestJwks;
-import org.springframework.security.oauth2.jwt.JoseHeader;
-import org.springframework.security.oauth2.jwt.JwtClaimsSet;
+import org.springframework.security.oauth2.jose.TestKeys;
+import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
 import org.springframework.security.oauth2.jwt.NimbusJwsEncoder;
+import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
 import org.springframework.security.oauth2.server.authorization.TokenType;
@@ -53,7 +65,9 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
 import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
 import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter;
 import org.springframework.test.web.servlet.MockMvc;
@@ -90,13 +104,16 @@ public class OAuth2AuthorizationCodeGrantTests {
 	// https://tools.ietf.org/html/rfc7636#appendix-B
 	private static final String S256_CODE_VERIFIER = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk";
 	private static final String S256_CODE_CHALLENGE = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM";
+	private static final String AUTHORITIES_CLAIM = "authorities";
 
 	private static RegisteredClientRepository registeredClientRepository;
 	private static OAuth2AuthorizationService authorizationService;
 	private static JWKSource<SecurityContext> jwkSource;
 	private static NimbusJwsEncoder jwtEncoder;
-	private static BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer;
+	private static NimbusJwtDecoder jwtDecoder;
 	private static ProviderSettings providerSettings;
+	private static HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
+			new OAuth2AccessTokenResponseHttpMessageConverter();
 
 	@Rule
 	public final SpringTestRule spring = new SpringTestRule();
@@ -111,8 +128,7 @@ public class OAuth2AuthorizationCodeGrantTests {
 		JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK);
 		jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet);
 		jwtEncoder = new NimbusJwsEncoder(jwkSource);
-		jwtCustomizer = mock(BiConsumer.class);
-		jwtEncoder.setJwtCustomizer(jwtCustomizer);
+		jwtDecoder = NimbusJwtDecoder.withPublicKey(TestKeys.DEFAULT_PUBLIC_KEY).build();
 		providerSettings = new ProviderSettings()
 				.authorizationEndpoint("/test/authorize")
 				.tokenEndpoint("/test/token");
@@ -186,8 +202,17 @@ public class OAuth2AuthorizationCodeGrantTests {
 				eq(TokenType.AUTHORIZATION_CODE)))
 				.thenReturn(authorization);
 
-		assertTokenRequestReturnsAccessTokenResponse(
+		OAuth2AccessTokenResponse accessTokenResponse = assertTokenRequestReturnsAccessTokenResponse(
 				registeredClient, authorization, OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI);
+
+		// Assert user authorities was propagated as claim in JWT
+		Jwt jwt = jwtDecoder.decode(accessTokenResponse.getAccessToken().getTokenValue());
+		List<String> authoritiesClaim = jwt.getClaim(AUTHORITIES_CLAIM);
+		Authentication principal = authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL);
+		Set<String> userAuthorities = principal.getAuthorities().stream()
+				.map(GrantedAuthority::getAuthority)
+				.collect(Collectors.toSet());
+		assertThat(authoritiesClaim).containsExactlyInAnyOrderElementsOf(userAuthorities);
 	}
 
 	@Test
@@ -208,10 +233,10 @@ public class OAuth2AuthorizationCodeGrantTests {
 				registeredClient, authorization, providerSettings.tokenEndpoint());
 	}
 
-	private void assertTokenRequestReturnsAccessTokenResponse(RegisteredClient registeredClient,
+	private OAuth2AccessTokenResponse assertTokenRequestReturnsAccessTokenResponse(RegisteredClient registeredClient,
 			OAuth2Authorization authorization, String tokenEndpointUri) throws Exception {
 
-		this.mvc.perform(post(tokenEndpointUri)
+		MvcResult mvcResult = this.mvc.perform(post(tokenEndpointUri)
 				.params(getTokenRequestParameters(registeredClient, authorization))
 				.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
 						registeredClient.getClientId(), registeredClient.getClientSecret())))
@@ -222,13 +247,19 @@ public class OAuth2AuthorizationCodeGrantTests {
 				.andExpect(jsonPath("$.token_type").isNotEmpty())
 				.andExpect(jsonPath("$.expires_in").isNotEmpty())
 				.andExpect(jsonPath("$.refresh_token").isNotEmpty())
-				.andExpect(jsonPath("$.scope").isNotEmpty());
+				.andExpect(jsonPath("$.scope").isNotEmpty())
+				.andReturn();
 
 		verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
 		verify(authorizationService).findByToken(
 				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
 				eq(TokenType.AUTHORIZATION_CODE));
 		verify(authorizationService).save(any());
+
+		MockHttpServletResponse servletResponse = mvcResult.getResponse();
+		MockClientHttpResponse httpResponse = new MockClientHttpResponse(
+				servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus()));
+		return accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse);
 	}
 
 	@Test
@@ -295,8 +326,6 @@ public class OAuth2AuthorizationCodeGrantTests {
 				.params(getTokenRequestParameters(registeredClient, authorization))
 				.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
 						registeredClient.getClientId(), registeredClient.getClientSecret())));
-
-		verify(jwtCustomizer).accept(any(JoseHeader.Builder.class), any(JwtClaimsSet.Builder.class));
 	}
 
 	private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) {
@@ -345,6 +374,20 @@ public class OAuth2AuthorizationCodeGrantTests {
 		JWKSource<SecurityContext> jwkSource() {
 			return jwkSource;
 		}
+
+		@Bean
+		OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer() {
+			return context -> {
+				if (AuthorizationGrantType.AUTHORIZATION_CODE.equals(context.getAuthorizationGrantType()) &&
+						TokenType.ACCESS_TOKEN.equals(context.getTokenType())) {
+					Authentication principal = context.getPrincipal();
+					Set<String> authorities = principal.getAuthorities().stream()
+							.map(GrantedAuthority::getAuthority)
+							.collect(Collectors.toSet());
+					context.getClaims().claim(AUTHORITIES_CLAIM, authorities);
+				}
+			};
+		}
 	}
 
 	@EnableWebSecurity

+ 13 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2ClientCredentialsGrantTests.java

@@ -41,6 +41,8 @@ import org.springframework.security.oauth2.server.authorization.OAuth2Authorizat
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
 import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
@@ -65,6 +67,7 @@ public class OAuth2ClientCredentialsGrantTests {
 	private static RegisteredClientRepository registeredClientRepository;
 	private static OAuth2AuthorizationService authorizationService;
 	private static JWKSource<SecurityContext> jwkSource;
+	private static OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer;
 
 	@Rule
 	public final SpringTestRule spring = new SpringTestRule();
@@ -78,10 +81,13 @@ public class OAuth2ClientCredentialsGrantTests {
 		authorizationService = mock(OAuth2AuthorizationService.class);
 		JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK);
 		jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet);
+		jwtCustomizer = mock(OAuth2TokenCustomizer.class);
 	}
 
+	@SuppressWarnings("unchecked")
 	@Before
 	public void setup() {
+		reset(jwtCustomizer);
 		reset(registeredClientRepository);
 		reset(authorizationService);
 	}
@@ -115,6 +121,7 @@ public class OAuth2ClientCredentialsGrantTests {
 				.andExpect(jsonPath("$.access_token").isNotEmpty())
 				.andExpect(jsonPath("$.scope").value("scope1 scope2"));
 
+		verify(jwtCustomizer).customize(any());
 		verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
 		verify(authorizationService).save(any());
 	}
@@ -136,6 +143,7 @@ public class OAuth2ClientCredentialsGrantTests {
 				.andExpect(jsonPath("$.access_token").isNotEmpty())
 				.andExpect(jsonPath("$.scope").value("scope1 scope2"));
 
+		verify(jwtCustomizer).customize(any());
 		verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
 		verify(authorizationService).save(any());
 	}
@@ -166,5 +174,10 @@ public class OAuth2ClientCredentialsGrantTests {
 		JWKSource<SecurityContext> jwkSource() {
 			return jwkSource;
 		}
+
+		@Bean
+		OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer() {
+			return jwtCustomizer;
+		}
 	}
 }

+ 54 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OAuth2RefreshTokenGrantTests.java

@@ -18,6 +18,9 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.se
 import java.net.URLEncoder;
 import java.nio.charset.StandardCharsets;
 import java.util.Base64;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
 
 import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.jwk.source.JWKSource;
@@ -31,24 +34,40 @@ import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Import;
 import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.mock.http.client.MockClientHttpResponse;
+import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
 import org.springframework.security.config.test.SpringTestRule;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
 import org.springframework.security.oauth2.jose.TestJwks;
+import org.springframework.security.oauth2.jose.TestKeys;
+import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
 import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
 import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter;
 import org.springframework.test.web.servlet.MockMvc;
+import org.springframework.test.web.servlet.MvcResult;
 import org.springframework.util.LinkedMultiValueMap;
 import org.springframework.util.MultiValueMap;
 
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.hamcrest.CoreMatchers.containsString;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
@@ -68,9 +87,13 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
  * @since 0.0.3
  */
 public class OAuth2RefreshTokenGrantTests {
+	private static final String AUTHORITIES_CLAIM = "authorities";
 	private static RegisteredClientRepository registeredClientRepository;
 	private static OAuth2AuthorizationService authorizationService;
 	private static JWKSource<SecurityContext> jwkSource;
+	private static NimbusJwtDecoder jwtDecoder;
+	private static HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
+			new OAuth2AccessTokenResponseHttpMessageConverter();
 
 	@Rule
 	public final SpringTestRule spring = new SpringTestRule();
@@ -84,6 +107,7 @@ public class OAuth2RefreshTokenGrantTests {
 		authorizationService = mock(OAuth2AuthorizationService.class);
 		JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK);
 		jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet);
+		jwtDecoder = NimbusJwtDecoder.withPublicKey(TestKeys.DEFAULT_PUBLIC_KEY).build();
 	}
 
 	@Before
@@ -106,7 +130,7 @@ public class OAuth2RefreshTokenGrantTests {
 				eq(TokenType.REFRESH_TOKEN)))
 				.thenReturn(authorization);
 
-		this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI)
+		MvcResult mvcResult = this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI)
 				.params(getRefreshTokenRequestParameters(authorization))
 				.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
 						registeredClient.getClientId(), registeredClient.getClientSecret())))
@@ -117,7 +141,8 @@ public class OAuth2RefreshTokenGrantTests {
 				.andExpect(jsonPath("$.token_type").isNotEmpty())
 				.andExpect(jsonPath("$.expires_in").isNotEmpty())
 				.andExpect(jsonPath("$.refresh_token").isNotEmpty())
-				.andExpect(jsonPath("$.scope").isNotEmpty());
+				.andExpect(jsonPath("$.scope").isNotEmpty())
+				.andReturn();
 
 		verify(registeredClientRepository).findByClientId(eq(registeredClient.getClientId()));
 		verify(authorizationService).findByToken(
@@ -125,6 +150,20 @@ public class OAuth2RefreshTokenGrantTests {
 				eq(TokenType.REFRESH_TOKEN));
 		verify(authorizationService).save(any());
 
+		MockHttpServletResponse servletResponse = mvcResult.getResponse();
+		MockClientHttpResponse httpResponse = new MockClientHttpResponse(
+				servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus()));
+		OAuth2AccessTokenResponse accessTokenResponse = accessTokenHttpResponseConverter.read(
+				OAuth2AccessTokenResponse.class, httpResponse);
+
+		// Assert user authorities was propagated as claim in JWT
+		Jwt jwt = jwtDecoder.decode(accessTokenResponse.getAccessToken().getTokenValue());
+		List<String> authoritiesClaim = jwt.getClaim(AUTHORITIES_CLAIM);
+		Authentication principal = authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL);
+		Set<String> userAuthorities = principal.getAuthorities().stream()
+				.map(GrantedAuthority::getAuthority)
+				.collect(Collectors.toSet());
+		assertThat(authoritiesClaim).containsExactlyInAnyOrderElementsOf(userAuthorities);
 	}
 
 	private static MultiValueMap<String, String> getRefreshTokenRequestParameters(OAuth2Authorization authorization) {
@@ -160,5 +199,18 @@ public class OAuth2RefreshTokenGrantTests {
 		JWKSource<SecurityContext> jwkSource() {
 			return jwkSource;
 		}
+
+		@Bean
+		OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer() {
+			return context -> {
+				if (AuthorizationGrantType.REFRESH_TOKEN.equals(context.getAuthorizationGrantType())) {
+					Authentication principal = context.getPrincipal();
+					Set<String> authorities = principal.getAuthorities().stream()
+							.map(GrantedAuthority::getAuthority)
+							.collect(Collectors.toSet());
+					context.getClaims().claim(AUTHORITIES_CLAIM, authorities);
+				}
+			};
+		}
 	}
 }

+ 57 - 6
oauth2-authorization-server/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/server/authorization/OidcTests.java

@@ -18,6 +18,9 @@ package org.springframework.security.config.annotation.web.configurers.oauth2.se
 import java.net.URLEncoder;
 import java.nio.charset.StandardCharsets;
 import java.util.Base64;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
 
 import com.nimbusds.jose.jwk.JWKSet;
 import com.nimbusds.jose.jwk.source.JWKSource;
@@ -32,15 +35,28 @@ import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Import;
 import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.mock.http.client.MockClientHttpResponse;
+import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.OAuth2AuthorizationServerConfiguration;
 import org.springframework.security.config.test.SpringTestRule;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponseType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
+import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
 import org.springframework.security.oauth2.jose.TestJwks;
+import org.springframework.security.oauth2.jose.TestKeys;
+import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
@@ -48,7 +64,9 @@ import org.springframework.security.oauth2.server.authorization.client.Registere
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
 import org.springframework.security.oauth2.server.authorization.oidc.web.OidcProviderConfigurationEndpointFilter;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
 import org.springframework.security.oauth2.server.authorization.web.OAuth2AuthorizationEndpointFilter;
 import org.springframework.security.oauth2.server.authorization.web.OAuth2TokenEndpointFilter;
 import org.springframework.test.web.servlet.MockMvc;
@@ -80,10 +98,14 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
  * @author Daniel Garnier-Moiroux
  */
 public class OidcTests {
-	private static final String issuerUrl = "https://example.com/issuer1";
+	private static final String ISSUER_URL = "https://example.com/issuer1";
+	private static final String AUTHORITIES_CLAIM = "authorities";
 	private static RegisteredClientRepository registeredClientRepository;
 	private static OAuth2AuthorizationService authorizationService;
 	private static JWKSource<SecurityContext> jwkSource;
+	private static NimbusJwtDecoder jwtDecoder;
+	private static HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
+			new OAuth2AccessTokenResponseHttpMessageConverter();
 
 	@Rule
 	public final SpringTestRule spring = new SpringTestRule();
@@ -97,6 +119,7 @@ public class OidcTests {
 		authorizationService = mock(OAuth2AuthorizationService.class);
 		JWKSet jwkSet = new JWKSet(TestJwks.DEFAULT_RSA_JWK);
 		jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(jwkSet);
+		jwtDecoder = NimbusJwtDecoder.withPublicKey(TestKeys.DEFAULT_PUBLIC_KEY).build();
 	}
 
 	@Before
@@ -111,7 +134,7 @@ public class OidcTests {
 
 		this.mvc.perform(get(OidcProviderConfigurationEndpointFilter.DEFAULT_OIDC_PROVIDER_CONFIGURATION_ENDPOINT_URI))
 				.andExpect(status().is2xxSuccessful())
-				.andExpect(jsonPath("issuer").value(issuerUrl));
+				.andExpect(jsonPath("issuer").value(ISSUER_URL));
 	}
 
 	@Test
@@ -148,7 +171,7 @@ public class OidcTests {
 
 		MvcResult mvcResult = this.mvc.perform(get(OAuth2AuthorizationEndpointFilter.DEFAULT_AUTHORIZATION_ENDPOINT_URI)
 				.params(getAuthorizationRequestParameters(registeredClient))
-				.with(user("user")))
+				.with(user("user").roles("A", "B")))
 				.andExpect(status().is3xxRedirection())
 				.andReturn();
 		assertThat(mvcResult.getResponse().getRedirectedUrl()).matches("https://example.com\\?code=.{15,}&state=state");
@@ -164,7 +187,7 @@ public class OidcTests {
 				eq(TokenType.AUTHORIZATION_CODE)))
 				.thenReturn(authorization);
 
-		this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI)
+		mvcResult = this.mvc.perform(post(OAuth2TokenEndpointFilter.DEFAULT_TOKEN_ENDPOINT_URI)
 				.params(getTokenRequestParameters(registeredClient, authorization))
 				.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
 						registeredClient.getClientId(), registeredClient.getClientSecret())))
@@ -176,13 +199,28 @@ public class OidcTests {
 				.andExpect(jsonPath("$.expires_in").isNotEmpty())
 				.andExpect(jsonPath("$.refresh_token").isNotEmpty())
 				.andExpect(jsonPath("$.scope").isNotEmpty())
-				.andExpect(jsonPath("$.id_token").isNotEmpty());
+				.andExpect(jsonPath("$.id_token").isNotEmpty())
+				.andReturn();
 
 		verify(registeredClientRepository, times(2)).findByClientId(eq(registeredClient.getClientId()));
 		verify(authorizationService).findByToken(
 				eq(authorization.getTokens().getToken(OAuth2AuthorizationCode.class).getTokenValue()),
 				eq(TokenType.AUTHORIZATION_CODE));
 		verify(authorizationService, times(2)).save(any());
+
+		MockHttpServletResponse servletResponse = mvcResult.getResponse();
+		MockClientHttpResponse httpResponse = new MockClientHttpResponse(
+				servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus()));
+		OAuth2AccessTokenResponse accessTokenResponse = accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse);
+
+		// Assert user authorities was propagated as claim in ID Token
+		Jwt idToken = jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN));
+		List<String> authoritiesClaim = idToken.getClaim(AUTHORITIES_CLAIM);
+		Authentication principal = authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL);
+		Set<String> userAuthorities = principal.getAuthorities().stream()
+				.map(GrantedAuthority::getAuthority)
+				.collect(Collectors.toSet());
+		assertThat(authoritiesClaim).containsExactlyInAnyOrderElementsOf(userAuthorities);
 	}
 
 	private static MultiValueMap<String, String> getAuthorizationRequestParameters(RegisteredClient registeredClient) {
@@ -231,6 +269,19 @@ public class OidcTests {
 		JWKSource<SecurityContext> jwkSource() {
 			return jwkSource;
 		}
+
+		@Bean
+		OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer() {
+			return context -> {
+				if (context.getTokenType().getValue().equals(OidcParameterNames.ID_TOKEN)) {
+					Authentication principal = context.getPrincipal();
+					Set<String> authorities = principal.getAuthorities().stream()
+							.map(GrantedAuthority::getAuthority)
+							.collect(Collectors.toSet());
+					context.getClaims().claim(AUTHORITIES_CLAIM, authorities);
+				}
+			};
+		}
 	}
 
 	@EnableWebSecurity
@@ -239,7 +290,7 @@ public class OidcTests {
 
 		@Bean
 		ProviderSettings providerSettings() {
-			return new ProviderSettings().issuer(issuerUrl);
+			return new ProviderSettings().issuer(ISSUER_URL);
 		}
 	}
 

+ 0 - 24
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/jwt/NimbusJwsEncoderTests.java

@@ -21,7 +21,6 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
-import java.util.function.BiConsumer;
 
 import com.nimbusds.jose.KeySourceException;
 import com.nimbusds.jose.jwk.ECKey;
@@ -49,7 +48,6 @@ import static org.mockito.BDDMockito.given;
 import static org.mockito.BDDMockito.willAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.verify;
 
 /**
  * Tests for {@link NimbusJwsEncoder}.
@@ -77,12 +75,6 @@ public class NimbusJwsEncoderTests {
 				.withMessage("jwkSource cannot be null");
 	}
 
-	@Test
-	public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
-		assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.setJwtCustomizer(null))
-				.withMessage("jwtCustomizer cannot be null");
-	}
-
 	@Test
 	public void encodeWhenHeadersNullThenThrowIllegalArgumentException() {
 		JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
@@ -99,22 +91,6 @@ public class NimbusJwsEncoderTests {
 				.withMessage("claims cannot be null");
 	}
 
-	@Test
-	public void encodeWhenCustomizerSetThenCalled() {
-		RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK;
-		this.jwkList.add(rsaJwk);
-
-		BiConsumer<JoseHeader.Builder, JwtClaimsSet.Builder> jwtCustomizer = mock(BiConsumer.class);
-		this.jwsEncoder.setJwtCustomizer(jwtCustomizer);
-
-		JoseHeader joseHeader = TestJoseHeaders.joseHeader().build();
-		JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
-
-		this.jwsEncoder.encode(joseHeader, jwtClaimsSet);
-
-		verify(jwtCustomizer).accept(any(JoseHeader.Builder.class), any(JwtClaimsSet.Builder.class));
-	}
-
 	@Test
 	public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exception {
 		this.jwkSource = mock(JWKSource.class);

+ 9 - 6
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,6 +15,12 @@
  */
 package org.springframework.security.oauth2.server.authorization;
 
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.Collections;
+import java.util.Map;
+
+import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
@@ -24,11 +30,6 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
-import java.time.Instant;
-import java.time.temporal.ChronoUnit;
-import java.util.Collections;
-import java.util.Map;
-
 /**
  * @author Joe Grandja
  * @author Daniel Garnier-Moiroux
@@ -63,6 +64,8 @@ public class TestOAuth2Authorizations {
 				.principalName("principal")
 				.tokens(OAuth2Tokens.builder().token(authorizationCode).accessToken(accessToken).refreshToken(refreshToken).build())
 				.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST, authorizationRequest)
+				.attribute(OAuth2AuthorizationAttributeNames.PRINCIPAL,
+						new TestingAuthenticationToken("principal", null, "ROLE_A", "ROLE_B"))
 				.attribute(OAuth2AuthorizationAttributeNames.AUTHORIZED_SCOPES, authorizationRequest.getScopes());
 	}
 }

+ 55 - 6
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java

@@ -15,10 +15,17 @@
  */
 package org.springframework.security.oauth2.server.authorization.authentication;
 
+import java.time.Duration;
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.Set;
+
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
+
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
@@ -28,11 +35,12 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
 import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
-import org.springframework.security.oauth2.jwt.JoseHeaderNames;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
+import org.springframework.security.oauth2.jwt.JoseHeaderNames;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@@ -41,14 +49,10 @@ import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
-import java.time.Duration;
-import java.time.Instant;
-import java.time.temporal.ChronoUnit;
-import java.util.Set;
-
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.assertj.core.api.Assertions.entry;
@@ -69,6 +73,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 	private static final String AUTHORIZATION_CODE = "code";
 	private OAuth2AuthorizationService authorizationService;
 	private JwtEncoder jwtEncoder;
+	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer;
 	private OAuth2AuthorizationCodeAuthenticationProvider authenticationProvider;
 
 	@Before
@@ -77,6 +82,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		this.jwtEncoder = mock(JwtEncoder.class);
 		this.authenticationProvider = new OAuth2AuthorizationCodeAuthenticationProvider(
 				this.authorizationService, this.jwtEncoder);
+		this.jwtCustomizer = mock(OAuth2TokenCustomizer.class);
+		this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer);
 	}
 
 	@Test
@@ -93,6 +100,13 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 				.hasMessage("jwtEncoder cannot be null");
 	}
 
+	@Test
+	public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("jwtCustomizer cannot be null");
+	}
+
 	@Test
 	public void supportsWhenTypeOAuth2AuthorizationCodeAuthenticationTokenThenReturnTrue() {
 		assertThat(this.authenticationProvider.supports(OAuth2AuthorizationCodeAuthenticationToken.class)).isTrue();
@@ -225,6 +239,18 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
 				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
 
+		ArgumentCaptor<JwtEncodingContext> jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class);
+		verify(this.jwtCustomizer).customize(jwtEncodingContextCaptor.capture());
+		JwtEncodingContext jwtEncodingContext = jwtEncodingContextCaptor.getValue();
+		assertThat(jwtEncodingContext.getRegisteredClient()).isEqualTo(registeredClient);
+		assertThat(jwtEncodingContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL));
+		assertThat(jwtEncodingContext.getAuthorization()).isEqualTo(authorization);
+		assertThat(jwtEncodingContext.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN);
+		assertThat(jwtEncodingContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
+		assertThat(jwtEncodingContext.<Authentication>getAuthorizationGrant()).isEqualTo(authentication);
+		assertThat(jwtEncodingContext.getHeaders()).isNotNull();
+		assertThat(jwtEncodingContext.getClaims()).isNotNull();
+
 		ArgumentCaptor<JwtClaimsSet> jwtClaimsSetCaptor = ArgumentCaptor.forClass(JwtClaimsSet.class);
 		verify(this.jwtEncoder).encode(any(), jwtClaimsSetCaptor.capture());
 		JwtClaimsSet jwtClaimsSet = jwtClaimsSetCaptor.getValue();
@@ -264,6 +290,29 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
 				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
 
+		ArgumentCaptor<JwtEncodingContext> jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class);
+		verify(this.jwtCustomizer, times(2)).customize(jwtEncodingContextCaptor.capture());
+		// Access Token context
+		JwtEncodingContext accessTokenContext = jwtEncodingContextCaptor.getAllValues().get(0);
+		assertThat(accessTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
+		assertThat(accessTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL));
+		assertThat(accessTokenContext.getAuthorization()).isEqualTo(authorization);
+		assertThat(accessTokenContext.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN);
+		assertThat(accessTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
+		assertThat(accessTokenContext.<Authentication>getAuthorizationGrant()).isEqualTo(authentication);
+		assertThat(accessTokenContext.getHeaders()).isNotNull();
+		assertThat(accessTokenContext.getClaims()).isNotNull();
+		// ID Token context
+		JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1);
+		assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
+		assertThat(idTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL));
+		assertThat(idTokenContext.getAuthorization()).isEqualTo(authorization);
+		assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN);
+		assertThat(idTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
+		assertThat(idTokenContext.<Authentication>getAuthorizationGrant()).isEqualTo(authentication);
+		assertThat(idTokenContext.getHeaders()).isNotNull();
+		assertThat(idTokenContext.getClaims()).isNotNull();
+
 		verify(this.jwtEncoder, times(2)).encode(any(), any());		// Access token and ID Token
 
 		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);

+ 38 - 9
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2ClientCredentialsAuthenticationProviderTests.java

@@ -15,27 +15,33 @@
  */
 package org.springframework.security.oauth2.server.authorization.authentication;
 
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.Collections;
+import java.util.Set;
+
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
+
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
-import org.springframework.security.oauth2.jwt.JoseHeaderNames;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
+import org.springframework.security.oauth2.jwt.JoseHeaderNames;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
-
-import java.time.Instant;
-import java.time.temporal.ChronoUnit;
-import java.util.Collections;
-import java.util.Set;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -53,6 +59,7 @@ import static org.mockito.Mockito.when;
 public class OAuth2ClientCredentialsAuthenticationProviderTests {
 	private OAuth2AuthorizationService authorizationService;
 	private JwtEncoder jwtEncoder;
+	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer;
 	private OAuth2ClientCredentialsAuthenticationProvider authenticationProvider;
 
 	@Before
@@ -61,6 +68,8 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 		this.jwtEncoder = mock(JwtEncoder.class);
 		this.authenticationProvider = new OAuth2ClientCredentialsAuthenticationProvider(
 				this.authorizationService, this.jwtEncoder);
+		this.jwtCustomizer = mock(OAuth2TokenCustomizer.class);
+		this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer);
 	}
 
 	@Test
@@ -77,6 +86,13 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 				.hasMessage("jwtEncoder cannot be null");
 	}
 
+	@Test
+	public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("jwtCustomizer cannot be null");
+	}
+
 	@Test
 	public void supportsWhenSupportedAuthenticationThenTrue() {
 		assertThat(this.authenticationProvider.supports(OAuth2ClientCredentialsAuthenticationToken.class)).isTrue();
@@ -152,7 +168,7 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 		OAuth2ClientCredentialsAuthenticationToken authentication =
 				new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal, requestedScope);
 
-		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());
+		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(requestedScope));
 
 		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
 				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
@@ -165,11 +181,23 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
 		OAuth2ClientCredentialsAuthenticationToken authentication = new OAuth2ClientCredentialsAuthenticationToken(clientPrincipal);
 
-		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt());
+		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(registeredClient.getScopes()));
 
 		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
 				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
 
+		ArgumentCaptor<JwtEncodingContext> jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class);
+		verify(this.jwtCustomizer).customize(jwtEncodingContextCaptor.capture());
+		JwtEncodingContext jwtEncodingContext = jwtEncodingContextCaptor.getValue();
+		assertThat(jwtEncodingContext.getRegisteredClient()).isEqualTo(registeredClient);
+		assertThat(jwtEncodingContext.<Authentication>getPrincipal()).isEqualTo(clientPrincipal);
+		assertThat(jwtEncodingContext.getAuthorization()).isNull();
+		assertThat(jwtEncodingContext.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN);
+		assertThat(jwtEncodingContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS);
+		assertThat(jwtEncodingContext.<Authentication>getAuthorizationGrant()).isEqualTo(authentication);
+		assertThat(jwtEncodingContext.getHeaders()).isNotNull();
+		assertThat(jwtEncodingContext.getClaims()).isNotNull();
+
 		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
 		verify(this.authorizationService).save(authorizationCaptor.capture());
 		OAuth2Authorization authorization = authorizationCaptor.getValue();
@@ -182,13 +210,14 @@ public class OAuth2ClientCredentialsAuthenticationProviderTests {
 		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(authorization.getTokens().getAccessToken());
 	}
 
-	private static Jwt createJwt() {
+	private static Jwt createJwt(Set<String> scope) {
 		Instant issuedAt = Instant.now();
 		Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS);
 		return Jwt.withTokenValue("token")
 				.header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName())
 				.issuedAt(issuedAt)
 				.expiresAt(expiresAt)
+				.claim(OAuth2ParameterNames.SCOPE, scope)
 				.build();
 	}
 }

+ 47 - 12
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java

@@ -15,20 +15,30 @@
  */
 package org.springframework.security.oauth2.server.authorization.authentication;
 
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.Set;
+
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
+
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
-import org.springframework.security.oauth2.jwt.JoseHeaderNames;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
+import org.springframework.security.oauth2.jwt.JoseHeaderNames;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.token.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
@@ -36,14 +46,10 @@ import org.springframework.security.oauth2.server.authorization.TestOAuth2Author
 import org.springframework.security.oauth2.server.authorization.TokenType;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenCustomizer;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenMetadata;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2Tokens;
 
-import java.time.Instant;
-import java.time.temporal.ChronoUnit;
-import java.util.HashSet;
-import java.util.Set;
-
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
 import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
 import static org.mockito.ArgumentMatchers.any;
@@ -56,25 +62,24 @@ import static org.mockito.Mockito.when;
  * Tests for {@link OAuth2RefreshTokenAuthenticationProvider}.
  *
  * @author Alexey Nesterov
+ * @author Joe Grandja
  * @since 0.0.3
  */
 public class OAuth2RefreshTokenAuthenticationProviderTests {
 	private OAuth2AuthorizationService authorizationService;
 	private JwtEncoder jwtEncoder;
+	private OAuth2TokenCustomizer<JwtEncodingContext> jwtCustomizer;
 	private OAuth2RefreshTokenAuthenticationProvider authenticationProvider;
 
 	@Before
 	public void setUp() {
 		this.authorizationService = mock(OAuth2AuthorizationService.class);
 		this.jwtEncoder = mock(JwtEncoder.class);
-		Jwt jwt = Jwt.withTokenValue("refreshed-access-token")
-				.header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName())
-				.issuedAt(Instant.now())
-				.expiresAt(Instant.now().plus(1, ChronoUnit.HOURS))
-				.build();
-		when(this.jwtEncoder.encode(any(), any())).thenReturn(jwt);
+		when(this.jwtEncoder.encode(any(), any())).thenReturn(createJwt(Collections.singleton("scope1")));
 		this.authenticationProvider = new OAuth2RefreshTokenAuthenticationProvider(
 				this.authorizationService, this.jwtEncoder);
+		this.jwtCustomizer = mock(OAuth2TokenCustomizer.class);
+		this.authenticationProvider.setJwtCustomizer(this.jwtCustomizer);
 	}
 
 	@Test
@@ -93,6 +98,13 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 				.isEqualTo("jwtEncoder cannot be null");
 	}
 
+	@Test
+	public void setJwtCustomizerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authenticationProvider.setJwtCustomizer(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("jwtCustomizer cannot be null");
+	}
+
 	@Test
 	public void supportsWhenSupportedAuthenticationThenTrue() {
 		assertThat(this.authenticationProvider.supports(OAuth2RefreshTokenAuthenticationToken.class)).isTrue();
@@ -119,6 +131,18 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
 				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
 
+		ArgumentCaptor<JwtEncodingContext> jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class);
+		verify(this.jwtCustomizer).customize(jwtEncodingContextCaptor.capture());
+		JwtEncodingContext jwtEncodingContext = jwtEncodingContextCaptor.getValue();
+		assertThat(jwtEncodingContext.getRegisteredClient()).isEqualTo(registeredClient);
+		assertThat(jwtEncodingContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL));
+		assertThat(jwtEncodingContext.getAuthorization()).isEqualTo(authorization);
+		assertThat(jwtEncodingContext.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN);
+		assertThat(jwtEncodingContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
+		assertThat(jwtEncodingContext.<Authentication>getAuthorizationGrant()).isEqualTo(authentication);
+		assertThat(jwtEncodingContext.getHeaders()).isNotNull();
+		assertThat(jwtEncodingContext.getClaims()).isNotNull();
+
 		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
 		verify(this.authorizationService).save(authorizationCaptor.capture());
 		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
@@ -340,4 +364,15 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 				.extracting("errorCode")
 				.isEqualTo(OAuth2ErrorCodes.INVALID_GRANT);
 	}
+
+	private static Jwt createJwt(Set<String> scope) {
+		Instant issuedAt = Instant.now();
+		Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS);
+		return Jwt.withTokenValue("refreshed-access-token")
+				.header(JoseHeaderNames.ALG, SignatureAlgorithm.RS256.getName())
+				.issuedAt(issuedAt)
+				.expiresAt(expiresAt)
+				.claim(OAuth2ParameterNames.SCOPE, scope)
+				.build();
+	}
 }

+ 118 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtEncodingContextTests.java

@@ -0,0 +1,118 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.token;
+
+import org.junit.Test;
+
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
+import org.springframework.security.oauth2.jwt.JoseHeader;
+import org.springframework.security.oauth2.jwt.JwtClaimsSet;
+import org.springframework.security.oauth2.jwt.TestJoseHeaders;
+import org.springframework.security.oauth2.jwt.TestJwtClaimsSets;
+import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
+import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationAttributeNames;
+import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
+import org.springframework.security.oauth2.server.authorization.TokenType;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Tests for {@link JwtEncodingContext}.
+ *
+ * @author Joe Grandja
+ */
+public class JwtEncodingContextTests {
+
+	@Test
+	public void withWhenHeadersNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> JwtEncodingContext.with(null, TestJwtClaimsSets.jwtClaimsSet()))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("headersBuilder cannot be null");
+	}
+
+	@Test
+	public void withWhenClaimsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> JwtEncodingContext.with(TestJoseHeaders.joseHeader(), null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("claimsBuilder cannot be null");
+	}
+
+	@Test
+	public void setWhenValueNullThenThrowIllegalArgumentException() {
+		JwtEncodingContext.Builder builder = JwtEncodingContext
+				.with(TestJoseHeaders.joseHeader(), TestJwtClaimsSets.jwtClaimsSet());
+		assertThatThrownBy(() -> builder.registeredClient(null))
+				.isInstanceOf(IllegalArgumentException.class);
+		assertThatThrownBy(() -> builder.principal(null))
+				.isInstanceOf(IllegalArgumentException.class);
+		assertThatThrownBy(() -> builder.authorization(null))
+				.isInstanceOf(IllegalArgumentException.class);
+		assertThatThrownBy(() -> builder.tokenType(null))
+				.isInstanceOf(IllegalArgumentException.class);
+		assertThatThrownBy(() -> builder.authorizationGrantType(null))
+				.isInstanceOf(IllegalArgumentException.class);
+		assertThatThrownBy(() -> builder.authorizationGrant(null))
+				.isInstanceOf(IllegalArgumentException.class);
+		assertThatThrownBy(() -> builder.put(null, ""))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void buildWhenAllValuesProvidedThenAllValuesAreSet() {
+		JoseHeader.Builder headers = TestJoseHeaders.joseHeader();
+		JwtClaimsSet.Builder claims = TestJwtClaimsSets.jwtClaimsSet();
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		TestingAuthenticationToken principal = new TestingAuthenticationToken("principal", "password");
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization().build();
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
+		OAuth2AuthorizationRequest authorizationRequest = authorization.getAttribute(
+				OAuth2AuthorizationAttributeNames.AUTHORIZATION_REQUEST);
+		OAuth2AuthorizationCodeAuthenticationToken authorizationGrant =
+				new OAuth2AuthorizationCodeAuthenticationToken(
+						"code", clientPrincipal, authorizationRequest.getRedirectUri(), null);
+
+		JwtEncodingContext context = JwtEncodingContext.with(headers, claims)
+				.registeredClient(registeredClient)
+				.principal(principal)
+				.authorization(authorization)
+				.tokenType(TokenType.ACCESS_TOKEN)
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.authorizationGrant(authorizationGrant)
+				.put("custom-key-1", "custom-value-1")
+				.context(ctx -> ctx.put("custom-key-2", "custom-value-2"))
+				.build();
+
+		assertThat(context.getHeaders()).isEqualTo(headers);
+		assertThat(context.getClaims()).isEqualTo(claims);
+		assertThat(context.getRegisteredClient()).isEqualTo(registeredClient);
+		assertThat(context.<Authentication>getPrincipal()).isEqualTo(principal);
+		assertThat(context.getAuthorization()).isEqualTo(authorization);
+		assertThat(context.getTokenType()).isEqualTo(TokenType.ACCESS_TOKEN);
+		assertThat(context.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
+		assertThat(context.<Authentication>getAuthorizationGrant()).isEqualTo(authorizationGrant);
+		assertThat(context.<String>get("custom-key-1")).isEqualTo("custom-value-1");
+		assertThat(context.<String>get("custom-key-2")).isEqualTo("custom-value-2");
+	}
+
+}

+ 17 - 8
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,15 +15,25 @@
  */
 package org.springframework.security.oauth2.server.authorization.web;
 
+import java.nio.charset.StandardCharsets;
+import java.util.Set;
+import java.util.function.Consumer;
+
+import javax.servlet.FilterChain;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.ArgumentCaptor;
+
 import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
@@ -44,13 +54,6 @@ import org.springframework.security.oauth2.server.authorization.client.TestRegis
 import org.springframework.security.oauth2.server.authorization.token.OAuth2AuthorizationCode;
 import org.springframework.util.StringUtils;
 
-import javax.servlet.FilterChain;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-import java.nio.charset.StandardCharsets;
-import java.util.Set;
-import java.util.function.Consumer;
-
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
@@ -464,6 +467,8 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization authorization = authorizationCaptor.getValue();
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
+		assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
+				.isEqualTo(this.authentication);
 
 		OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
 		assertThat(authorizationCode).isNotNull();
@@ -511,6 +516,8 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization authorization = authorizationCaptor.getValue();
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
+		assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
+				.isEqualTo(this.authentication);
 
 		OAuth2AuthorizationCode authorizationCode = authorization.getTokens().getToken(OAuth2AuthorizationCode.class);
 		assertThat(authorizationCode).isNotNull();
@@ -556,6 +563,8 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		OAuth2Authorization authorization = authorizationCaptor.getValue();
 		assertThat(authorization.getRegisteredClientId()).isEqualTo(registeredClient.getId());
 		assertThat(authorization.getPrincipalName()).isEqualTo(this.authentication.getPrincipal().toString());
+		assertThat(authorization.<Authentication>getAttribute(OAuth2AuthorizationAttributeNames.PRINCIPAL))
+				.isEqualTo(this.authentication);
 
 		String state = authorization.getAttribute(OAuth2AuthorizationAttributeNames.STATE);
 		assertThat(state).isNotNull();