Browse Source

Polish gh-787

Joe Grandja 3 years ago
parent
commit
502fa24cfb

+ 9 - 5
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/token/JwtGenerator.java

@@ -27,6 +27,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.oidc.IdTokenClaimNames;
 import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.core.oidc.endpoint.OidcParameterNames;
+import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.security.oauth2.jwt.JwsHeader;
 import org.springframework.security.oauth2.jwt.Jwt;
@@ -89,14 +90,15 @@ public final class JwtGenerator implements OAuth2TokenGenerator<Jwt> {
 
 		Instant issuedAt = Instant.now();
 		Instant expiresAt;
-		JwsHeader.Builder headersBuilder;
+		JwsAlgorithm jwsAlgorithm = SignatureAlgorithm.RS256;
 		if (OidcParameterNames.ID_TOKEN.equals(context.getTokenType().getValue())) {
 			// TODO Allow configuration for ID Token time-to-live
 			expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES);
-			headersBuilder = JwsHeader.with(registeredClient.getTokenSettings().getIdTokenSignatureAlgorithm());
+			if (registeredClient.getTokenSettings().getIdTokenSignatureAlgorithm() != null) {
+				jwsAlgorithm = registeredClient.getTokenSettings().getIdTokenSignatureAlgorithm();
+			}
 		} else {
 			expiresAt = issuedAt.plus(registeredClient.getTokenSettings().getAccessTokenTimeToLive());
-			headersBuilder = JwsHeader.with(SignatureAlgorithm.RS256);
 		}
 
 		// @formatter:off
@@ -128,9 +130,11 @@ public final class JwtGenerator implements OAuth2TokenGenerator<Jwt> {
 		}
 		// @formatter:on
 
+		JwsHeader.Builder jwsHeaderBuilder = JwsHeader.with(jwsAlgorithm);
+
 		if (this.jwtCustomizer != null) {
 			// @formatter:off
-			JwtEncodingContext.Builder jwtContextBuilder = JwtEncodingContext.with(headersBuilder, claimsBuilder)
+			JwtEncodingContext.Builder jwtContextBuilder = JwtEncodingContext.with(jwsHeaderBuilder, claimsBuilder)
 					.registeredClient(context.getRegisteredClient())
 					.principal(context.getPrincipal())
 					.authorizationServerContext(context.getAuthorizationServerContext())
@@ -149,7 +153,7 @@ public final class JwtGenerator implements OAuth2TokenGenerator<Jwt> {
 			this.jwtCustomizer.customize(jwtContext);
 		}
 
-		JwsHeader jwsHeader = headersBuilder.build();
+		JwsHeader jwsHeader = jwsHeaderBuilder.build();
 		JwtClaimsSet claims = claimsBuilder.build();
 
 		Jwt jwt = this.jwtEncoder.encode(JwtEncoderParameters.from(jwsHeader, claims));

+ 11 - 7
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/token/JwtGeneratorTests.java

@@ -152,7 +152,10 @@ public class JwtGeneratorTests {
 
 	@Test
 	public void generateWhenIdTokenTypeThenReturnJwt() {
-		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().scope(OidcScopes.OPENID).build();
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient()
+				.scope(OidcScopes.OPENID)
+				.tokenSettings(TokenSettings.builder().idTokenSignatureAlgorithm(SignatureAlgorithm.ES256).build())
+				.build();
 		Map<String, Object> authenticationRequestAdditionalParameters = new HashMap<>();
 		authenticationRequestAdditionalParameters.put(OidcParameterNames.NONCE, "nonce");
 		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(
@@ -201,6 +204,13 @@ public class JwtGeneratorTests {
 		ArgumentCaptor<JwtEncoderParameters> jwtEncoderParametersCaptor = ArgumentCaptor.forClass(JwtEncoderParameters.class);
 		verify(this.jwtEncoder).encode(jwtEncoderParametersCaptor.capture());
 
+		JwsHeader jwsHeader = jwtEncoderParametersCaptor.getValue().getJwsHeader();
+		if (OidcParameterNames.ID_TOKEN.equals(tokenContext.getTokenType().getValue())) {
+			assertThat(jwsHeader.getAlgorithm()).isEqualTo(tokenContext.getRegisteredClient().getTokenSettings().getIdTokenSignatureAlgorithm());
+		} else {
+			assertThat(jwsHeader.getAlgorithm()).isEqualTo(SignatureAlgorithm.RS256);
+		}
+
 		JwtClaimsSet jwtClaimsSet = jwtEncoderParametersCaptor.getValue().getClaims();
 		assertThat(jwtClaimsSet.getIssuer().toExternalForm()).isEqualTo(tokenContext.getAuthorizationServerContext().getIssuer());
 		assertThat(jwtClaimsSet.getSubject()).isEqualTo(tokenContext.getAuthorization().getPrincipalName());
@@ -208,20 +218,14 @@ public class JwtGeneratorTests {
 
 		Instant issuedAt = Instant.now();
 		Instant expiresAt;
-		JwsHeader.Builder headersBuilder;
 		if (tokenContext.getTokenType().equals(OAuth2TokenType.ACCESS_TOKEN)) {
 			expiresAt = issuedAt.plus(tokenContext.getRegisteredClient().getTokenSettings().getAccessTokenTimeToLive());
-			headersBuilder = JwsHeader.with(SignatureAlgorithm.RS256);
 		} else {
 			expiresAt = issuedAt.plus(30, ChronoUnit.MINUTES);
-			headersBuilder = JwsHeader.with(tokenContext.getRegisteredClient().getTokenSettings().getIdTokenSignatureAlgorithm());
 		}
 		assertThat(jwtClaimsSet.getIssuedAt()).isBetween(issuedAt.minusSeconds(1), issuedAt.plusSeconds(1));
 		assertThat(jwtClaimsSet.getExpiresAt()).isBetween(expiresAt.minusSeconds(1), expiresAt.plusSeconds(1));
 
-		JwsHeader jwsHeader = jwtEncoderParametersCaptor.getValue().getJwsHeader();
-		assertThat(jwsHeader.getAlgorithm()).isEqualTo(headersBuilder.build().getAlgorithm());
-
 		if (tokenContext.getTokenType().equals(OAuth2TokenType.ACCESS_TOKEN)) {
 			assertThat(jwtClaimsSet.getNotBefore()).isBetween(issuedAt.minusSeconds(1), issuedAt.plusSeconds(1));