浏览代码

Add WebClientReactiveClientCredentialsTokenResponseClient setWebClient

Added the ability to specify a custom WebClient in
WebClientReactiveClientCredentialsTokenResponseClient.
Also added testing to ensure the custom WebClient is not null and is
used.

Fixes: gh-6051
jer051 6 年之前
父节点
当前提交
fdc81822ec

+ 6 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClient.java

@@ -21,6 +21,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 import org.springframework.util.StringUtils;
 import org.springframework.web.reactive.function.BodyInserters;
@@ -112,4 +113,9 @@ public class WebClientReactiveClientCredentialsTokenResponseClient implements Re
 		}
 		return body;
 	}
+
+	public void setWebClient(WebClient webClient) {
+		Assert.notNull(webClient, "webClient cannot be null");
+		this.webClient = webClient;
+	}
 }

+ 29 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java

@@ -28,9 +28,11 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.web.reactive.function.client.WebClient;
 import org.springframework.web.reactive.function.client.WebClientResponseException;
 
-import static org.assertj.core.api.Assertions.*;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.*;
 
 /**
  * @author Rob Winch
@@ -55,6 +57,7 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
 
 	@After
 	public void cleanup() throws Exception {
+		validateMockitoUsage();
 		this.server.shutdown();
 	}
 
@@ -117,6 +120,31 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
 		assertThat(response.getAccessToken().getScopes()).isEqualTo(registration.getScopes());
 	}
 
+	@Test(expected=IllegalArgumentException.class)
+	public void setWebClientNullThenIllegalArgumentException(){
+		client.setWebClient(null);
+	}
+
+	@Test
+	public void setWebClientCustomThenCustomClientIsUsed() {
+		WebClient customClient = mock(WebClient.class);
+		when(customClient.post()).thenReturn(WebClient.builder().build().post());
+
+		this.client.setWebClient(customClient);
+		ClientRegistration registration = this.clientRegistration.build();
+		enqueueJson("{\n"
+				+ "  \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+				+ "  \"token_type\":\"bearer\",\n"
+				+ "  \"expires_in\":3600,\n"
+				+ "  \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+				+ "}");
+		OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration);
+
+		OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
+
+		verify(customClient, atLeastOnce()).post();
+	}
+
 	@Test(expected = WebClientResponseException.class)
 	// gh-6089
 	public void getTokenResponseWhenInvalidResponse() throws WebClientResponseException {