Explorar o código

Refresh token grant may issue ID token

See https://openid.net/specs/openid-connect-core-1_0.html#RefreshTokenResponse

Closes gh-287
Anoop Garlapati %!s(int64=4) %!d(string=hai) anos
pai
achega
385fc37b1d

+ 59 - 6
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java

@@ -19,6 +19,9 @@ import java.security.Principal;
 import java.time.Duration;
 import java.time.Instant;
 import java.util.Base64;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
 import java.util.Set;
 
 import org.springframework.beans.factory.annotation.Autowired;
@@ -35,17 +38,20 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
 import org.springframework.security.oauth2.core.OAuth2TokenType;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
+import org.springframework.security.oauth2.core.oidc.OidcScopes;
+import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
 import org.springframework.security.oauth2.jwt.JoseHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtClaimsSet;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.config.ProviderSettings;
 import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
-import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
-import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
 import org.springframework.util.Assert;
 
 import static org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthenticationProviderUtils.getAuthenticatedClientElseThrowInvalidClient;
@@ -55,6 +61,7 @@ import static org.springframework.security.oauth2.server.authorization.authentic
  *
  * @author Alexey Nesterov
  * @author Joe Grandja
+ * @author Anoop Garlapati
  * @since 0.0.3
  * @see OAuth2RefreshTokenAuthenticationToken
  * @see OAuth2AccessTokenAuthenticationToken
@@ -66,6 +73,7 @@ import static org.springframework.security.oauth2.server.authorization.authentic
  * @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-6">Section 6 Refreshing an Access Token</a>
  */
 public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationProvider {
+	private static final OAuth2TokenType ID_TOKEN_TOKEN_TYPE = new OAuth2TokenType(OidcParameterNames.ID_TOKEN);
 	private static final StringKeyGenerator TOKEN_GENERATOR = new Base64StringKeyGenerator(Base64.getUrlEncoder().withoutPadding(), 96);
 	private final OAuth2AuthorizationService authorizationService;
 	private final JwtEncoder jwtEncoder;
@@ -174,19 +182,64 @@ public class OAuth2RefreshTokenAuthenticationProvider implements AuthenticationP
 			currentRefreshToken = generateRefreshToken(tokenSettings.refreshTokenTimeToLive());
 		}
 
+		Jwt jwtIdToken = null;
+		if (authorizedScopes.contains(OidcScopes.OPENID)) {
+			headersBuilder = JwtUtils.headers();
+			claimsBuilder = JwtUtils.idTokenClaims(
+					registeredClient, issuer, authorization.getPrincipalName(), null);
+
+			// @formatter:off
+			context = JwtEncodingContext.with(headersBuilder, claimsBuilder)
+					.registeredClient(registeredClient)
+					.principal(authorization.getAttribute(Principal.class.getName()))
+					.authorization(authorization)
+					.authorizedScopes(authorizedScopes)
+					.tokenType(ID_TOKEN_TOKEN_TYPE)
+					.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
+					.authorizationGrant(refreshTokenAuthentication)
+					.build();
+			// @formatter:on
+
+			this.jwtCustomizer.customize(context);
+
+			headers = context.getHeaders().build();
+			claims = context.getClaims().build();
+			jwtIdToken = this.jwtEncoder.encode(headers, claims);
+		}
+
+		OidcIdToken idToken;
+		if (jwtIdToken != null) {
+			idToken = new OidcIdToken(jwtIdToken.getTokenValue(), jwtIdToken.getIssuedAt(),
+					jwtIdToken.getExpiresAt(), jwtIdToken.getClaims());
+		} else {
+			idToken = null;
+		}
+
 		// @formatter:off
-		authorization = OAuth2Authorization.from(authorization)
+		OAuth2Authorization.Builder authorizationBuilder = OAuth2Authorization.from(authorization)
 				.token(accessToken,
 						(metadata) ->
 								metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, jwtAccessToken.getClaims()))
-				.refreshToken(currentRefreshToken)
-				.build();
+				.refreshToken(currentRefreshToken);
+		if (idToken != null) {
+			authorizationBuilder
+					.token(idToken,
+							(metadata) ->
+									metadata.put(OAuth2Authorization.Token.CLAIMS_METADATA_NAME, idToken.getClaims()));
+		}
+		authorization = authorizationBuilder.build();
 		// @formatter:on
 
 		this.authorizationService.save(authorization);
 
+		Map<String, Object> additionalParameters = Collections.emptyMap();
+		if (idToken != null) {
+			additionalParameters = new HashMap<>();
+			additionalParameters.put(OidcParameterNames.ID_TOKEN, idToken.getTokenValue());
+		}
+
 		return new OAuth2AccessTokenAuthenticationToken(
-				registeredClient, clientPrincipal, accessToken, currentRefreshToken);
+				registeredClient, clientPrincipal, accessToken, currentRefreshToken, additionalParameters);
 	}
 
 	@Override

