瀏覽代碼

Align Servlet ExchangeFilterFunction CoreSubscriber

Fixes gh-7422
Joe Grandja 6 年之前
父節點
當前提交
2a5bd6e719

+ 23 - 11
config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ImportSelector.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2019 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,23 +15,28 @@
  */
 package org.springframework.security.config.annotation.web.configuration;
 
-import java.util.ArrayList;
-import java.util.List;
-
 import org.springframework.context.annotation.ImportSelector;
 import org.springframework.core.type.AnnotationMetadata;
 import org.springframework.util.ClassUtils;
 
+import java.util.ArrayList;
+import java.util.List;
+
 /**
- * Used by {@link EnableWebSecurity} to conditionally import {@link OAuth2ClientConfiguration}
- * when the {@code spring-security-oauth2-client} module is present on the classpath and
- * {@link OAuth2ResourceServerConfiguration} when the {@code spring-security-oauth2-resource-server}
- * module is on the classpath.
+ * Used by {@link EnableWebSecurity} to conditionally import:
+ *
+ * <ul>
+ * 	<li>{@link OAuth2ClientConfiguration} when the {@code spring-security-oauth2-client} module is present on the classpath</li>
+ * 	<li>{@link SecurityReactorContextConfiguration} when the {@code spring-webflux} and {@code spring-security-oauth2-client} module is present on the classpath</li>
+ * 	<li>{@link OAuth2ResourceServerConfiguration} when the {@code spring-security-oauth2-resource-server} module is present on the classpath</li>
+ * </ul>
  *
  * @author Joe Grandja
  * @author Josh Cummings
  * @since 5.1
  * @see OAuth2ClientConfiguration
+ * @see SecurityReactorContextConfiguration
+ * @see OAuth2ResourceServerConfiguration
  */
 final class OAuth2ImportSelector implements ImportSelector {
 
@@ -39,13 +44,20 @@ final class OAuth2ImportSelector implements ImportSelector {
 	public String[] selectImports(AnnotationMetadata importingClassMetadata) {
 		List<String> imports = new ArrayList<>();
 
-		if (ClassUtils.isPresent(
-			"org.springframework.security.oauth2.client.registration.ClientRegistration", getClass().getClassLoader())) {
+		boolean oauth2ClientPresent = ClassUtils.isPresent(
+				"org.springframework.security.oauth2.client.registration.ClientRegistration", getClass().getClassLoader());
+		if (oauth2ClientPresent) {
 			imports.add("org.springframework.security.config.annotation.web.configuration.OAuth2ClientConfiguration");
 		}
 
+		boolean webfluxPresent = ClassUtils.isPresent(
+				"org.springframework.web.reactive.function.client.ExchangeFilterFunction", getClass().getClassLoader());
+		if (webfluxPresent && oauth2ClientPresent) {
+			imports.add("org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration");
+		}
+
 		if (ClassUtils.isPresent(
-			"org.springframework.security.oauth2.server.resource.BearerTokenError", getClass().getClassLoader())) {
+				"org.springframework.security.oauth2.server.resource.BearerTokenError", getClass().getClassLoader())) {
 			imports.add("org.springframework.security.config.annotation.web.configuration.OAuth2ResourceServerConfiguration");
 		}
 

+ 165 - 0
config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java

@@ -0,0 +1,165 @@
+/*
+ * Copyright 2002-2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.config.annotation.web.configuration;
+
+import org.reactivestreams.Publisher;
+import org.reactivestreams.Subscription;
+import org.springframework.beans.factory.DisposableBean;
+import org.springframework.beans.factory.InitializingBean;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.util.CollectionUtils;
+import org.springframework.web.context.request.RequestContextHolder;
+import org.springframework.web.context.request.ServletRequestAttributes;
+import reactor.core.CoreSubscriber;
+import reactor.core.publisher.Hooks;
+import reactor.core.publisher.Operators;
+import reactor.util.context.Context;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.function.Function;
+
+import static org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES;
+
+/**
+ * {@link Configuration} that (potentially) adds a "decorating" {@code Publisher}
+ * for the last operator created in every {@code Mono} or {@code Flux}.
+ *
+ * <p>
+ * The {@code Publisher} is solely responsible for adding
+ * the current {@code HttpServletRequest}, {@code HttpServletResponse} and {@code Authentication}
+ * to the Reactor {@code Context} so that it's accessible in every flow, if required.
+ *
+ * @author Joe Grandja
+ * @since 5.2
+ * @see OAuth2ImportSelector
+ */
+@Configuration(proxyBeanMethods = false)
+class SecurityReactorContextConfiguration {
+
+	@Bean
+	SecurityReactorContextSubscriberRegistrar securityReactorContextSubscriberRegistrar() {
+		return new SecurityReactorContextSubscriberRegistrar();
+	}
+
+	static class SecurityReactorContextSubscriberRegistrar implements InitializingBean, DisposableBean {
+		private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR";
+
+		@Override
+		public void afterPropertiesSet() throws Exception {
+			Function<? super Publisher<Object>, ? extends Publisher<Object>> lifter =
+					Operators.liftPublisher((pub, sub) -> createSubscriberIfNecessary(sub));
+
+			Hooks.onLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY, pub -> {
+				if (CollectionUtils.isEmpty(getContextAttributes())) {
+					// No need to decorate so return original Publisher
+					return pub;
+				}
+				return lifter.apply(pub);
+			});
+		}
+
+		@Override
+		public void destroy() throws Exception {
+			Hooks.resetOnLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY);
+		}
+
+		<T> CoreSubscriber<T> createSubscriberIfNecessary(CoreSubscriber<T> delegate) {
+			if (delegate.currentContext().hasKey(SECURITY_CONTEXT_ATTRIBUTES)) {
+				// Already enriched. No need to create Subscriber so return original
+				return delegate;
+			}
+			return new SecurityReactorContextSubscriber<>(delegate, getContextAttributes());
+		}
+
+		private static Map<Object, Object> getContextAttributes() {
+			HttpServletRequest servletRequest = null;
+			HttpServletResponse servletResponse = null;
+			ServletRequestAttributes requestAttributes =
+					(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
+			if (requestAttributes != null) {
+				servletRequest = requestAttributes.getRequest();
+				servletResponse = requestAttributes.getResponse();
+			}
+			Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+			if (authentication == null && servletRequest == null && servletResponse == null) {
+				return Collections.emptyMap();
+			}
+
+			Map<Object, Object> contextAttributes = new HashMap<>();
+			if (servletRequest != null) {
+				contextAttributes.put(HttpServletRequest.class, servletRequest);
+			}
+			if (servletResponse != null) {
+				contextAttributes.put(HttpServletResponse.class, servletResponse);
+			}
+			if (authentication != null) {
+				contextAttributes.put(Authentication.class, authentication);
+			}
+
+			return contextAttributes;
+		}
+	}
+
+	static class SecurityReactorContextSubscriber<T> implements CoreSubscriber<T> {
+		static final String SECURITY_CONTEXT_ATTRIBUTES = "org.springframework.security.SECURITY_CONTEXT_ATTRIBUTES";
+		private final CoreSubscriber<T> delegate;
+		private final Context context;
+
+		SecurityReactorContextSubscriber(CoreSubscriber<T> delegate, Map<Object, Object> attributes) {
+			this.delegate = delegate;
+			Context currentContext = this.delegate.currentContext();
+			Context context;
+			if (currentContext.hasKey(SECURITY_CONTEXT_ATTRIBUTES)) {
+				context = currentContext;
+			} else {
+				context = currentContext.put(SECURITY_CONTEXT_ATTRIBUTES, attributes);
+			}
+			this.context = context;
+		}
+
+		@Override
+		public Context currentContext() {
+			return this.context;
+		}
+
+		@Override
+		public void onSubscribe(Subscription s) {
+			this.delegate.onSubscribe(s);
+		}
+
+		@Override
+		public void onNext(T t) {
+			this.delegate.onNext(t);
+		}
+
+		@Override
+		public void onError(Throwable t) {
+			this.delegate.onError(t);
+		}
+
+		@Override
+		public void onComplete() {
+			this.delegate.onComplete();
+		}
+	}
+}

+ 195 - 0
config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java

@@ -0,0 +1,195 @@
+/*
+ * Copyright 2002-2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.config.annotation.web.configuration;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.springframework.http.HttpStatus;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.config.annotation.web.builders.HttpSecurity;
+import org.springframework.security.config.test.SpringTestRule;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
+import org.springframework.web.context.request.RequestContextHolder;
+import org.springframework.web.context.request.ServletRequestAttributes;
+import org.springframework.web.reactive.function.client.ClientRequest;
+import org.springframework.web.reactive.function.client.ClientResponse;
+import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
+import reactor.core.CoreSubscriber;
+import reactor.core.publisher.BaseSubscriber;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+import reactor.util.context.Context;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.net.URI;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.entry;
+import static org.springframework.http.HttpMethod.GET;
+import static org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES;
+
+/**
+ * Tests for {@link SecurityReactorContextConfiguration}.
+ *
+ * @author Joe Grandja
+ * @since 5.2
+ */
+public class SecurityReactorContextConfigurationTests {
+	private MockHttpServletRequest servletRequest;
+	private MockHttpServletResponse servletResponse;
+	private Authentication authentication;
+	private SecurityReactorContextConfiguration.SecurityReactorContextSubscriberRegistrar subscriberRegistrar =
+			new SecurityReactorContextConfiguration.SecurityReactorContextSubscriberRegistrar();
+
+	@Rule
+	public final SpringTestRule spring = new SpringTestRule();
+
+	@Before
+	public void setup() {
+		this.servletRequest = new MockHttpServletRequest();
+		this.servletResponse = new MockHttpServletResponse();
+		this.authentication = new TestingAuthenticationToken("principal", "password");
+	}
+
+	@After
+	public void cleanup() {
+		SecurityContextHolder.clearContext();
+		RequestContextHolder.resetRequestAttributes();
+	}
+
+	@Test
+	public void createSubscriberIfNecessaryWhenSubscriberContextContainsSecurityContextAttributesThenReturnOriginalSubscriber() {
+		Context context = Context.of(SECURITY_CONTEXT_ATTRIBUTES, new HashMap<>());
+		BaseSubscriber<Object> originalSubscriber = new BaseSubscriber<Object>() {
+			@Override
+			public Context currentContext() {
+				return context;
+			}
+		};
+		CoreSubscriber<Object> resultSubscriber = this.subscriberRegistrar.createSubscriberIfNecessary(originalSubscriber);
+		assertThat(resultSubscriber).isSameAs(originalSubscriber);
+	}
+
+	@Test
+	public void createSubscriberIfNecessaryWhenWebSecurityContextAvailableThenCreateWithParentContext() {
+		RequestContextHolder.setRequestAttributes(
+				new ServletRequestAttributes(this.servletRequest, this.servletResponse));
+		SecurityContextHolder.getContext().setAuthentication(this.authentication);
+
+		String testKey = "test_key";
+		String testValue = "test_value";
+
+		BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
+			@Override
+			public Context currentContext() {
+				return Context.of(testKey, testValue);
+			}
+		};
+		CoreSubscriber<Object> subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(parent);
+
+		Context resultContext = subscriber.currentContext();
+
+		assertThat(resultContext.getOrEmpty(testKey)).hasValue(testValue);
+		Map<Object, Object> securityContextAttributes = resultContext.getOrDefault(SECURITY_CONTEXT_ATTRIBUTES, null);
+		assertThat(securityContextAttributes).hasSize(3);
+		assertThat(securityContextAttributes).contains(
+				entry(HttpServletRequest.class, this.servletRequest),
+				entry(HttpServletResponse.class, this.servletResponse),
+				entry(Authentication.class, this.authentication));
+	}
+
+	@Test
+	public void createSubscriberIfNecessaryWhenParentContextContainsSecurityContextAttributesThenUseParentContext() {
+		RequestContextHolder.setRequestAttributes(
+				new ServletRequestAttributes(this.servletRequest, this.servletResponse));
+		SecurityContextHolder.getContext().setAuthentication(this.authentication);
+
+		Context parentContext = Context.of(SECURITY_CONTEXT_ATTRIBUTES, new HashMap<>());
+		BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
+			@Override
+			public Context currentContext() {
+				return parentContext;
+			}
+		};
+		CoreSubscriber<Object> subscriber = this.subscriberRegistrar.createSubscriberIfNecessary(parent);
+
+		Context resultContext = subscriber.currentContext();
+		assertThat(resultContext).isSameAs(parentContext);
+	}
+
+	@Test
+	public void createPublisherWhenLastOperatorAddedThenSecurityContextAttributesAvailable() {
+		// Trigger the importing of SecurityReactorContextConfiguration via OAuth2ImportSelector
+		this.spring.register(SecurityConfig.class).autowire();
+
+		// Setup for SecurityReactorContextSubscriberRegistrar
+		RequestContextHolder.setRequestAttributes(
+				new ServletRequestAttributes(this.servletRequest, this.servletResponse));
+		SecurityContextHolder.getContext().setAuthentication(this.authentication);
+
+		ClientResponse clientResponseOk = ClientResponse.create(HttpStatus.OK).build();
+
+		ExchangeFilterFunction filter = (req, next) ->
+				Mono.subscriberContext()
+						.filter(ctx -> ctx.hasKey(SECURITY_CONTEXT_ATTRIBUTES))
+						.map(ctx -> ctx.get(SECURITY_CONTEXT_ATTRIBUTES))
+						.cast(Map.class)
+						.map(attributes -> {
+							if (attributes.containsKey(HttpServletRequest.class) &&
+									attributes.containsKey(HttpServletResponse.class) &&
+									attributes.containsKey(Authentication.class)) {
+								return clientResponseOk;
+							} else {
+								return ClientResponse.create(HttpStatus.NOT_FOUND).build();
+							}
+						});
+
+		ClientRequest clientRequest = ClientRequest.create(GET, URI.create("https://example.com")).build();
+		MockExchangeFunction exchange = new MockExchangeFunction();
+
+		Map<Object, Object> expectedContextAttributes = new HashMap<>();
+		expectedContextAttributes.put(HttpServletRequest.class, this.servletRequest);
+		expectedContextAttributes.put(HttpServletResponse.class, this.servletResponse);
+		expectedContextAttributes.put(Authentication.class, this.authentication);
+
+		Mono<ClientResponse> clientResponseMono = filter.filter(clientRequest, exchange)
+				.flatMap(response -> filter.filter(clientRequest, exchange));
+
+		StepVerifier.create(clientResponseMono)
+				.expectAccessibleContext()
+				.contains(SECURITY_CONTEXT_ATTRIBUTES, expectedContextAttributes)
+				.then()
+				.expectNext(clientResponseOk)
+				.verifyComplete();
+	}
+
+	@EnableWebSecurity
+	static class SecurityConfig extends WebSecurityConfigurerAdapter {
+
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+		}
+	}
+}

