فهرست منبع

Add setBodyExtractor

Closes gh-10260
bishoy basily 3 سال پیش
والد
کامیت
860690491a

+ 19 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/endpoint/AbstractWebClientReactiveOAuth2AccessTokenResponseClient.java

@@ -27,6 +27,7 @@ import reactor.core.publisher.Mono;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.MediaType;
+import org.springframework.http.ReactiveHttpInputMessage;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@@ -35,6 +36,7 @@ import org.springframework.security.oauth2.core.web.reactive.function.OAuth2Body
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 import org.springframework.util.StringUtils;
+import org.springframework.web.reactive.function.BodyExtractor;
 import org.springframework.web.reactive.function.BodyInserters;
 import org.springframework.web.reactive.function.client.ClientResponse;
 import org.springframework.web.reactive.function.client.WebClient;
@@ -68,6 +70,9 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
 
 	private Converter<T, HttpHeaders> headersConverter = this::populateTokenRequestHeaders;
 
+	private BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor = OAuth2BodyExtractors
+			.oauth2AccessTokenResponse();
+
 	AbstractWebClientReactiveOAuth2AccessTokenResponseClient() {
 	}
 
@@ -204,7 +209,7 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
 	 * @return the token response from the response body.
 	 */
 	private Mono<OAuth2AccessTokenResponse> readTokenResponse(T grantRequest, ClientResponse response) {
-		return response.body(OAuth2BodyExtractors.oauth2AccessTokenResponse())
+		return response.body(this.bodyExtractor)
 				.map((tokenResponse) -> populateTokenResponse(grantRequest, tokenResponse));
 	}
 
@@ -291,4 +296,17 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
 		};
 	}
 
+	/**
+	 * Sets the {@link BodyExtractor} that will be used to decode the
+	 * {@link OAuth2AccessTokenResponse}
+	 * @param bodyExtractor the {@link BodyExtractor} that will be used to decode the
+	 * {@link OAuth2AccessTokenResponse}
+	 * @since 5.6
+	 */
+	public final void setBodyExtractor(
+			BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> bodyExtractor) {
+		Assert.notNull(bodyExtractor, "bodyExtractor cannot be null");
+		this.bodyExtractor = bodyExtractor;
+	}
+
 }

+ 26 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveAuthorizationCodeTokenResponseClientTests.java

@@ -27,11 +27,13 @@ import okhttp3.mockwebserver.RecordedRequest;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import reactor.core.publisher.Mono;
 
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
+import org.springframework.http.ReactiveHttpInputMessage;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -41,11 +43,14 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExch
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
+import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
+import org.springframework.web.reactive.function.BodyExtractor;
 import org.springframework.web.reactive.function.client.WebClient;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
@@ -409,4 +414,25 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
 				.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
 	}
 
+	// gh-10260
+	@Test
+	public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {
+
+		String accessTokenSuccessResponse = "{}";
+
+		WebClientReactiveAuthorizationCodeTokenResponseClient customClient = new WebClientReactiveAuthorizationCodeTokenResponseClient();
+
+		BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
+		OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(extractor.extract(any(), any())).willReturn(Mono.just(response));
+
+		customClient.setBodyExtractor(extractor);
+
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+		OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(authorizationCodeGrantRequest())
+				.block();
+		assertThat(accessTokenResponse.getAccessToken()).isNotNull();
+
+	}
+
 }

+ 27 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveClientCredentialsTokenResponseClientTests.java

@@ -27,21 +27,26 @@ import okhttp3.mockwebserver.RecordedRequest;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import reactor.core.publisher.Mono;
 
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.MediaType;
+import org.springframework.http.ReactiveHttpInputMessage;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
+import org.springframework.web.reactive.function.BodyExtractor;
 import org.springframework.web.reactive.function.client.WebClient;
 import org.springframework.web.reactive.function.client.WebClientResponseException;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
@@ -280,4 +285,26 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
 				.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
 	}
 
+	// gh-10260
+	@Test
+	public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {
+
+		enqueueJson("{}");
+
+		WebClientReactiveClientCredentialsTokenResponseClient customClient = new WebClientReactiveClientCredentialsTokenResponseClient();
+
+		BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
+		OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(extractor.extract(any(), any())).willReturn(Mono.just(response));
+
+		customClient.setBodyExtractor(extractor);
+
+		OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
+				this.clientRegistration.build());
+
+		OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(request).block();
+		assertThat(accessTokenResponse.getAccessToken()).isNotNull();
+
+	}
+
 }

+ 30 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactivePasswordTokenResponseClientTests.java

@@ -25,21 +25,26 @@ import okhttp3.mockwebserver.RecordedRequest;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import reactor.core.publisher.Mono;
 
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
+import org.springframework.http.ReactiveHttpInputMessage;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
+import org.springframework.web.reactive.function.BodyExtractor;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -286,4 +291,29 @@ public class WebClientReactivePasswordTokenResponseClientTests {
 				.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
 	}
 
+	// gh-10260
+	@Test
+	public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {
+
+		String accessTokenSuccessResponse = "{}";
+
+		WebClientReactivePasswordTokenResponseClient customClient = new WebClientReactivePasswordTokenResponseClient();
+
+		BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
+		OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(extractor.extract(any(), any())).willReturn(Mono.just(response));
+
+		customClient.setBodyExtractor(extractor);
+
+		ClientRegistration clientRegistration = this.clientRegistrationBuilder.build();
+		OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration,
+				this.username, this.password);
+
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+
+		OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(passwordGrantRequest).block();
+		assertThat(accessTokenResponse.getAccessToken()).isNotNull();
+
+	}
+
 }

+ 28 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/endpoint/WebClientReactiveRefreshTokenTokenResponseClientTests.java

@@ -25,11 +25,13 @@ import okhttp3.mockwebserver.RecordedRequest;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import reactor.core.publisher.Mono;
 
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
+import org.springframework.http.ReactiveHttpInputMessage;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
@@ -39,10 +41,13 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
+import org.springframework.web.reactive.function.BodyExtractor;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -289,4 +294,27 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
 				.isEqualTo("Basic Y2xpZW50LWlkOmNsaWVudC1zZWNyZXQ=");
 	}
 
+	// gh-10260
+	@Test
+	public void getTokenResponseWhenSuccessCustomResponseThenReturnAccessTokenResponse() {
+
+		String accessTokenSuccessResponse = "{}";
+
+		WebClientReactiveRefreshTokenTokenResponseClient customClient = new WebClientReactiveRefreshTokenTokenResponseClient();
+
+		BodyExtractor<Mono<OAuth2AccessTokenResponse>, ReactiveHttpInputMessage> extractor = mock(BodyExtractor.class);
+		OAuth2AccessTokenResponse response = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(extractor.extract(any(), any())).willReturn(Mono.just(response));
+
+		customClient.setBodyExtractor(extractor);
+
+		OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
+				this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
+
+		this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
+		OAuth2AccessTokenResponse accessTokenResponse = customClient.getTokenResponse(refreshTokenGrantRequest).block();
+		assertThat(accessTokenResponse.getAccessToken()).isNotNull();
+
+	}
+
 }