Преглед изворни кода

ServerOAuth2AuthorizedClientExchangeFilterFunction uses ServerOAuth2AuthorizedClientRepository

Issue: gh-4921
Rob Winch пре 7 година
родитељ
комит
5bcbb1c40f

+ 43 - 11
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -24,8 +24,8 @@ import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
-import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.util.Assert;
@@ -34,6 +34,7 @@ import org.springframework.web.reactive.function.client.ClientRequest;
 import org.springframework.web.reactive.function.client.ClientResponse;
 import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
 import org.springframework.web.reactive.function.client.ExchangeFunction;
+import org.springframework.web.server.ServerWebExchange;
 import reactor.core.publisher.Mono;
 
 import java.net.URI;
@@ -60,16 +61,22 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 	 */
 	private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
 
+	/**
+	 * The request attribute name used to locate the {@link org.springframework.web.server.ServerWebExchange}.
+	 */
+	private static final String SERVER_WEB_EXCHANGE_ATTR_NAME = ServerWebExchange.class.getName();
+
 	private Clock clock = Clock.systemUTC();
 
 	private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
 
-	private ReactiveOAuth2AuthorizedClientService authorizedClientService;
+	private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
 
 	public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
 
-	public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveOAuth2AuthorizedClientService authorizedClientService) {
-		this.authorizedClientService = authorizedClientService;
+	public ServerOAuth2AuthorizedClientExchangeFilterFunction(
+			ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
+		this.authorizedClientRepository = authorizedClientRepository;
 	}
 
 	/**
@@ -78,7 +85,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 	 *
 	 * <pre>
 	 * WebClient webClient = WebClient.builder()
-	 *    .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientService))
+	 *    .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository))
 	 *    .build();
 	 * Mono<String> response = webClient
 	 *    .get()
@@ -110,6 +117,30 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 		return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
 	}
 
+
+	/**
+	 * Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for
+	 * providing the Bearer Token. Example usage:
+	 *
+	 * <pre>
+	 * WebClient webClient = WebClient.builder()
+	 *    .filter(new OAuth2AuthorizedClientExchangeFilterFunction(authorizedClientRepository))
+	 *    .build();
+	 * Mono<String> response = webClient
+	 *    .get()
+	 *    .uri(uri)
+	 *    .attributes(serverWebExchange(serverWebExchange))
+	 *    // ...
+	 *    .retrieve()
+	 *    .bodyToMono(String.class);
+	 * </pre>
+	 * @param serverWebExchange the {@link ServerWebExchange} to use
+	 * @return the {@link Consumer} to populate the client request attributes
+	 */
+	public static Consumer<Map<String, Object>> serverWebExchange(ServerWebExchange serverWebExchange) {
+		return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange);
+	}
+
 	/**
 	 * An access token will be considered expired by comparing its expiration to now +
 	 * this skewed Duration. The default is 1 minute.
@@ -124,22 +155,23 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 	public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
 		Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
 				.map(OAuth2AuthorizedClient.class::cast);
+		ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
 		return Mono.justOrEmpty(attribute)
-				.flatMap(authorizedClient -> authorizedClient(next, authorizedClient))
+				.flatMap(authorizedClient -> authorizedClient(next, authorizedClient, exchange))
 				.map(authorizedClient -> bearer(request, authorizedClient))
 				.flatMap(next::exchange)
 				.switchIfEmpty(next.exchange(request));
 	}
 
-	private Mono<OAuth2AuthorizedClient> authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
+	private Mono<OAuth2AuthorizedClient> authorizedClient(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
 		if (shouldRefresh(authorizedClient)) {
-			return refreshAuthorizedClient(next, authorizedClient);
+			return refreshAuthorizedClient(next, authorizedClient, exchange);
 		}
 		return Mono.just(authorizedClient);
 	}
 
 	private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction next,
-			OAuth2AuthorizedClient authorizedClient) {
+			OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
 		ClientRegistration clientRegistration = authorizedClient
 				.getClientRegistration();
 		String tokenUri = clientRegistration
@@ -155,12 +187,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 				.flatMap(result -> ReactiveSecurityContextHolder.getContext()
 						.map(SecurityContext::getAuthentication)
 						.defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName()))
-						.flatMap(principal -> this.authorizedClientService.saveAuthorizedClient(result, principal))
+						.flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(result, principal, exchange))
 						.thenReturn(result));
 	}
 
 	private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
-		if (this.authorizedClientService == null) {
+		if (this.authorizedClientRepository == null) {
 			return false;
 		}
 		OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();

+ 10 - 10
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -36,9 +36,9 @@ import org.springframework.mock.http.client.reactive.MockClientHttpRequest;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
-import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@@ -70,7 +70,7 @@ import static org.springframework.security.oauth2.client.web.reactive.function.c
 @RunWith(MockitoJUnitRunner.class)
 public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 	@Mock
-	private ReactiveOAuth2AuthorizedClientService authorizedClientService;
+	private ServerOAuth2AuthorizedClientRepository auth2AuthorizedClientRepository;
 
 	private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
 
@@ -124,7 +124,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void filterWhenRefreshRequiredThenRefresh() {
-		when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty());
+		when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
 		OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
 				.tokenType(OAuth2AccessToken.TokenType.BEARER)
 				.expiresIn(3600)
@@ -139,7 +139,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 				this.accessToken.getTokenValue(),
 				issuedAt,
 				accessTokenExpiresAt);
-		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
 
 		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@@ -153,7 +153,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 				.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
 				.block();
 
-		verify(this.authorizedClientService).saveAuthorizedClient(any(), eq(authentication));
+		verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
 
 		List<ClientRequest> requests = this.exchange.getRequests();
 		assertThat(requests).hasSize(2);
@@ -173,7 +173,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void filterWhenRefreshRequiredAndEmptyReactiveSecurityContextThenSaved() {
-		when(this.authorizedClientService.saveAuthorizedClient(any(), any())).thenReturn(Mono.empty());
+		when(this.auth2AuthorizedClientRepository.saveAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
 		OAuth2AccessTokenResponse response = OAuth2AccessTokenResponse.withToken("token-1")
 				.tokenType(OAuth2AccessToken.TokenType.BEARER)
 				.expiresIn(3600)
@@ -188,7 +188,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 				this.accessToken.getTokenValue(),
 				issuedAt,
 				accessTokenExpiresAt);
-		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
 
 		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@@ -200,7 +200,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		this.function.filter(request, this.exchange)
 				.block();
 
-		verify(this.authorizedClientService).saveAuthorizedClient(any(), any());
+		verify(this.auth2AuthorizedClientRepository).saveAuthorizedClient(any(), any(), any());
 
 		List<ClientRequest> requests = this.exchange.getRequests();
 		assertThat(requests).hasSize(2);
@@ -220,7 +220,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
-		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
 
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
 				"principalName", this.accessToken);
@@ -242,7 +242,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 
 	@Test
 	public void filterWhenNotExpiredThenShouldRefreshFalse() {
-		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.authorizedClientService);
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.auth2AuthorizedClientRepository);
 
 		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,