2
0
Эх сурвалжийг харах

Support overriding WebClient in ReactiveOidcIdTokenDecoderFactory

Closes gh-14178
Armin Krezović 1 жил өмнө
parent
commit
0041c658de

+ 20 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactory.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -49,6 +49,7 @@ import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder;
 import org.springframework.security.oauth2.jwt.ReactiveJwtDecoderFactory;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
+import org.springframework.web.reactive.function.client.WebClient;
 
 /**
  * A {@link ReactiveJwtDecoderFactory factory} that provides a {@link ReactiveJwtDecoder}
@@ -89,6 +90,8 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod
 	private Function<ClientRegistration, Converter<Map<String, Object>, Map<String, Object>>> claimTypeConverterFactory = (
 			clientRegistration) -> DEFAULT_CLAIM_TYPE_CONVERTER;
 
+	private Function<ClientRegistration, WebClient> webClientFactory = (clientRegistration) -> WebClient.create();
+
 	/**
 	 * Returns the default {@link Converter}'s used for type conversion of claim values
 	 * for an {@link OidcIdToken}.
@@ -165,6 +168,7 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod
 			}
 			return NimbusReactiveJwtDecoder.withJwkSetUri(jwkSetUri)
 				.jwsAlgorithm((SignatureAlgorithm) jwsAlgorithm)
+				.webClient(webClientFactory.apply(clientRegistration))
 				.build();
 		}
 		if (jwsAlgorithm != null && MacAlgorithm.class.isAssignableFrom(jwsAlgorithm.getClass())) {
@@ -241,4 +245,19 @@ public final class ReactiveOidcIdTokenDecoderFactory implements ReactiveJwtDecod
 		this.claimTypeConverterFactory = claimTypeConverterFactory;
 	}
 
+	/**
+	 * Sets the factory that provides a {@link WebClient} used by
+	 * {@link NimbusReactiveJwtDecoder} to coordinate with the authorization servers
+	 * indicated in the <a href="https://tools.ietf.org/html/rfc7517#section-5">JWK
+	 * Set</a> uri.
+	 * @param webClientFactory the factory that provides a {@link WebClient} used by
+	 * {@link NimbusReactiveJwtDecoder}
+	 *
+	 * @since 6.3
+	 */
+	public void setWebClientFactory(Function<ClientRegistration, WebClient> webClientFactory) {
+		Assert.notNull(webClientFactory, "webClientFactory cannot be null");
+		this.webClientFactory = webClientFactory;
+	}
+
 }

+ 19 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/oidc/authentication/ReactiveOidcIdTokenDecoderFactoryTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -34,6 +34,7 @@ import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
 import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
 import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
 import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.web.reactive.function.client.WebClient;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@@ -94,6 +95,12 @@ public class ReactiveOidcIdTokenDecoderFactoryTests {
 			.isThrownBy(() -> this.idTokenDecoderFactory.setClaimTypeConverterFactory(null));
 	}
 
+	@Test
+	public void setWebClientFactoryWhenNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.idTokenDecoderFactory.setWebClientFactory(null));
+	}
+
 	@Test
 	public void createDecoderWhenClientRegistrationNullThenThrowIllegalArgumentException() {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.idTokenDecoderFactory.createDecoder(null));
@@ -176,4 +183,15 @@ public class ReactiveOidcIdTokenDecoderFactoryTests {
 		verify(customClaimTypeConverterFactory).apply(same(clientRegistration));
 	}
 
+	@Test
+	public void createDecoderWhenCustomWebClientFactorySetThenApplied() {
+		Function<ClientRegistration, WebClient> customWebClientFactory = mock(
+				Function.class);
+		this.idTokenDecoderFactory.setWebClientFactory(customWebClientFactory);
+		ClientRegistration clientRegistration = this.registration.build();
+		given(customWebClientFactory.apply(same(clientRegistration)))
+				.willReturn(WebClient.create());
+		this.idTokenDecoderFactory.createDecoder(clientRegistration);
+		verify(customWebClientFactory).apply(same(clientRegistration));
+	}
 }