浏览代码

ID Token contains sid claim after refresh_token grant

Closes gh-1224
Joe Grandja 2 年之前
父节点
当前提交
00c114cc12

+ 10 - 5
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java

@@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.token;
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
 import java.util.Collections;
+import java.util.Date;
 
 import org.springframework.lang.Nullable;
 import org.springframework.security.core.session.SessionInformation;
@@ -126,11 +127,15 @@ public final class JwtGenerator implements OAuth2TokenGenerator<Jwt> {
 				if (StringUtils.hasText(nonce)) {
 					claimsBuilder.claim(IdTokenClaimNames.NONCE, nonce);
 				}
-			}
-			SessionInformation sessionInformation = context.get(SessionInformation.class);
-			if (sessionInformation != null) {
-				claimsBuilder.claim("sid", sessionInformation.getSessionId());
-				claimsBuilder.claim(IdTokenClaimNames.AUTH_TIME, sessionInformation.getLastRequest());
+				SessionInformation sessionInformation = context.get(SessionInformation.class);
+				if (sessionInformation != null) {
+					claimsBuilder.claim("sid", sessionInformation.getSessionId());
+					claimsBuilder.claim(IdTokenClaimNames.AUTH_TIME, sessionInformation.getLastRequest());
+				}
+			} else if (AuthorizationGrantType.REFRESH_TOKEN.equals(context.getAuthorizationGrantType())) {
+				OidcIdToken currentIdToken = context.getAuthorization().getToken(OidcIdToken.class).getToken();
+				claimsBuilder.claim("sid", currentIdToken.getClaim("sid"));
+				claimsBuilder.claim(IdTokenClaimNames.AUTH_TIME, currentIdToken.<Date>getClaim(IdTokenClaimNames.AUTH_TIME));
 			}
 		}
 		// @formatter:on

+ 11 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java

@@ -19,6 +19,7 @@ import java.security.Principal;
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
 import java.util.Collections;
+import java.util.Date;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.Map;
@@ -38,6 +39,7 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.OAuth2Token;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
 import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
@@ -196,7 +198,15 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 	@Test
 	public void authenticateWhenValidRefreshTokenThenReturnIdToken() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
-		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		OidcIdToken authorizedIdToken =  OidcIdToken.withTokenValue("id-token")
+				.issuer("https://provider.com")
+				.subject("subject")
+				.issuedAt(Instant.now())
+				.expiresAt(Instant.now().plusSeconds(60))
+				.claim("sid", "sessionId-1234")
+				.claim(IdTokenClaimNames.AUTH_TIME, Date.from(Instant.now()))
+				.build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).token(authorizedIdToken).build();
 		when(this.authorizationService.findByToken(
 				eq(authorization.getRefreshToken().getToken().getTokenValue()),
 				eq(OAuth2TokenType.REFRESH_TOKEN)))

+ 57 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OidcTests.java

@@ -237,6 +237,63 @@ public class OidcTests {
 		assertThat(idToken.<String>getClaim("sid")).isNotNull();
 	}
 
