Browse Source

Access token is available when customizing ID Token

Closes gh-744
Joe Grandja 3 years ago
parent
commit
fdf0a2f94c

+ 6 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProvider.java

@@ -179,7 +179,12 @@ public final class OAuth2AuthorizationCodeAuthenticationProvider implements Auth
 		// ----- ID token -----
 		OidcIdToken idToken;
 		if (authorizationRequest.getScopes().contains(OidcScopes.OPENID)) {
-			tokenContext = tokenContextBuilder.tokenType(ID_TOKEN_TOKEN_TYPE).build();
+			// @formatter:off
+			tokenContext = tokenContextBuilder
+					.tokenType(ID_TOKEN_TOKEN_TYPE)
+					.authorization(authorizationBuilder.build())	// ID token customizer may need access to the access token and/or refresh token
+					.build();
+			// @formatter:on
 			OAuth2Token generatedIdToken = this.tokenGenerator.generate(tokenContext);
 			if (!(generatedIdToken instanceof Jwt)) {
 				OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,

+ 6 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java

@@ -176,7 +176,12 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
 		// ----- ID token -----
 		OidcIdToken idToken;
 		if (authorizedScopes.contains(OidcScopes.OPENID)) {
-			tokenContext = tokenContextBuilder.tokenType(ID_TOKEN_TOKEN_TYPE).build();
+			// @formatter:off
+			tokenContext = tokenContextBuilder
+					.tokenType(ID_TOKEN_TOKEN_TYPE)
+					.authorization(authorizationBuilder.build())	// ID token customizer may need access to the access token and/or refresh token
+					.build();
+			// @formatter:on
 			OAuth2Token generatedIdToken = this.tokenGenerator.generate(tokenContext);
 			if (!(generatedIdToken instanceof Jwt)) {
 				OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR,

+ 27 - 15
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/TestOAuth2Authorizations.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2021 the original author or authors.
+ * Copyright 2020-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -47,25 +47,30 @@ public class TestOAuth2Authorizations {
 		return authorization(registeredClient, Collections.emptyMap());
 	}
 
-	public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient,
-			OAuth2AccessToken accessToken, Map<String, Object> accessTokenClaims) {
-		return authorization(registeredClient, accessToken, accessTokenClaims, Collections.emptyMap());
-	}
-
 	public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient,
 			Map<String, Object> authorizationRequestAdditionalParameters) {
+		OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
+				"code", Instant.now(), Instant.now().plusSeconds(120));
 		OAuth2AccessToken accessToken = new OAuth2AccessToken(
 				OAuth2AccessToken.TokenType.BEARER, "access-token", Instant.now(), Instant.now().plusSeconds(300));
-		return authorization(registeredClient, accessToken, Collections.emptyMap(), authorizationRequestAdditionalParameters);
+		return authorization(registeredClient, authorizationCode, accessToken, Collections.emptyMap(), authorizationRequestAdditionalParameters);
 	}
 
-	private static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient,
-			OAuth2AccessToken accessToken, Map<String, Object> accessTokenClaims,
-			Map<String, Object> authorizationRequestAdditionalParameters) {
+	public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient,
+			OAuth2AuthorizationCode authorizationCode) {
+		return authorization(registeredClient, authorizationCode, null, Collections.emptyMap(), Collections.emptyMap());
+	}
+
+	public static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient,
+			OAuth2AccessToken accessToken, Map<String, Object> accessTokenClaims) {
 		OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
 				"code", Instant.now(), Instant.now().plusSeconds(120));
