瀏覽代碼

Fix generating ID token with null sid when refresh_token grant

Closes gh-1283
cbilodeau 2 年之前
父節點
當前提交
b6f3b5cc45

+ 3 - 1
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java

@@ -134,7 +134,9 @@ public final class JwtGenerator implements OAuth2TokenGenerator<Jwt> {
 				}
 			} else if (AuthorizationGrantType.REFRESH_TOKEN.equals(context.getAuthorizationGrantType())) {
 				OidcIdToken currentIdToken = context.getAuthorization().getToken(OidcIdToken.class).getToken();
-				claimsBuilder.claim("sid", currentIdToken.getClaim("sid"));
+				if (currentIdToken.hasClaim("sid")) {
+					claimsBuilder.claim("sid", currentIdToken.getClaim("sid"));
+				}
 				claimsBuilder.claim(IdTokenClaimNames.AUTH_TIME, currentIdToken.<Date>getClaim(IdTokenClaimNames.AUTH_TIME));
 			}
 		}

+ 81 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProviderTests.java

@@ -47,6 +47,7 @@ import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.security.oauth2.jwt.JoseHeaderNames;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtEncoder;
+import org.springframework.security.oauth2.jwt.JwtEncoderParameters;
 import org.springframework.security.oauth2.server.authorization.OAuth2Authorization;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationService;
 import org.springframework.security.oauth2.server.authorization.OAuth2TokenType;
