浏览代码

ServerOAuth2AuthorizedClientExchangeFilterFunction default ServerWebExchange

Leverage ServerWebExchange established by ServerWebExchangeReactorContextWebFilter

Issue: gh-4921
Rob Winch 7 年之前
父节点
当前提交
23726abb1e

+ 18 - 6
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -16,6 +16,7 @@
 
 
 package org.springframework.security.oauth2.client.web.reactive.function.client;
 package org.springframework.security.oauth2.client.web.reactive.function.client;
 
 
+import com.sun.security.ntlm.Server;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
 import org.springframework.http.MediaType;
@@ -211,14 +212,25 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 
 
 	@Override
 	@Override
 	public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
 	public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
-		ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
 		return authorizedClient(request)
 		return authorizedClient(request)
-				.flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, exchange))
+				.flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, request))
 				.map(authorizedClient -> bearer(request, authorizedClient))
 				.map(authorizedClient -> bearer(request, authorizedClient))
 				.flatMap(next::exchange)
 				.flatMap(next::exchange)
 				.switchIfEmpty(next.exchange(request));
 				.switchIfEmpty(next.exchange(request));
 	}
 	}
 
 
+	private Mono<ServerWebExchange> serverWebExchange(ClientRequest request) {
+		ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
+		return Mono.justOrEmpty(exchange)
+				.switchIfEmpty(serverWebExchange());
+	}
+
+	private Mono<ServerWebExchange> serverWebExchange() {
+		return Mono.subscriberContext()
+			.filter(c -> c.hasKey(ServerWebExchange.class))
+			.map(c -> c.get(ServerWebExchange.class));
+	}
+
 	private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
 	private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
 		Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
 		Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
 				.map(OAuth2AuthorizedClient.class::cast);
 				.map(OAuth2AuthorizedClient.class::cast);
@@ -231,10 +243,9 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 			return Mono.empty();
 			return Mono.empty();
 		}
 		}
 
 
-		ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
 		return currentAuthentication()
 		return currentAuthentication()
 			.flatMap(principal -> clientRegistrationId(request, principal)
 			.flatMap(principal -> clientRegistrationId(request, principal)
-					.flatMap(clientRegistrationId -> loadAuthorizedClient(clientRegistrationId, exchange, principal))
+					.flatMap(clientRegistrationId -> serverWebExchange(request).flatMap(exchange -> loadAuthorizedClient(clientRegistrationId, exchange, principal)))
 			);
 			);
 	}
 	}
 
 
@@ -289,9 +300,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 			});
 			});
 	}
 	}
 
 
-	private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
+	private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ClientRequest request) {
 		if (shouldRefresh(authorizedClient)) {
 		if (shouldRefresh(authorizedClient)) {
-			return refreshAuthorizedClient(next, authorizedClient, exchange);
+			return serverWebExchange(request)
+				.flatMap(exchange -> refreshAuthorizedClient(next, authorizedClient, exchange));
 		}
 		}
 		return Mono.just(authorizedClient);
 		return Mono.just(authorizedClient);
 	}
 	}

+ 29 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -49,7 +49,9 @@ import org.springframework.security.oauth2.core.user.DefaultOAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.security.oauth2.core.user.OAuth2User;
 import org.springframework.web.reactive.function.BodyInserter;
 import org.springframework.web.reactive.function.BodyInserter;
 import org.springframework.web.reactive.function.client.ClientRequest;
 import org.springframework.web.reactive.function.client.ClientRequest;
+import org.springframework.web.server.ServerWebExchange;
 import reactor.core.publisher.Mono;
 import reactor.core.publisher.Mono;
+import reactor.util.context.Context;
 
 
 import java.net.URI;
 import java.net.URI;
 import java.time.Duration;
 import java.time.Duration;
@@ -83,6 +85,9 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 	@Mock
 	@Mock
 	private ReactiveClientRegistrationRepository clientRegistrationRepository;
 	private ReactiveClientRegistrationRepository clientRegistrationRepository;
 
 
+	@Mock
+	private ServerWebExchange serverWebExchange;
+
 	private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
 	private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
 
 
 	private MockExchangeFunction exchange = new MockExchangeFunction();
 	private MockExchangeFunction exchange = new MockExchangeFunction();
@@ -352,6 +357,30 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository);
 		verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository);
 	}
 	}
 
 
+	@Test
+	public void filterWhenClientRegistrationIdAndServerWebExchangeFromContextThenServerWebExchangeFromContext() {
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
+
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
+				"principalName", this.accessToken, refreshToken);
+		when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.just(authorizedClient));
+		when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(this.registration));
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(clientRegistrationId(this.registration.getRegistrationId()))
+				.build();
+
+		this.function.filter(request, this.exchange)
+				.subscriberContext(serverWebExchange())
+				.block();
+
+		verify(this.authorizedClientRepository).loadAuthorizedClient(eq(this.registration.getRegistrationId()), any(), eq(this.serverWebExchange));
+	}
+
+	private Context serverWebExchange() {
+		return Context.of(ServerWebExchange.class, this.serverWebExchange);
+	}
+
 	private static String getBody(ClientRequest request) {
 	private static String getBody(ClientRequest request) {
 		final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
 		final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
 		messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));
 		messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));