浏览代码

Verify DPoP Proof public key during refresh_token grant for public clients

Issue gh-1813

Closes gh-1949
Joe Grandja 5 月之前
父节点
当前提交
48fd6ab60f

+ 66 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/authentication/OAuth2RefreshTokenAuthenticationProvider.java

@@ -15,12 +15,17 @@
  */
 package org.springframework.security.oauth2.server.authorization.authentication;
 
+import java.security.MessageDigest;
 import java.security.Principal;
+import java.security.PublicKey;
+import java.util.Base64;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Set;
 
+import com.nimbusds.jose.jwk.AsymmetricJWK;
+import com.nimbusds.jose.jwk.JWK;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 
@@ -29,6 +34,8 @@ import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClaimAccessor;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -48,6 +55,7 @@ import org.springframework.security.oauth2.server.authorization.token.DefaultOAu
 import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenContext;
 import org.springframework.security.oauth2.server.authorization.token.OAuth2TokenGenerator;
 import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
 
 /**
  * An {@link AuthenticationProvider} implementation for the OAuth 2.0 Refresh Token Grant.
@@ -160,6 +168,14 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
 		// Verify the DPoP Proof (if available)
 		Jwt dPoPProof = DPoPProofVerifier.verifyIfAvailable(refreshTokenAuthentication);
 
+		if (dPoPProof != null
+				& clientPrincipal.getClientAuthenticationMethod().equals(ClientAuthenticationMethod.NONE)) {
+			// For public clients, verify the DPoP Proof public key is same as (current)
+			// access token public key binding
+			Map<String, Object> accessTokenClaims = authorization.getAccessToken().getClaims();
+			verifyDPoPProofPublicKey(dPoPProof, () -> accessTokenClaims);
+		}
+
 		if (this.logger.isTraceEnabled()) {
 			this.logger.trace("Validated token request parameters");
 		}
@@ -275,4 +291,54 @@ public final class OAuth2RefreshTokenAuthenticationProvider implements Authentic
 		return OAuth2RefreshTokenAuthenticationToken.class.isAssignableFrom(authentication);
 	}
 
+	private static void verifyDPoPProofPublicKey(Jwt dPoPProof, ClaimAccessor accessTokenClaims) {
+		PublicKey publicKey = null;
+		@SuppressWarnings("unchecked")
+		Map<String, Object> jwkJson = (Map<String, Object>) dPoPProof.getHeaders().get("jwk");
+		try {
+			JWK jwk = JWK.parse(jwkJson);
+			if (jwk instanceof AsymmetricJWK) {
+				publicKey = ((AsymmetricJWK) jwk).toPublicKey();
+			}
+		}
+		catch (Exception ignored) {
+		}
+		if (publicKey == null) {
+			OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF,
+					"jwk header is missing or invalid.", null);
+			throw new OAuth2AuthenticationException(error);
+		}
+
+		String jwkThumbprint;
+		try {
+			jwkThumbprint = computeSHA256(publicKey);
+		}
+		catch (Exception ex) {
+			OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF,
+					"Failed to compute SHA-256 Thumbprint for jwk.", null);
+			throw new OAuth2AuthenticationException(error);
+		}
+
+		String jwkThumbprintClaim = null;
+		Map<String, Object> confirmationMethodClaim = accessTokenClaims.getClaimAsMap("cnf");
+		if (!CollectionUtils.isEmpty(confirmationMethodClaim) && confirmationMethodClaim.containsKey("jkt")) {
+			jwkThumbprintClaim = (String) confirmationMethodClaim.get("jkt");
+		}
+		if (jwkThumbprintClaim == null) {
+			OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, "jkt claim is missing.", null);
+			throw new OAuth2AuthenticationException(error);
+		}
+
+		if (!jwkThumbprint.equals(jwkThumbprintClaim)) {
+			OAuth2Error error = new OAuth2Error(OAuth2ErrorCodes.INVALID_DPOP_PROOF, "jwk header is invalid.", null);
+			throw new OAuth2AuthenticationException(error);
+		}
+	}
+
+	private static String computeSHA256(PublicKey publicKey) throws Exception {
+		MessageDigest md = MessageDigest.getInstance("SHA-256");
+		byte[] digest = md.digest(publicKey.getEncoded());
+		return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
+	}
+
 }

+ 108 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/config/annotation/web/configurers/OAuth2RefreshTokenGrantTests.java

@@ -17,9 +17,12 @@ package org.springframework.security.oauth2.server.authorization.config.annotati
 
 import java.net.URLEncoder;
 import java.nio.charset.StandardCharsets;
+import java.security.MessageDigest;
 import java.security.Principal;
+import java.security.PublicKey;
 import java.time.Instant;
 import java.util.Base64;
+import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
@@ -279,6 +282,105 @@ public class OAuth2RefreshTokenGrantTests {
 			.andExpect(jsonPath("$.scope").isNotEmpty());
 	}
 
+	@Test
+	public void requestWhenRefreshTokenRequestWithPublicClientAndDPoPProofThenReturnDPoPBoundAccessToken()
+			throws Exception {
+		this.spring.register(AuthorizationServerConfigurationWithPublicClientAuthentication.class).autowire();
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient()
+			.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
+			.build();
+		this.registeredClientRepository.save(registeredClient);
+
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.DPOP,
+				"dpop-bound-access-token", Instant.now(), Instant.now().plusSeconds(300));
+		Map<String, Object> accessTokenClaims = new HashMap<>();
+		PublicKey publicKey = TestJwks.DEFAULT_EC_JWK.toPublicKey();
+		Map<String, Object> cnfClaim = new HashMap<>();
+		cnfClaim.put("jkt", computeSHA256(publicKey));
+		accessTokenClaims.put("cnf", cnfClaim);
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+			.authorization(registeredClient, accessToken, accessTokenClaims)
+			.build();
+		this.authorizationService.save(authorization);
+
+		String tokenEndpointUri = "http://localhost" + DEFAULT_TOKEN_ENDPOINT_URI;
+		String dPoPProof = generateDPoPProof(tokenEndpointUri);
+
+		this.mvc
+			.perform(post(DEFAULT_TOKEN_ENDPOINT_URI).params(getRefreshTokenRequestParameters(authorization))
+				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
+				.header(OAuth2AccessToken.TokenType.DPOP.getValue(), dPoPProof))
+			.andExpect(status().isOk())
+			.andExpect(jsonPath("$.token_type").value(OAuth2AccessToken.TokenType.DPOP.getValue()));
+
+		authorization = this.authorizationService.findById(authorization.getId());
+		assertThat(authorization.getAccessToken().getClaims()).containsKey("cnf");
+		@SuppressWarnings("unchecked")
+		Map<String, Object> cnfClaims = (Map<String, Object>) authorization.getAccessToken().getClaims().get("cnf");
+		assertThat(cnfClaims).containsKey("jkt");
+	}
+
+	@Test
+	public void requestWhenRefreshTokenRequestWithPublicClientAndDPoPProofAndAccessTokenNotBoundThenBadRequest()
+			throws Exception {
+		this.spring.register(AuthorizationServerConfigurationWithPublicClientAuthentication.class).autowire();
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient()
+			.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
+			.build();
+		this.registeredClientRepository.save(registeredClient);
+
+		OAuth2Authorization authorization = TestOAuth2Authorizations.authorization(registeredClient).build();
+		this.authorizationService.save(authorization);
+
+		String tokenEndpointUri = "http://localhost" + DEFAULT_TOKEN_ENDPOINT_URI;
+		String dPoPProof = generateDPoPProof(tokenEndpointUri);
+
+		this.mvc
+			.perform(post(DEFAULT_TOKEN_ENDPOINT_URI).params(getRefreshTokenRequestParameters(authorization))
+				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
+				.header(OAuth2AccessToken.TokenType.DPOP.getValue(), dPoPProof))
+			.andExpect(status().isBadRequest())
+			.andExpect(jsonPath("$.error").value(OAuth2ErrorCodes.INVALID_DPOP_PROOF))
+			.andExpect(jsonPath("$.error_description").value("jkt claim is missing."));
+	}
+
+	@Test
+	public void requestWhenRefreshTokenRequestWithPublicClientAndDPoPProofAndDifferentPublicKeyThenBadRequest()
+			throws Exception {
+		this.spring.register(AuthorizationServerConfigurationWithPublicClientAuthentication.class).autowire();
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient()
+			.authorizationGrantType(AuthorizationGrantType.REFRESH_TOKEN)
+			.build();
+		this.registeredClientRepository.save(registeredClient);
+
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.DPOP,
+				"dpop-bound-access-token", Instant.now(), Instant.now().plusSeconds(300));
+		Map<String, Object> accessTokenClaims = new HashMap<>();
+		// Bind access token to different public key
+		PublicKey publicKey = TestJwks.DEFAULT_RSA_JWK.toPublicKey();
+		Map<String, Object> cnfClaim = new HashMap<>();
+		cnfClaim.put("jkt", computeSHA256(publicKey));
+		accessTokenClaims.put("cnf", cnfClaim);
+		OAuth2Authorization authorization = TestOAuth2Authorizations
+			.authorization(registeredClient, accessToken, accessTokenClaims)
+			.build();
+		this.authorizationService.save(authorization);
+
+		String tokenEndpointUri = "http://localhost" + DEFAULT_TOKEN_ENDPOINT_URI;
+		String dPoPProof = generateDPoPProof(tokenEndpointUri);
+
+		this.mvc
+			.perform(post(DEFAULT_TOKEN_ENDPOINT_URI).params(getRefreshTokenRequestParameters(authorization))
+				.param(OAuth2ParameterNames.CLIENT_ID, registeredClient.getClientId())
+				.header(OAuth2AccessToken.TokenType.DPOP.getValue(), dPoPProof))
+			.andExpect(status().isBadRequest())
+			.andExpect(jsonPath("$.error").value(OAuth2ErrorCodes.INVALID_DPOP_PROOF))
+			.andExpect(jsonPath("$.error_description").value("jwk header is invalid."));
+	}
+
 	@Test
 	public void requestWhenRefreshTokenRequestWithDPoPProofThenReturnDPoPBoundAccessToken() throws Exception {
 		this.spring.register(AuthorizationServerConfiguration.class).autowire();
@@ -327,6 +429,12 @@ public class OAuth2RefreshTokenGrantTests {
 		return jwt.getTokenValue();
 	}
 
+	private static String computeSHA256(PublicKey publicKey) throws Exception {
+		MessageDigest md = MessageDigest.getInstance("SHA-256");
+		byte[] digest = md.digest(publicKey.getEncoded());
+		return Base64.getUrlEncoder().withoutPadding().encodeToString(digest);
+	}
+
 	private static MultiValueMap<String, String> getRefreshTokenRequestParameters(OAuth2Authorization authorization) {
 		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
 		parameters.set(OAuth2ParameterNames.GRANT_TYPE, AuthorizationGrantType.REFRESH_TOKEN.getValue());