@@ -250,7 +251,86 @@ public class OAuth2RefreshTokenAuthenticationProviderTests {
 		assertThat(idTokenContext.getJwsHeader()).isNotNull();
 		assertThat(idTokenContext.getClaims()).isNotNull();
 
-		verify(this.jwtEncoder, times(2)).encode(any());		// Access token and ID Token
+		ArgumentCaptor<JwtEncoderParameters> jwtEncoderParametersArgumentCaptor = ArgumentCaptor.forClass(JwtEncoderParameters.class);
+		verify(this.jwtEncoder, times(2)).encode(jwtEncoderParametersArgumentCaptor.capture());		// Access token and ID Token
+		JwtEncoderParameters jwtEncoderParameters = jwtEncoderParametersArgumentCaptor.getValue();
+		assertThat(jwtEncoderParameters.getClaims().getClaims().get("sid")).isNotNull();
+
+		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
+		verify(this.authorizationService).save(authorizationCaptor.capture());
+		OAuth2Authorization updatedAuthorization = authorizationCaptor.getValue();
+
+		assertThat(accessTokenAuthentication.getRegisteredClient().getId()).isEqualTo(updatedAuthorization.getRegisteredClientId());
+		assertThat(accessTokenAuthentication.getPrincipal()).isEqualTo(clientPrincipal);
+		assertThat(accessTokenAuthentication.getAccessToken()).isEqualTo(updatedAuthorization.getAccessToken().getToken());
+		assertThat(updatedAuthorization.getAccessToken()).isNotEqualTo(authorization.getAccessToken());
+		OAuth2Authorization.Token<OidcIdToken> idToken = updatedAuthorization.getToken(OidcIdToken.class);
+		assertThat(idToken).isNotNull();
+		assertThat(accessTokenAuthentication.getAdditionalParameters())
+				.containsExactly(entry(OidcParameterNames.ID_TOKEN, idToken.getToken().getTokenValue()));
+		assertThat(accessTokenAuthentication.getRefreshToken()).isEqualTo(updatedAuthorization.getRefreshToken().getToken());
+		// By default, refresh token is reused
+		assertThat(updatedAuthorization.getRefreshToken()).isEqualTo(authorization.getRefreshToken());
+	}
+
+	@Test
+	public void authenticateWhenValidRefreshTokenThenReturnIdTokenWithoutSid() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
+		OidcIdToken authorizedIdToken =  OidcIdToken.withTokenValue("id-token")
+				.issuer("https://provider.com")
+				.subject("subject")
+				.issuedAt(Instant.now())
+				.expiresAt(Instant.now().plusSeconds(60))
+				.claim(IdTokenClaimNames.AUTH_TIME, Date.from(Instant.now()))
+				.build();
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).token(authorizedIdToken).build();
+		when(this.authorizationService.findByToken(
+				eq(authorization.getRefreshToken().getToken().getTokenValue()),
+				eq(OAuth2TokenType.REFRESH_TOKEN)))
+				.thenReturn(authorization);
+
+		OAuth2ClientAuthenticationToken clientPrincipal = new OAuth2ClientAuthenticationToken(
+				registeredClient, ClientAuthenticationMethod.CLIENT_SECRET_BASIC, registeredClient.getClientSecret());
+		OAuth2RefreshTokenAuthenticationToken authentication = new OAuth2RefreshTokenAuthenticationToken(
+				authorization.getRefreshToken().getToken().getTokenValue(), clientPrincipal, null, null);
+
+		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+				(OAuth2AccessTokenAuthenticationToken) this.authenticationProvider.authenticate(authentication);
+
+		ArgumentCaptor<JwtEncodingContext> jwtEncodingContextCaptor = ArgumentCaptor.forClass(JwtEncodingContext.class);
+		verify(this.jwtCustomizer, times(2)).customize(jwtEncodingContextCaptor.capture());
+		// Access Token context
+		JwtEncodingContext accessTokenContext = jwtEncodingContextCaptor.getAllValues().get(0);
+		assertThat(accessTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
+		assertThat(accessTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName()));
+		assertThat(accessTokenContext.getAuthorization()).isEqualTo(authorization);
+		assertThat(accessTokenContext.getAuthorizedScopes()).isEqualTo(authorization.getAuthorizedScopes());
+		assertThat(accessTokenContext.getTokenType()).isEqualTo(OAuth2TokenType.ACCESS_TOKEN);
+		assertThat(accessTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
+		assertThat(accessTokenContext.<OAuth2AuthorizationGrantAuthenticationToken>getAuthorizationGrant()).isEqualTo(authentication);
+		assertThat(accessTokenContext.getJwsHeader()).isNotNull();
+		assertThat(accessTokenContext.getClaims()).isNotNull();
+		Map<String, Object> claims = new HashMap<>();
+		accessTokenContext.getClaims().claims(claims::putAll);
+		assertThat(claims).flatExtracting(OAuth2ParameterNames.SCOPE)
+				.containsExactlyInAnyOrder(OidcScopes.OPENID, "scope1");
+		// ID Token context
+		JwtEncodingContext idTokenContext = jwtEncodingContextCaptor.getAllValues().get(1);
+		assertThat(idTokenContext.getRegisteredClient()).isEqualTo(registeredClient);
+		assertThat(idTokenContext.<Authentication>getPrincipal()).isEqualTo(authorization.getAttribute(Principal.class.getName()));
+		assertThat(idTokenContext.getAuthorization()).isNotEqualTo(authorization);
+		assertThat(idTokenContext.getAuthorization().getAccessToken()).isNotEqualTo(authorization.getAccessToken());
+		assertThat(idTokenContext.getAuthorizedScopes()).isEqualTo(authorization.getAuthorizedScopes());
+		assertThat(idTokenContext.getTokenType().getValue()).isEqualTo(OidcParameterNames.ID_TOKEN);
+		assertThat(idTokenContext.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
+		assertThat(idTokenContext.<OAuth2AuthorizationGrantAuthenticationToken>getAuthorizationGrant()).isEqualTo(authentication);
+		assertThat(idTokenContext.getJwsHeader()).isNotNull();
+		assertThat(idTokenContext.getClaims()).isNotNull();
+
+		ArgumentCaptor<JwtEncoderParameters> jwtEncoderParametersArgumentCaptor = ArgumentCaptor.forClass(JwtEncoderParameters.class);
+		verify(this.jwtEncoder, times(2)).encode(jwtEncoderParametersArgumentCaptor.capture());		// Access token and ID Token
+		JwtEncoderParameters jwtEncoderParameters = jwtEncoderParametersArgumentCaptor.getValue();
+		assertThat(jwtEncoderParameters.getClaims().getClaims().get("sid")).isNull();
 
 		ArgumentCaptor<OAuth2Authorization> authorizationCaptor = ArgumentCaptor.forClass(OAuth2Authorization.class);
 		verify(this.authorizationService).save(authorizationCaptor.capture());