+ 20 - 137
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -16,10 +16,6 @@
 
 package org.springframework.security.oauth2.client.web.reactive.function.client;
 
-import org.reactivestreams.Subscription;
-import org.springframework.beans.factory.DisposableBean;
-import org.springframework.beans.factory.InitializingBean;
-import org.springframework.lang.Nullable;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
@@ -47,10 +43,7 @@ import org.springframework.web.reactive.function.client.ClientResponse;
 import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
 import org.springframework.web.reactive.function.client.ExchangeFunction;
 import org.springframework.web.reactive.function.client.WebClient;
-import reactor.core.CoreSubscriber;
-import reactor.core.publisher.Hooks;
 import reactor.core.publisher.Mono;
-import reactor.core.publisher.Operators;
 import reactor.core.scheduler.Schedulers;
 import reactor.util.context.Context;
 
@@ -100,8 +93,10 @@ import java.util.function.Consumer;
  * @since 5.1
  * @see OAuth2AuthorizedClientManager
  */
-public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
-		implements ExchangeFilterFunction, InitializingBean, DisposableBean {
+public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
+
+	// Same key as in SecurityReactorContextConfiguration.SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES
+	static final String SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY = "org.springframework.security.SECURITY_CONTEXT_ATTRIBUTES";
 
 	/**
 	 * The request attribute name used to locate the {@link OAuth2AuthorizedClient}.
@@ -112,8 +107,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
 	private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
 	private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
 
-	private static final String REQUEST_CONTEXT_OPERATOR_KEY = RequestContextSubscriber.class.getName();
-
 	private static final Authentication ANONYMOUS_AUTHENTICATION = new AnonymousAuthenticationToken(
 			"anonymous", "anonymousUser", AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"));
 
@@ -175,16 +168,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
 		return authorizedClientManager;
 	}
 
-	@Override
-	public void afterPropertiesSet() {
-		Hooks.onLastOperator(REQUEST_CONTEXT_OPERATOR_KEY, Operators.liftPublisher((s, sub) -> createRequestContextSubscriberIfNecessary(sub)));
-	}
-
-	@Override
-	public void destroy() {
-		Hooks.resetOnLastOperator(REQUEST_CONTEXT_OPERATOR_KEY);
-	}
-
 	/**
 	 * Sets the {@link OAuth2AccessTokenResponseClient} used for getting an {@link OAuth2AuthorizedClient} for the client_credentials grant.
 	 *
@@ -382,22 +365,22 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
 	}
 
 	private void populateRequestAttributes(Map<String, Object> attrs, Context ctx) {
-		RequestContextDataHolder holder = RequestContextSubscriber.getRequestContext(ctx);
-		if (holder != null) {
-			HttpServletRequest request = holder.getRequest();
-			if (request != null) {
-				attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, request);
-			}
-
-			HttpServletResponse response = holder.getResponse();
-			if (response != null) {
-				attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, response);
-			}
-
-			Authentication authentication = holder.getAuthentication();
-			if (authentication != null) {
-				attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
-			}
+		// NOTE: SecurityReactorContextConfiguration.SecurityReactorContextSubscriber adds this key
+		if (!ctx.hasKey(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY)) {
+			return;
+		}
+		Map<Object, Object> contextAttributes = ctx.get(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY);
+		HttpServletRequest servletRequest = (HttpServletRequest) contextAttributes.get(HttpServletRequest.class);
+		if (servletRequest != null) {
+			attrs.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, servletRequest);
+		}
+		HttpServletResponse servletResponse = (HttpServletResponse) contextAttributes.get(HttpServletResponse.class);
+		if (servletResponse != null) {
+			attrs.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, servletResponse);
+		}
+		Authentication authentication = (Authentication) contextAttributes.get(Authentication.class);
+		if (authentication != null) {
+			attrs.putIfAbsent(AUTHENTICATION_ATTR_NAME, authentication);
 		}
 	}
 
@@ -503,23 +486,6 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
 					.build();
 	}
 
-	<T> CoreSubscriber<T> createRequestContextSubscriberIfNecessary(CoreSubscriber<T> delegate) {
-		HttpServletRequest request = null;
-		HttpServletResponse response = null;
-		ServletRequestAttributes requestAttributes =
-				(ServletRequestAttributes) RequestContextHolder.getRequestAttributes();
-		if (requestAttributes != null) {
-			request = requestAttributes.getRequest();
-			response = requestAttributes.getResponse();
-		}
-		Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
-		if (authentication == null && request == null && response == null) {
-			//do not need to create RequestContextSubscriber with empty data
-			return delegate;
-		}
-		return new RequestContextSubscriber<>(delegate, request, response, authentication);
-	}
-
 	static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map<String, Object> attrs) {
 		return (OAuth2AuthorizedClient) attrs.get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
 	}
@@ -587,87 +553,4 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction
 			return new UnsupportedOperationException("Not Supported");
 		}
 	}
-
-	static class RequestContextSubscriber<T> implements CoreSubscriber<T> {
-		static final String REQUEST_CONTEXT_DATA_HOLDER =
-				RequestContextSubscriber.class.getName().concat(".REQUEST_CONTEXT_DATA_HOLDER");
-		private final CoreSubscriber<T> delegate;
-		private final Context context;
-
-		RequestContextSubscriber(CoreSubscriber<T> delegate,
-								HttpServletRequest request,
-								HttpServletResponse response,
-								Authentication authentication) {
-			this.delegate = delegate;
-
-			Context parentContext = this.delegate.currentContext();
-			Context context;
-			if (parentContext.hasKey(REQUEST_CONTEXT_DATA_HOLDER)) {
-				context = parentContext;
-			} else {
-				context = parentContext.put(REQUEST_CONTEXT_DATA_HOLDER, new RequestContextDataHolder(request, response, authentication));
-			}
-
-			this.context = context;
-		}
-
-		@Nullable
-		private static RequestContextDataHolder getRequestContext(Context ctx) {
-			return ctx.getOrDefault(REQUEST_CONTEXT_DATA_HOLDER, null);
-		}
-
-		@Override
-		public Context currentContext() {
-			return this.context;
-		}
-
-		@Override
-		public void onSubscribe(Subscription s) {
-			this.delegate.onSubscribe(s);
-		}
-
-		@Override
-		public void onNext(T t) {
-			this.delegate.onNext(t);
-		}
-
-		@Override
-		public void onError(Throwable t) {
-			this.delegate.onError(t);
-		}
-
-		@Override
-		public void onComplete() {
-			this.delegate.onComplete();
-		}
-	}
-
-	static class RequestContextDataHolder {
-		private final HttpServletRequest request;
-		private final HttpServletResponse response;
-		private final Authentication authentication;
-
-		RequestContextDataHolder(@Nullable HttpServletRequest request,
-								@Nullable HttpServletResponse response,
-								@Nullable Authentication authentication) {
-			this.request = request;
-			this.response = response;
-			this.authentication = authentication;
-		}
-
-		@Nullable
-		private HttpServletRequest getRequest() {
-			return this.request;
-		}
-
-		@Nullable
-		private HttpServletResponse getResponse() {
-			return this.response;
-		}
-
-		@Nullable
-		private Authentication getAuthentication() {
-			return this.authentication;
-		}
-	}
 }

+ 13 - 2
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionITests.java

@@ -43,16 +43,20 @@ import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
 import org.springframework.web.reactive.function.client.WebClient;
 import reactor.blockhound.BlockHound;
+import reactor.util.context.Context;
 
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import java.time.Duration;
 import java.time.Instant;
 import java.util.Arrays;
+import java.util.HashMap;
 import java.util.HashSet;
+import java.util.Map;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.Mockito.*;
+import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY;
 import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
 
 /**
@@ -104,7 +108,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
 		});
 		this.authorizedClientFilter = new ServletOAuth2AuthorizedClientExchangeFilterFunction(
 				this.clientRegistrationRepository, this.authorizedClientRepository);
-		this.authorizedClientFilter.afterPropertiesSet();
 		this.server = new MockWebServer();
 		this.server.start();
 		this.serverUrl = this.server.url("/").toString();
@@ -120,7 +123,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
 
 	@After
 	public void cleanup() throws Exception {
-		this.authorizedClientFilter.destroy();
 		this.server.shutdown();
 		SecurityContextHolder.clearContext();
 		RequestContextHolder.resetRequestAttributes();
@@ -248,6 +250,7 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
 						.attributes(clientRegistrationId(clientRegistration2.getRegistrationId()))
 						.retrieve()
 						.bodyToMono(String.class))
+				.subscriberContext(context())
 				.block();
 
 		assertThat(this.server.getRequestCount()).isEqualTo(4);
@@ -259,6 +262,14 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
 		assertThat(authorizedClientCaptor.getAllValues().get(1).getClientRegistration()).isSameAs(clientRegistration2);
 	}
 
+	private Context context() {
+		Map<Object, Object> contextAttributes = new HashMap<>();
+		contextAttributes.put(HttpServletRequest.class, this.request);
+		contextAttributes.put(HttpServletResponse.class, this.response);
+		contextAttributes.put(Authentication.class, this.authentication);
+		return Context.of(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY, contextAttributes);
+	}
+
 	private MockResponse jsonResponse(String json) {
 		return new MockResponse()
 				.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)

+ 10 - 151
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -76,12 +76,10 @@ import org.springframework.web.context.request.ServletRequestAttributes;
 import org.springframework.web.reactive.function.BodyInserter;
 import org.springframework.web.reactive.function.client.ClientRequest;
 import org.springframework.web.reactive.function.client.WebClient;
-import reactor.core.CoreSubscriber;
-import reactor.core.publisher.BaseSubscriber;
-import reactor.core.publisher.Mono;
 import reactor.util.context.Context;
 
 import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 import java.net.URI;
 import java.time.Duration;
 import java.time.Instant;
@@ -93,7 +91,6 @@ import java.util.Optional;
 import java.util.function.Consumer;
 
 import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatCode;
 import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy;
 import static org.mockito.Mockito.*;
 import static org.springframework.http.HttpMethod.GET;
@@ -163,7 +160,6 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 	public void cleanup() throws Exception {
 		SecurityContextHolder.clearContext();
 		RequestContextHolder.resetRequestAttributes();
-		this.function.destroy();
 	}
 
 	@Test
@@ -591,18 +587,15 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 	// gh-6483
 	@Test
 	public void filterWhenChainedThenDefaultsStillAvailable() throws Exception {
-		this.function.afterPropertiesSet();			// Hooks.onLastOperator() initialized
 		this.function.setDefaultOAuth2AuthorizedClient(true);
 
 		MockHttpServletRequest servletRequest = new MockHttpServletRequest();
 		MockHttpServletResponse servletResponse = new MockHttpServletResponse();
-		RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
 
 		OAuth2User user = mock(OAuth2User.class);
 		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
 		OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
 				user, authorities, this.registration.getRegistrationId());
-		SecurityContextHolder.getContext().setAuthentication(authentication);
 
 		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
 				this.registration, "principalName", this.accessToken);
@@ -619,12 +612,13 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		// Default request attributes NOT set
 		final ClientRequest request2 = ClientRequest.create(GET, URI.create("https://example2.com")).build();
 
+		Context context = context(servletRequest, servletResponse, authentication);
+
 		this.function.filter(request1, this.exchange)
 				.flatMap(response -> this.function.filter(request2, this.exchange))
+				.subscriberContext(context)
 				.block();
 
-		this.function.destroy();		// Hooks.onLastOperator() released
-
 		List<ClientRequest> requests = this.exchange.getRequests();
 		assertThat(requests).hasSize(2);
 
@@ -641,147 +635,12 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		assertThat(getBody(request)).isEmpty();
 	}
 
-	@Test
-	public void filterWhenRequestAttributesNotSetAndHooksNotInitThenDefaultsNotAvailable() {
-//		this.function.afterPropertiesSet();		// Hooks.onLastOperator() NOT initialized
-		this.function.setDefaultOAuth2AuthorizedClient(true);
-
-		MockHttpServletRequest servletRequest = new MockHttpServletRequest();
-		MockHttpServletResponse servletResponse = new MockHttpServletResponse();
-		RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
-
-		OAuth2User user = mock(OAuth2User.class);
-		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
-		OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
-				user, authorities, this.registration.getRegistrationId());
-		SecurityContextHolder.getContext().setAuthentication(authentication);
-
-		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
-
-		this.function.filter(request, this.exchange).block();
-
-		List<ClientRequest> requests = this.exchange.getRequests();
-		assertThat(requests).hasSize(1);
-
-		request = requests.get(0);
-		assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
-		assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
-		assertThat(request.method()).isEqualTo(HttpMethod.GET);
-		assertThat(getBody(request)).isEmpty();
-	}
-
-	@Test
-	public void filterWhenRequestAttributesNotSetAndHooksInitHooksResetThenDefaultsNotAvailable() throws Exception {
-		this.function.afterPropertiesSet();			// Hooks.onLastOperator() initialized
-		this.function.destroy();					// Hooks.onLastOperator() released
-		this.function.setDefaultOAuth2AuthorizedClient(true);
-
-		MockHttpServletRequest servletRequest = new MockHttpServletRequest();
-		MockHttpServletResponse servletResponse = new MockHttpServletResponse();
-		RequestContextHolder.setRequestAttributes(new ServletRequestAttributes(servletRequest, servletResponse));
-
-		OAuth2User user = mock(OAuth2User.class);
-		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_USER");
-		OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(
-				user, authorities, this.registration.getRegistrationId());
-		SecurityContextHolder.getContext().setAuthentication(authentication);
-
-		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")).build();
-
-		this.function.filter(request, this.exchange).block();
-
-		List<ClientRequest> requests = this.exchange.getRequests();
-		assertThat(requests).hasSize(1);
-
-		request = requests.get(0);
-		assertThat(request.headers().getFirst(HttpHeaders.AUTHORIZATION)).isNull();
-		assertThat(request.url().toASCIIString()).isEqualTo("https://example.com");
-		assertThat(request.method()).isEqualTo(HttpMethod.GET);
-		assertThat(getBody(request)).isEmpty();
-	}
-
-	// gh-7228
-	@Test
-	public void afterPropertiesSetWhenHooksInitAndOutsideWebSecurityContextThenShouldNotThrowException() throws Exception {
-		this.function.afterPropertiesSet();			// Hooks.onLastOperator() initialized
-		assertThatCode(() -> Mono.subscriberContext().block())
-				.as("RequestContext Hook brakes application outside of web/security context")
-				.doesNotThrowAnyException();
-	}
-
-	@Test
-	public void createRequestContextSubscriberIfNecessaryWhenOutsideWebSecurityContextThenReturnOriginalSubscriber() throws Exception {
-		BaseSubscriber<Object> originalSubscriber = new BaseSubscriber<Object>() {};
-		CoreSubscriber<Object> resultSubscriber = this.function.createRequestContextSubscriberIfNecessary(originalSubscriber);
-		assertThat(resultSubscriber).isSameAs(originalSubscriber);
-	}
-
-	// gh-7228
-	@Test
-	public void createRequestContextSubscriberWhenRequestResponseProvidedThenCreateWithParentContext() throws Exception {
-		testRequestContextSubscriber(new MockHttpServletRequest(), new MockHttpServletResponse(), null);
-	}
-
-	// gh-7228
-	@Test
-	public void createRequestContextSubscriberWhenAuthenticationProvidedThenCreateWithParentContext() throws Exception {
-		testRequestContextSubscriber(null, null, this.authentication);
-	}
-
-	@Test
-	public void createRequestContextSubscriberWhenParentContextHasDataHolderThenShouldReuseParentContext() throws Exception {
-		RequestContextDataHolder testValue = new RequestContextDataHolder(null, null, null);
-		final Context parentContext = Context.of(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, testValue);
-		BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
-			@Override
-			public Context currentContext() {
-				return parentContext;
-			}
-		};
-
-		RequestContextSubscriber<Object> requestContextSubscriber =
-				new RequestContextSubscriber<>(parent, null, null, authentication);
-
-		Context resultContext = requestContextSubscriber.currentContext();
-
-		assertThat(resultContext)
-				.describedAs("parent context was replaced")
-				.isSameAs(parentContext);
-	}
-
-	private void testRequestContextSubscriber(MockHttpServletRequest servletRequest,
-											MockHttpServletResponse servletResponse,
-											Authentication authentication) {
-		String testKey = "test_key";
-		String testValue = "test_value";
-
-		BaseSubscriber<Object> parent = new BaseSubscriber<Object>() {
-			@Override
-			public Context currentContext() {
-				return Context.of(testKey, testValue);
-			}
-		};
-
-		RequestContextSubscriber<Object> requestContextSubscriber =
-				new RequestContextSubscriber<>(parent, servletRequest, servletResponse, authentication);
-
-		Context resultContext = requestContextSubscriber.currentContext();
-
-		assertThat(resultContext)
-				.describedAs("result context is null")
-				.isNotNull();
-
-		assertThat(resultContext.getOrEmpty(testKey))
-				.describedAs("context is replaced")
-				.hasValue(testValue);
-
-		Object dataHolder = resultContext.getOrDefault(RequestContextSubscriber.REQUEST_CONTEXT_DATA_HOLDER, null);
-		assertThat(dataHolder)
-				.describedAs("context is not populated with REQUEST_CONTEXT_DATA_HOLDER")
-				.isNotNull()
-				.hasFieldOrPropertyWithValue("request", servletRequest)
-				.hasFieldOrPropertyWithValue("response", servletResponse)
-				.hasFieldOrPropertyWithValue("authentication", authentication);
+	private Context context(HttpServletRequest servletRequest, HttpServletResponse servletResponse, Authentication authentication) {
+		Map<Object, Object> contextAttributes = new HashMap<>();
+		contextAttributes.put(HttpServletRequest.class, servletRequest);
+		contextAttributes.put(HttpServletResponse.class, servletResponse);
+		contextAttributes.put(Authentication.class, authentication);
+		return Context.of(SECURITY_REACTOR_CONTEXT_ATTRIBUTES_KEY, contextAttributes);
 	}
 
 	private static String getBody(ClientRequest request) {