-		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(
-				"refresh-token", Instant.now(), Instant.now().plus(1, ChronoUnit.HOURS));
+		return authorization(registeredClient, authorizationCode, accessToken, accessTokenClaims, Collections.emptyMap());
+	}
+
+	private static OAuth2Authorization.Builder authorization(RegisteredClient registeredClient,
+			OAuth2AuthorizationCode authorizationCode, OAuth2AccessToken accessToken,
+			Map<String, Object> accessTokenClaims, Map<String, Object> authorizationRequestAdditionalParameters) {
 		OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
 				.authorizationUri("https://provider.com/oauth2/authorize")
 				.clientId(registeredClient.getClientId())
@@ -74,18 +79,25 @@ public class TestOAuth2Authorizations {
 				.additionalParameters(authorizationRequestAdditionalParameters)
 				.state("state")
 				.build();
-		return OAuth2Authorization.withRegisteredClient(registeredClient)
+		OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient)
 				.id("id")
 				.principalName("principal")
 				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
 				.token(authorizationCode)
-				.token(accessToken, (metadata) -> metadata.putAll(tokenMetadata(accessTokenClaims)))
-				.refreshToken(refreshToken)
 				.attribute(OAuth2ParameterNames.STATE, "state")
 				.attribute(OAuth2AuthorizationRequest.class.getName(), authorizationRequest)
 				.attribute(Principal.class.getName(),
 						new TestingAuthenticationToken("principal", null, "ROLE_A", "ROLE_B"))
 				.attribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME, authorizationRequest.getScopes());
+		if (accessToken != null) {
+			OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(
+					"refresh-token", Instant.now(), Instant.now().plus(1, ChronoUnit.HOURS));
+			builder
+				.token(accessToken, (metadata) -> metadata.putAll(tokenMetadata(accessTokenClaims)))
+				.refreshToken(refreshToken);
+		}
+
+		return builder;
 	}
 
 	private static Map<String, Object> tokenMetadata(Map<String, Object> tokenClaims) {

+ 8 - 4
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2AuthorizationCodeAuthenticationProviderTests.java

@@ -443,7 +443,9 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 	@Test
 	public void authenticateWhenValidCodeAndAuthenticationRequestThenReturnIdToken() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
-		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
+				"code", Instant.now(), Instant.now().plusSeconds(120));
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient, authorizationCode).build();
 		when(this.authorizationService.findByToken(eq(AUTHORIZATION_CODE), eq(AUTHORIZATION_CODE_TOKEN_TYPE)))
 				.thenReturn(authorization);
 
@@ -466,6 +468,7 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		assertThat(accessTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
 		assertThat(accessTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName()));
 		assertThat(accessTokenContext.getAuthorization()).isEqualTo(authorization);
+		assertThat(accessTokenContext.getAuthorization().getAccessToken()).isNull();
 		assertThat(accessTokenContext.getAuthorizedScopes())
 				.isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME));
 		assertThat(accessTokenContext.getTokenType()).isEqualTo(OAuth2TokenType.ACCESS_TOKEN);
@@ -481,7 +484,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1);
 		assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
 		assertThat(idTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName()));
-		assertThat(idTokenContext.getAuthorization()).isEqualTo(authorization);
+		assertThat(idTokenContext.getAuthorization()).isNotEqualTo(authorization);
+		assertThat(idTokenContext.getAuthorization().getAccessToken()).isNotNull();
 		assertThat(idTokenContext.getAuthorizedScopes())
 				.isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME));
 		assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN);
@@ -503,8 +507,8 @@ public class OAuth2AuthorizationCodeAuthenticationProviderTests {
 		assertThat(accessTokenAuthentication.getAccessToken().getScopes()).isEqualTo(accessTokenScopes);
 		assertThat(accessTokenAuthentication.getRefreshToken()).isNotNull();
 		assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
-		OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode = updatedAuthorization.getToken(OAuth2AuthorizationCode.class);
-		assertThat(authorizationCode.isInvalidated()).isTrue();
+		OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCodeToken = updatedAuthorization.getToken(OAuth2AuthorizationCode.class);
+		assertThat(authorizationCodeToken.isInvalidated()).isTrue();
 		OAuth2Authorization.Token<OidcIdToken> idToken = updatedAuthorization.getToken(OidcIdToken.class);
 		assertThat(idToken).isNotNull();
 		assertThat(accessTokenAuthentication.getAdditionalParameters())

+ 2 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java

@@ -233,7 +233,8 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1);
 		assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
 		assertThat(idTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName()));
-		assertThat(idTokenContext.getAuthorization()).isEqualTo(authorization);
+		assertThat(idTokenContext.getAuthorization()).isNotEqualTo(authorization);
+		assertThat(idTokenContext.getAuthorization().getAccessToken()).isNotEqualTo(authorization.getAccessToken());
 		assertThat(idTokenContext.getAuthorizedScopes())
 				.isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME));
 		assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN);