Forráskód Böngészése

Fix DPoP jkt claim to be JWK SHA-256 thumbprint

Closes gh-2007
Joe Grandja 3 hónapja
szülő
commit
07f9621b02

+ 4 - 15
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/DefaultOAuth2TokenCustomizers.java

@@ -16,7 +16,6 @@
 package org.springframework.security.oauth2.server.authorization.config.annotation.web.configurers;
 
 import java.security.MessageDigest;
-import java.security.PublicKey;
 import java.security.cert.X509Certificate;
 import java.util.Base64;
 import java.util.Collections;
@@ -24,7 +23,6 @@ import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.Map;
 
-import com.nimbusds.jose.jwk.AsymmetricJWK;
 import com.nimbusds.jose.jwk.JWK;
 
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
@@ -91,25 +89,22 @@ final class DefaultOAuth2TokenCustomizers {
 		// Add 'cnf' claim for OAuth 2.0 Demonstrating Proof of Possession (DPoP)
 		Jwt dPoPProofJwt = tokenContext.get(OAuth2TokenContext.DPOP_PROOF_KEY);
 		if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenContext.getTokenType()) && dPoPProofJwt != null) {
-			PublicKey publicKey = null;
+			JWK jwk = null;
 			@SuppressWarnings("unchecked")
 			Map<String, Object> jwkJson = (Map<String, Object>) dPoPProofJwt.getHeaders().get("jwk");
 			try {
-				JWK jwk = JWK.parse(jwkJson);
-				if (jwk instanceof AsymmetricJWK asymmetricJWK) {
-					publicKey = asymmetricJWK.toPublicKey();
-				}
+				jwk = JWK.parse(jwkJson);
 			}
 			catch (Exception ignored) {
 			}
-			if (publicKey == null) {
+			if (jwk == null) {
 				OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF,
 						"jwk header is missing or invalid.", null);
 				throw new OAuth2AuthenticationException(error);
 			}
 
 			try {
-				String sha256Thumbprint = computeSHA256Thumbprint(publicKey);
+				String sha256Thumbprint = jwk.computeThumbprint().toString();
 				if (cnfClaims == null) {
 					cnfClaims = new HashMap<>();
 				}
@@ -149,10 +144,4 @@ final class DefaultOAuth2TokenCustomizers {
 		return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
 	}
 
-	private static String computeSHA256Thumbprint(PublicKey publicKey) throws Exception {
-		MessageDigest md = MessageDigest.getInstance("SHA-256");
-		byte[] digest = md.digest(publicKey.getEncoded());
-		return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
-	}
-
 }

+ 2 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2AuthorizationCodeGrantTests.java

@@ -1011,6 +1011,8 @@ public class OAuth2AuthorizationCodeGrantTests {
 		@SuppressWarnings("unchecked")
 		Map<String, Object> cnfClaims = (Map<String, Object>) authorization.getAccessToken().getClaims().get("cnf");
 		assertThat(cnfClaims).containsKey("jkt");
+		String jwkThumbprintClaim = (String) cnfClaims.get("jkt");
+		assertThat(jwkThumbprintClaim).isEqualTo(TestJwks.DEFAULT_EC_JWK.toPublicJWK().computeThumbprint().toString());
 	}
 
 	@Test

+ 2 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2DeviceCodeGrantTests.java

@@ -605,6 +605,8 @@ public class OAuth2DeviceCodeGrantTests {
 		@SuppressWarnings("unchecked")
 		Map<String, Object> cnfClaims = (Map<String, Object>) authorization.getAccessToken().getClaims().get("cnf");
 		assertThat(cnfClaims).containsKey("jkt");
+		String jwkThumbprintClaim = (String) cnfClaims.get("jkt");
+		assertThat(jwkThumbprintClaim).isEqualTo(TestJwks.DEFAULT_EC_JWK.toPublicJWK().computeThumbprint().toString());
 	}
 
 	private static String generateDPoPProof(String tokenEndpointUri) {

+ 2 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2RefreshTokenGrantTests.java

@@ -319,6 +319,8 @@ public class OAuth2RefreshTokenGrantTests {
 		@SuppressWarnings("unchecked")
 		Map<String, Object> cnfClaims = (Map<String, Object>) authorization.getAccessToken().getClaims().get("cnf");
 		assertThat(cnfClaims).containsKey("jkt");
+		String jwkThumbprintClaim = (String) cnfClaims.get("jkt");
+		assertThat(jwkThumbprintClaim).isEqualTo(TestJwks.DEFAULT_EC_JWK.toPublicJWK().computeThumbprint().toString());
 	}
 
 	@Test