Pārlūkot izejas kodu

Jwt client authentication converter detects new key

Closes gh-9814
Joe Grandja 4 gadi atpakaļ
vecāks
revīzija
eb6ed283e0

+ 29 - 4
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverter.java

@@ -80,7 +80,7 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab
 
 	private final Function<ClientRegistration, JWK> jwkResolver;
 
-	private final Map<String, NimbusJwsEncoder> jwsEncoders = new ConcurrentHashMap<>();
+	private final Map<String, JwsEncoderHolder> jwsEncoders = new ConcurrentHashMap<>();
 
 	/**
 	 * Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the
@@ -140,12 +140,16 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab
 		JoseHeader joseHeader = headersBuilder.build();
 		JwtClaimsSet jwtClaimsSet = claimsBuilder.build();
 
-		NimbusJwsEncoder jwsEncoder = this.jwsEncoders.computeIfAbsent(clientRegistration.getRegistrationId(),
-				(clientRegistrationId) -> {
+		JwsEncoderHolder jwsEncoderHolder = this.jwsEncoders.compute(clientRegistration.getRegistrationId(),
+				(clientRegistrationId, currentJwsEncoderHolder) -> {
+					if (currentJwsEncoderHolder != null && currentJwsEncoderHolder.getJwk().equals(jwk)) {
+						return currentJwsEncoderHolder;
+					}
 					JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
-					return new NimbusJwsEncoder(jwkSource);
+					return new JwsEncoderHolder(new NimbusJwsEncoder(jwkSource), jwk);
 				});
 
+		NimbusJwsEncoder jwsEncoder = jwsEncoderHolder.getJwsEncoder();
 		Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet);
 
 		MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
@@ -180,4 +184,25 @@ public final class NimbusJwtClientAuthenticationParametersConverter<T extends Ab
 		return jwsAlgorithm;
 	}
 
+	private static final class JwsEncoderHolder {
+
+		private final NimbusJwsEncoder jwsEncoder;
+
+		private final JWK jwk;
+
+		private JwsEncoderHolder(NimbusJwsEncoder jwsEncoder, JWK jwk) {
+			this.jwsEncoder = jwsEncoder;
+			this.jwk = jwk;
+		}
+
+		private NimbusJwsEncoder getJwsEncoder() {
+			return this.jwsEncoder;
+		}
+
+		private JWK getJwk() {
+			return this.jwk;
+		}
+
+	}
+
 }

+ 56 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/NimbusJwtClientAuthenticationParametersConverterTests.java

@@ -16,7 +16,12 @@
 
 package org.springframework.security.oauth2.client.endpoint;
 
+import java.security.KeyPair;
+import java.security.KeyPairGenerator;
+import java.security.interfaces.RSAPrivateKey;
+import java.security.interfaces.RSAPublicKey;
 import java.util.Collections;
+import java.util.UUID;
 import java.util.function.Function;
 
 import com.nimbusds.jose.jwk.JWK;
@@ -42,6 +47,7 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verifyNoInteractions;
@@ -172,4 +178,54 @@ public class NimbusJwtClientAuthenticationParametersConverterTests {
 		assertThat(jws.getExpiresAt()).isNotNull();
 	}
 
+	// gh-9814
+	@Test
+	public void convertWhenClientKeyChangesThenNewKeyUsed() throws Exception {
+		// @formatter:off
+		ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
+				.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
+				.build();
+		// @formatter:on
+
+		RSAKey rsaJwk1 = TestJwks.DEFAULT_RSA_JWK;
+		given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk1);
+
+		OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
+				clientRegistration);
+		MultiValueMap<String, String> parameters = this.converter.convert(clientCredentialsGrantRequest);
+
+		String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
+		NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk1.toRSAPublicKey()).build();
+		jwtDecoder.decode(encodedJws);
+
+		RSAKey rsaJwk2 = generateRsaJwk();
+		given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk2);
+
+		parameters = this.converter.convert(clientCredentialsGrantRequest);
+
+		encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
+		jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk2.toRSAPublicKey()).build();
+		jwtDecoder.decode(encodedJws);
+	}
+
+	private static RSAKey generateRsaJwk() {
+		KeyPair keyPair;
+		try {
+			KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
+			keyPairGenerator.initialize(2048);
+			keyPair = keyPairGenerator.generateKeyPair();
+		}
+		catch (Exception ex) {
+			throw new IllegalStateException(ex);
+		}
+		RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic();
+		RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate();
+		// @formatter:off
+		return new RSAKey.Builder(publicKey)
+				.privateKey(privateKey)
+				.keyID(UUID.randomUUID().toString())
+				.build();
+		// @formatter:on
+	}
+
 }