+ 76 - 2
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java

@@ -19,7 +19,9 @@ import java.security.Principal;
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Map;
 import java.util.Set;
 
 import org.junit.Before;
@@ -36,23 +38,28 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
 import org.springframework.security.oauth2.core.OAuth2TokenType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
+import org.springframework.security.oauth2.core.oidc.OidcScopes;
+import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.security.oauth2.jwt.JoseHeaderNames;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
+import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
-import org.springframework.security.oauth2.server.authorization.JwtEncodingContext;
-import org.springframework.security.oauth2.server.authorization.OAuth2TokenCustomizer;
 
+import static org.assertj.core.api.Assertions.entry;
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
 import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -61,6 +68,7 @@ import static org.mockito.Mockito.when;
  *
  * @author Alexey Nesterov
  * @author Joe Grandja
+ * @author Anoop Garlapati
  * @since 0.0.3
  */
 public class OAuth2RefreshTokenAuthenticationProviderTests {
@@ -156,6 +164,72 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		assertThat(updatedAuthorization.getRefreshToken()).isEqualTo(authorization.getRefreshToken());
 	}
 
+	@Test
+	public void authenticateWhenValidRefreshTokenThenReturnIdToken() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		when(this.authorizationService.findByToken(
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
+				eq(OAuth2TokenType.REFRESH_TOKEN)))
+				.thenReturn(authorization);
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
+		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
+
+		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+		ArgumentCaptor<JwtEncodingContext> jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class);
+		verify(this.jwtCustomizer, times(2)).customize(jwtEncodingContextCaptor.capture());
+		// Access Token context
+		JwtEncodingContext accessTokenContext = jwtEncodingContextCaptor.getAllValues().get(0);
+		assertThat(accessTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
+		assertThat(accessTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName()));
+		assertThat(accessTokenContext.getAuthorization()).isEqualTo(authorization);
+		assertThat(accessTokenContext.getAuthorizedScopes())
+				.isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME));
+		assertThat(accessTokenContext.getTokenType()).isEqualTo(OAuth2TokenType.ACCESS_TOKEN);
+		assertThat(accessTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
+		assertThat(accessTokenContext.<OAuth2AuthorizationGrantAuthenticationToken>getAuthorizationGrant()).isEqualTo(authentication);
+		assertThat(accessTokenContext.getHeaders()).isNotNull();
+		assertThat(accessTokenContext.getClaims()).isNotNull();
+		Map<String, Object> claims = new HashMap<>();
+		accessTokenContext.getClaims().claims(claims::putAll);
+		assertThat(claims).flatExtracting(OAuth2ParameterNames.SCOPE)
+				.containsExactlyInAnyOrder(OidcScopes.OPENID, "scope1");
+		// ID Token context
+		JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1);
+		assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
+		assertThat(idTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName()));
+		assertThat(idTokenContext.getAuthorization()).isEqualTo(authorization);
+		assertThat(idTokenContext.getAuthorizedScopes())
+				.isEqualTo(authorization.getAttribute(OAuth2Authorization.AUTHORIZED_SCOPE_ATTRIBUTE_NAME));
+		assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN);
+		assertThat(idTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
+		assertThat(idTokenContext.<OAuth2AuthorizationGrantAuthenticationToken>getAuthorizationGrant()).isEqualTo(authentication);
+		assertThat(idTokenContext.getHeaders()).isNotNull();
+		assertThat(idTokenContext.getClaims()).isNotNull();
+
+		verify(this.jwtEncoder, times(2)).encode(any(), any());		// Access token and ID Token
+
+		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+		verify(this.authorizationService).save(authorizationCaptor.capture());
+		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
+
+		assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
+		assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
+		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
+		assertThat(updatedAuthorization.getAccessToken()).isNotEqualTo(authorization.getAccessToken());
+		OAuth2Authorization.Token<OidcIdToken> idToken = updatedAuthorization.getToken(OidcIdToken.class);
+		assertThat(idToken).isNotNull();
+		assertThat(accessTokenAuthentication.getAdditionalParameters())
+				.containsExactly(entry(OidcParameterNames.ID_TOKEN, idToken.getToken().getTokenValue()));
+		assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
+		// By default, refresh token is reused
+		assertThat(updatedAuthorization.getRefreshToken()).isEqualTo(authorization.getRefreshToken());
+	}
+
 	@Test
 	public void authenticateWhenReuseRefreshTokensFalseThenReturnNewRefreshToken() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()