Pārlūkot izejas kodu

Author: Shraiysh Vaishay cs17btech11050@iith.ac.in

Add WebClientReactiveAuthorizationCodeTokenResponseClient.setWebClient

Fixes gh-6182
shraiysh 6 gadi atpakaļ
vecāks
revīzija
e25bea2cf7

+ 9 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClient.java

@@ -23,6 +23,7 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExch
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 import org.springframework.web.reactive.function.BodyInserters;
 import org.springframework.web.reactive.function.client.WebClient;
+import org.springframework.util.Assert;
 import reactor.core.publisher.Mono;
 
 import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
@@ -48,11 +49,18 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClient implements Re
 	private WebClient webClient = WebClient.builder()
 			.build();
 
+	/**
+	 * @param webClient the webClient to set
+	 */
+	public void setWebClient(WebClient webClient) {
+		Assert.notNull(webClient, "webClient cannot be null");
+		this.webClient = webClient;
+	}
+
 	@Override
 	public Mono<OAuth2AccessTokenResponse> getTokenResponse(OAuth2AuthorizationCodeGrantRequest authorizationGrantRequest) {
 		return Mono.defer(() -> {
 			ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration();
-
 			OAuth2AuthorizationExchange authorizationExchange = authorizationGrantRequest.getAuthorizationExchange();
 			String tokenUri = clientRegistration.getProviderDetails().getTokenUri();
 			BodyInserters.FormInserter<String> body = body(authorizationExchange);

+ 29 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java

@@ -32,11 +32,13 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
+import org.springframework.web.reactive.function.client.WebClient;
 
 import java.time.Instant;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.Mockito.*;
 
 /**
  * @author Rob Winch
@@ -259,4 +261,31 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
 			.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
 			.setBody(json);
 	}
+
+	@Test(expected=IllegalArgumentException.class)
+	public void setWebClientNullThenIllegalArgumentException(){
+		tokenResponseClient.setWebClient(null);
+	}
+
+	@Test
+	public void setCustomWebClientThenCustomWebClientIsUsed() {
+		WebClient customClient = mock(WebClient.class);
+		when(customClient.post()).thenReturn(WebClient.builder().build().post());
+
+		tokenResponseClient.setWebClient(customClient);
+
+		String accessTokenSuccessResponse = "{\n" +
+				"	\"access_token\": \"access-token-1234\",\n" +
+				"   \"token_type\": \"bearer\",\n" +
+				"   \"expires_in\": \"3600\",\n" +
+				"   \"scope\": \"openid profile\"\n" +
+				"}\n";
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+
+		this.clientRegistration.scope("openid", "profile", "email", "address");
+
+		OAuth2AccessTokenResponse response = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block();
+
+		verify(customClient, atLeastOnce()).post();
+	}
 }