+	// gh-1224
+	@Test
+	public void requestWhenRefreshTokenRequestThenIdTokenContainsSidClaim() throws Exception {
+		this.spring.register(AuthorizationServerConfiguration.class).autowire();
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
+		this.registeredClientRepository.save(registeredClient);
+
+		MultiValueMap<String, String> authorizationRequestParameters = getAuthorizationRequestParameters(registeredClient);
+		MvcResult mvcResult = this.mvc.perform(get(DEFAULT_AUTHORIZATION_ENDPOINT_URI)
+						.params(authorizationRequestParameters)
+						.with(user("user").roles("A", "B")))
+				.andExpect(status().is3xxRedirection())
+				.andReturn();
+		String redirectedUrl = mvcResult.getResponse().getRedirectedUrl();
+		String expectedRedirectUri = authorizationRequestParameters.getFirst(OAuth2ParameterNames.REDIRECT_URI);
+		assertThat(redirectedUrl).matches(expectedRedirectUri + "\\?code=.{15,}&state=state");
+
+		String authorizationCode = extractParameterFromRedirectUri(redirectedUrl, "code");
+		OAuth2Authorization authorization = this.authorizationService.findByToken(authorizationCode, AUTHORIZATION_CODE_TOKEN_TYPE);
+
+		mvcResult = this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
+						.params(getTokenRequestParameters(registeredClient, authorization))
+						.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
+								registeredClient.getClientId(), registeredClient.getClientSecret())))
+				.andExpect(status().isOk())
+				.andReturn();
+
+		MockHttpServletResponse servletResponse = mvcResult.getResponse();
+		MockClientHttpResponse httpResponse = new MockClientHttpResponse(
+				servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus()));
+		OAuth2AccessTokenResponse accessTokenResponse = accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse);
+
+		Jwt idToken = this.jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN));
+
+		String sidClaim = idToken.getClaim("sid");
+		assertThat(sidClaim).isNotNull();
+
+		// Refresh access token
+		mvcResult = this.mvc.perform(post(DEFAULT_TOKEN_ENDPOINT_URI)
+						.param(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue())
+						.param(OAuth2ParameterNames.REFRESH_TOKEN, accessTokenResponse.getRefreshToken().getTokenValue())
+						.header(HttpHeaders.AUTHORIZATION, "Basic " + encodeBasicAuth(
+								registeredClient.getClientId(), registeredClient.getClientSecret())))
+				.andExpect(status().isOk())
+				.andReturn();
+
+		servletResponse = mvcResult.getResponse();
+		httpResponse = new MockClientHttpResponse(
+				servletResponse.getContentAsByteArray(), HttpStatus.valueOf(servletResponse.getStatus()));
+		accessTokenResponse = accessTokenHttpResponseConverter.read(OAuth2AccessTokenResponse.class, httpResponse);
+
+		idToken = this.jwtDecoder.decode((String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN));
+
+		assertThat(idToken.<String>getClaim("sid")).isEqualTo(sidClaim);
+	}
+
 	@Test
 	public void requestWhenLogoutRequestThenLogout() throws Exception {
 		this.spring.register(AuthorizationServerConfiguration.class).autowire();

+ 61 - 10
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java

@@ -31,9 +31,11 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.session.SessionInformation;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
 import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
@@ -46,6 +48,7 @@ import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
 import org.springframework.security.oauth2.server.authorization.TestOAuth2Authorizations;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.authentication.OAuth2RefreshTokenAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.context.TestAuthorizationServerContext;
@@ -152,7 +155,7 @@ public class JwtGeneratorTests {
 	}
 
 	@Test
-	public void generateWhenIdTokenTypeThenReturnJwt() {
+	public void generateWhenIdTokenTypeAndAuthorizationCodeGrantThenReturnJwt() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
 				.scope(OidcScopes.OPENID)
 				.tokenSettings(TokenSettings.builder().idTokenSignatureAlgorithm(SignatureAlgorithm.ES256).build())
@@ -190,6 +193,49 @@ public class JwtGeneratorTests {
 		assertGeneratedTokenType(tokenContext);
 	}
 
+	// gh-1224
+	@Test
+	public void generateWhenIdTokenTypeAndRefreshTokenGrantThenReturnJwt() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.scope(OidcScopes.OPENID)
+				.build();
+		OidcIdToken idToken =  OidcIdToken.withTokenValue("id-token")
+				.issuer("https://provider.com")
+				.subject("subject")
+				.issuedAt(Instant.now())
+				.expiresAt(Instant.now().plusSeconds(60))
+				.claim("sid", "sessionId-1234")
+				.claim(IdTokenClaimNames.AUTH_TIME, Date.from(Instant.now()))
+				.build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient)
+				.token(idToken)
+				.build();
+
+		OAuth2RefreshToken refreshToken = authorization.getRefreshToken().getToken();
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+				registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
+
+		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
+				refreshToken.getTokenValue(), clientPrincipal, null, null);
+
+		Authentication principal = authorization.getAttribute(Principal.class.getName());
+
+		// @formatter:off
+		OAuth2TokenContext tokenContext = DefaultOAuth2TokenContext.builder()
+				.registeredClient(registeredClient)
+				.principal(principal)
+				.authorizationServerContext(this.authorizationServerContext)
+				.authorization(authorization)
+				.authorizedScopes(authorization.getAuthorizedScopes())
+				.tokenType(ID_TOKEN_TOKEN_TYPE)
+				.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
+				.authorizationGrant(authentication)
+				.build();
+		// @formatter:on
+
+		assertGeneratedTokenType(tokenContext);
+	}
+
 	private void assertGeneratedTokenType(OAuth2TokenContext tokenContext) {
 		this.jwtGenerator.generate(tokenContext);
 
@@ -239,15 +285,20 @@ public class JwtGeneratorTests {
 			assertThat(scopes).isEqualTo(tokenContext.getAuthorizedScopes());
 		} else {
 			assertThat(jwtClaimsSet.<String>getClaim(IdTokenClaimNames.AZP)).isEqualTo(tokenContext.getRegisteredClient().getClientId());
-
-			OAuth2AuthorizationRequest authorizationRequest = tokenContext.getAuthorization().getAttribute(
-					OAuth2AuthorizationRequest.class.getName());
-			String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE);
-			assertThat(jwtClaimsSet.<String>getClaim(IdTokenClaimNames.NONCE)).isEqualTo(nonce);
-
-			SessionInformation sessionInformation = tokenContext.get(SessionInformation.class);
-			assertThat(jwtClaimsSet.<String>getClaim("sid")).isEqualTo(sessionInformation.getSessionId());
-			assertThat(jwtClaimsSet.<Date>getClaim(IdTokenClaimNames.AUTH_TIME)).isEqualTo(sessionInformation.getLastRequest());
+			if (tokenContext.getAuthorizationGrantType().equals(AuthorizationGrantType.AUTHORIZATION_CODE)) {
+				OAuth2AuthorizationRequest authorizationRequest = tokenContext.getAuthorization().getAttribute(
+						OAuth2AuthorizationRequest.class.getName());
+				String nonce = (String) authorizationRequest.getAdditionalParameters().get(OidcParameterNames.NONCE);
+				assertThat(jwtClaimsSet.<String>getClaim(IdTokenClaimNames.NONCE)).isEqualTo(nonce);
+
+				SessionInformation sessionInformation = tokenContext.get(SessionInformation.class);
+				assertThat(jwtClaimsSet.<String>getClaim("sid")).isEqualTo(sessionInformation.getSessionId());
+				assertThat(jwtClaimsSet.<Date>getClaim(IdTokenClaimNames.AUTH_TIME)).isEqualTo(sessionInformation.getLastRequest());
+			} else if (tokenContext.getAuthorizationGrantType().equals(AuthorizationGrantType.REFRESH_TOKEN)) {
+				OidcIdToken currentIdToken = tokenContext.getAuthorization().getToken(OidcIdToken.class).getToken();
+				assertThat(jwtClaimsSet.<String>getClaim("sid")).isEqualTo(currentIdToken.getClaim("sid"));
+				assertThat(jwtClaimsSet.<Date>getClaim(IdTokenClaimNames.AUTH_TIME)).isEqualTo(currentIdToken.<Date>getClaim(IdTokenClaimNames.AUTH_TIME));
+			}
 		}
 	}