Quellcode durchsuchen

ServletOAuth2AuthorizedClientExchangeFilterFunction supports chaining

Fixes gh-6483
Joe Grandja vor 6 Jahren
Ursprung
Commit
0c27f64338

+ 113 - 6
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -16,6 +16,9 @@
 
 package org.springframework.security.oauth2.client.web.reactive.function.client;
 
+import org.reactivestreams.Subscription;
+import org.springframework.beans.factory.DisposableBean;
+import org.springframework.beans.factory.InitializingBean;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
@@ -44,8 +47,12 @@ import org.springframework.web.reactive.function.client.ClientResponse;
 import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
 import org.springframework.web.reactive.function.client.ExchangeFunction;
 import org.springframework.web.reactive.function.client.WebClient;
+import reactor.core.CoreSubscriber;
+import reactor.core.publisher.Hooks;
 import reactor.core.publisher.Mono;
+import reactor.core.publisher.Operators;
 import reactor.core.scheduler.Schedulers;
+import reactor.util.context.Context;
 
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
@@ -98,7 +105,9 @@ import static org.springframework.security.oauth2.core.web.reactive.function.OAu
  * @author Rob Winch
  * @since 5.1
  */
-public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
+public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
+		implements ExchangeFilterFunction, InitializingBean, DisposableBean {
+
 	/**
 	 * The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
 	 */
@@ -108,6 +117,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 	private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
 	private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
 
+	private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName();
+
 	private Clock clock = Clock.systemUTC();
 
 	private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
@@ -123,7 +134,8 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 
 	private String defaultClientRegistrationId;
 
-	public ServletOAuth2AuthorizedClientExchangeFilterFunction() {}
+	public ServletOAuth2AuthorizedClientExchangeFilterFunction() {
+	}
 
 	public ServletOAuth2AuthorizedClientExchangeFilterFunction(
 			ClientRegistrationRepository clientRegistrationRepository,
@@ -132,6 +144,16 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 		this.authorizedClientRepository = authorizedClientRepository;
 	}
 
+	@Override
+	public void afterPropertiesSet() throws Exception {
+		Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.lift((s, sub) -> createRequestContextSubscriber(sub)));
+	}
+
+	@Override
+	public void destroy() throws Exception {
+		Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY);
+	}
+
 	/**
 	 * Sets the {@link OAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
 	 * client_credentials grant.
@@ -266,15 +288,36 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 
 	@Override
 	public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
-		Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
-				.map(OAuth2AuthorizedClient.class::cast);
-		return Mono.justOrEmpty(attribute)
-				.flatMap(authorizedClient -> authorizedClient(request, next, authorizedClient))
+		return Mono.just(request)
+				.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
+				.switchIfEmpty(mergeRequestAttributesFromContext(request))
+				.filter(req -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
+				.flatMap(req -> authorizedClient(req, next, getOAuth2AuthorizedClient(req.attributes())))
 				.map(authorizedClient -> bearer(request, authorizedClient))
 				.flatMap(next::exchange)
 				.switchIfEmpty(next.exchange(request));
 	}
 
+	private Mono<ClientRequest> mergeRequestAttributesFromContext(ClientRequest request) {
+		return Mono.just(ClientRequest.from(request))
+				.flatMap(builder -> Mono.subscriberContext()
+						.map(ctx -> builder.attributes(attrs -> populateRequestAttributes(attrs, ctx))))
+				.map(ClientRequest.Builder::build);
+	}
+
+	private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
+		if (ctx.hasKey(HTTP_SERVLET_REQUEST_ATTR_NAME)) {
+			attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, ctx.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
+		}
+		if (ctx.hasKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
+			attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, ctx.get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
+		}
+		if (ctx.hasKey(AUTHENTICATION_ATTR_NAME)) {
+			attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, ctx.get(AUTHENTICATION_ATTR_NAME));
+		}
+		populateDefaultOAuth2AuthorizedClient(attrs);
+	}
+
 	private void populateDefaultRequestResponse(Map<String, Object> attrs) {
 		if (attrs.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && attrs.containsKey(
 				HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
@@ -435,6 +478,19 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 					.build();
 	}
 
+	private <T> CoreSubscriber<T> createRequestContextSubscriber(CoreSubscriber<T> delegate) {
+		HttpServletRequest request = null;
+		HttpServletResponse response = null;
+		ServletRequestAttributes requestAttributes =
+				(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
+		if (requestAttributes != null) {
+			request = requestAttributes.getRequest();
+			response = requestAttributes.getResponse();
+		}
+		Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+		return new RequestContextSubscriber<>(delegate, request, response, authentication);
+	}
+
 	private static BodyInserters.FormInserter<String> refreshTokenBody(String refreshToken) {
 		return BodyInserters
 				.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
@@ -508,4 +564,55 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
 			return new UnsupportedOperationException("Not Supported");
 		}
 	}
+
+	private static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
+		private static final String CONTEXT_DEFAULTED_ATTR_NAME = RequestContextSubscriber.class.getName().concat(".CONTEXT_DEFAULTED_ATTR_NAME");
+		private final CoreSubscriber<T> delegate;
+		private final HttpServletRequest request;
+		private final HttpServletResponse response;
+		private final Authentication authentication;
+
+		private RequestContextSubscriber(CoreSubscriber<T> delegate,
+											HttpServletRequest request,
+											HttpServletResponse response,
+											Authentication authentication) {
+			this.delegate = delegate;
+			this.request = request;
+			this.response = response;
+			this.authentication = authentication;
+		}
+
+		@Override
+		public Context currentContext() {
+			Context context = this.delegate.currentContext();
+			if (context.hasKey(CONTEXT_DEFAULTED_ATTR_NAME)) {
+				return context;
+			}
+			return Context.of(
+					CONTEXT_DEFAULTED_ATTR_NAME, Boolean.TRUE,
+					HTTP_SERVLET_REQUEST_ATTR_NAME, this.request,
+					HTTP_SERVLET_RESPONSE_ATTR_NAME, this.response,
+					AUTHENTICATION_ATTR_NAME, this.authentication);
+		}
+
+		@Override
+		public void onSubscribe(Subscription s) {
+			this.delegate.onSubscribe(s);
+		}
+
+		@Override
+		public void onNext(T t) {
+			this.delegate.onNext(t);
+		}
+
+		@Override
+		public void onError(Throwable t) {
+			this.delegate.onError(t);
+		}
+
+		@Override
+		public void onComplete() {
+			this.delegate.onComplete();
+		}
+	}
 }

+ 118 - 6
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -74,14 +74,11 @@ import java.util.Map;
 import java.util.Optional;
 import java.util.function.Consumer;
 
-import static org.assertj.core.api.Assertions.*;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.never;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.verifyZeroInteractions;
-import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.*;
 import static org.springframework.http.HttpMethod.GET;
 import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.*;
 
@@ -647,6 +644,121 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		assertThat(getBody(request0)).isEmpty();
 	}
 
+	// gh-6483
+	@Test
+	public void filterWhenChainedThenDefaultsStillAvailable() throws Exception {
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
+				this.clientRegistrationRepository, this.authorizedClientRepository);
+		this.function.afterPropertiesSet();			// Hooks.onLastOperator() initialized
+		this.function.setDefaultOAuth2AuthorizedClient(true);
+
+		MockHttpServletRequest servletRequest = new MockHttpServletRequest();
+		MockHttpServletResponse servletResponse = new MockHttpServletResponse();
+		RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
+
+		OAuth2User user = mock(OAuth2User.class);
+		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
+				user, authorities, this.registration.getRegistrationId());
+		SecurityContextHolder.getContext().setAuthentication(authentication);
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				this.registration, "principalName", this.accessToken);
+		when(this.authorizedClientRepository.loadAuthorizedClient(eq(authentication.getAuthorizedClientRegistrationId()),
+				eq(authentication), eq(servletRequest))).thenReturn(authorizedClient);
+
+		// Default request attributes set
+		final ClientRequest request1 = ClientRequest.create(GET, URI.create("https://example1.com"))
+				.attributes(attrs -> attrs.putAll(getDefaultRequestAttributes())).build();
+
+		// Default request attributes NOT set
+		final ClientRequest request2 = ClientRequest.create(GET, URI.create("https://example2.com")).build();
+
+		this.function.filter(request1, this.exchange)
+				.flatMap(response -> this.function.filter(request2, this.exchange))
+				.block();
+
+		this.function.destroy();		// Hooks.onLastOperator() released
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(2);
+
+		ClientRequest request = requests.get(0);
+		assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
+		assertThat(request.url().toASCIIString()).isEqualTo("https://example1.com");
+		assertThat(request.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request)).isEmpty();
+
+		request = requests.get(1);
+		assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer token-0");
+		assertThat(request.url().toASCIIString()).isEqualTo("https://example2.com");
+		assertThat(request.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request)).isEmpty();
+	}
+
+	@Test
+	public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvailable() throws Exception {
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
+				this.clientRegistrationRepository, this.authorizedClientRepository);
+//		this.function.afterPropertiesSet();		// Hooks.onLastOperator() NOT initialized
+		this.function.setDefaultOAuth2AuthorizedClient(true);
+
+		MockHttpServletRequest servletRequest = new MockHttpServletRequest();
+		MockHttpServletResponse servletResponse = new MockHttpServletResponse();
+		RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
+
+		OAuth2User user = mock(OAuth2User.class);
+		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
+				user, authorities, this.registration.getRegistrationId());
+		SecurityContextHolder.getContext().setAuthentication(authentication);
+
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
+
+		this.function.filter(request, this.exchange).block();
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(1);
+
+		request = requests.get(0);
+		assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
+		assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
+		assertThat(request.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request)).isEmpty();
+	}
+
+	@Test
+	public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsNotAvailable() throws Exception {
+		this.function = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
+				this.clientRegistrationRepository, this.authorizedClientRepository);
+		this.function.afterPropertiesSet();			// Hooks.onLastOperator() initialized
+		this.function.destroy();					// Hooks.onLastOperator() released
+		this.function.setDefaultOAuth2AuthorizedClient(true);
+
+		MockHttpServletRequest servletRequest = new MockHttpServletRequest();
+		MockHttpServletResponse servletResponse = new MockHttpServletResponse();
+		RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
+
+		OAuth2User user = mock(OAuth2User.class);
+		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
+		OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
+				user, authorities, this.registration.getRegistrationId());
+		SecurityContextHolder.getContext().setAuthentication(authentication);
+
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
+
+		this.function.filter(request, this.exchange).block();
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(1);
+
+		request = requests.get(0);
+		assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
+		assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
+		assertThat(request.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request)).isEmpty();
+	}
+
 	private static String getBody(ClientRequest request) {
 		final List<HttpMessageWriter<?>> messageWriters = new ArrayList<>();
 		messageWriters.add(new EncoderHttpMessageWriter<>(new ByteBufferEncoder()));