|
@@ -16,7 +16,12 @@
|
|
|
|
|
|
package org.springframework.security.oauth2.client.endpoint;
|
|
|
|
|
|
+import java.security.KeyPair;
|
|
|
+import java.security.KeyPairGenerator;
|
|
|
+import java.security.interfaces.RSAPrivateKey;
|
|
|
+import java.security.interfaces.RSAPublicKey;
|
|
|
import java.util.Collections;
|
|
|
+import java.util.UUID;
|
|
|
import java.util.function.Function;
|
|
|
|
|
|
import com.nimbusds.jose.jwk.JWK;
|
|
@@ -42,6 +47,7 @@ import static org.assertj.core.api.Assertions.assertThat;
|
|
|
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
|
|
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
|
+import static org.mockito.ArgumentMatchers.eq;
|
|
|
import static org.mockito.BDDMockito.given;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
import static org.mockito.Mockito.verifyNoInteractions;
|
|
@@ -172,4 +178,54 @@ public class NimbusJwtClientAuthenticationParametersConverterTests {
|
|
|
assertThat(jws.getExpiresAt()).isNotNull();
|
|
|
}
|
|
|
|
|
|
+ // gh-9814
|
|
|
+ @Test
|
|
|
+ public void convertWhenClientKeyChangesThenNewKeyUsed() throws Exception {
|
|
|
+ // @formatter:off
|
|
|
+ ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
|
|
|
+ .clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
|
|
|
+ .build();
|
|
|
+ // @formatter:on
|
|
|
+
|
|
|
+ RSAKey rsaJwk1 = TestJwks.DEFAULT_RSA_JWK;
|
|
|
+ given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk1);
|
|
|
+
|
|
|
+ OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
|
|
|
+ clientRegistration);
|
|
|
+ MultiValueMap<String, String> parameters = this.converter.convert(clientCredentialsGrantRequest);
|
|
|
+
|
|
|
+ String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
|
|
|
+ NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk1.toRSAPublicKey()).build();
|
|
|
+ jwtDecoder.decode(encodedJws);
|
|
|
+
|
|
|
+ RSAKey rsaJwk2 = generateRsaJwk();
|
|
|
+ given(this.jwkResolver.apply(eq(clientRegistration))).willReturn(rsaJwk2);
|
|
|
+
|
|
|
+ parameters = this.converter.convert(clientCredentialsGrantRequest);
|
|
|
+
|
|
|
+ encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
|
|
|
+ jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk2.toRSAPublicKey()).build();
|
|
|
+ jwtDecoder.decode(encodedJws);
|
|
|
+ }
|
|
|
+
|
|
|
+ private static RSAKey generateRsaJwk() {
|
|
|
+ KeyPair keyPair;
|
|
|
+ try {
|
|
|
+ KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("RSA");
|
|
|
+ keyPairGenerator.initialize(2048);
|
|
|
+ keyPair = keyPairGenerator.generateKeyPair();
|
|
|
+ }
|
|
|
+ catch (Exception ex) {
|
|
|
+ throw new IllegalStateException(ex);
|
|
|
+ }
|
|
|
+ RSAPublicKey publicKey = (RSAPublicKey) keyPair.getPublic();
|
|
|
+ RSAPrivateKey privateKey = (RSAPrivateKey) keyPair.getPrivate();
|
|
|
+ // @formatter:off
|
|
|
+ return new RSAKey.Builder(publicKey)
|
|
|
+ .privateKey(privateKey)
|
|
|
+ .keyID(UUID.randomUUID().toString())
|
|
|
+ .build();
|
|
|
+ // @formatter:on
|
|
|
+ }
|
|
|
+
|
|
|
}
|