Procházet zdrojové kódy

Pick up SecurityContextHolderStrategy for WebClient integration

Issue gh-11061
Josh Cummings před 3 roky
rodič
revize
d24a89ad53

+ 33 - 12
config/src/main/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfiguration.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -36,10 +36,13 @@ import reactor.util.context.Context;
 
 import org.springframework.beans.factory.DisposableBean;
 import org.springframework.beans.factory.InitializingBean;
+import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.util.Assert;
 import org.springframework.web.context.request.RequestAttributes;
 import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
@@ -61,24 +64,37 @@ import org.springframework.web.context.request.ServletRequestAttributes;
 @Configuration(proxyBeanMethods = false)
 class SecurityReactorContextConfiguration {
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	@Bean
 	SecurityReactorContextSubscriberRegistrar securityReactorContextSubscriberRegistrar() {
-		return new SecurityReactorContextSubscriberRegistrar();
+		SecurityReactorContextSubscriberRegistrar registrar = new SecurityReactorContextSubscriberRegistrar();
+		registrar.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
+		return registrar;
+	}
+
+	@Autowired(required = false)
+	void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
 	}
 
 	static class SecurityReactorContextSubscriberRegistrar implements InitializingBean, DisposableBean {
 
 		private static final String SECURITY_REACTOR_CONTEXT_OPERATOR_KEY = "org.springframework.security.SECURITY_REACTOR_CONTEXT_OPERATOR";
 
-		private static final Map<Object, Supplier<Object>> CONTEXT_ATTRIBUTE_VALUE_LOADERS = new HashMap<>();
+		private final Map<Object, Supplier<Object>> CONTEXT_ATTRIBUTE_VALUE_LOADERS = new HashMap<>();
 
-		static {
-			CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletRequest.class,
+		private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+				.getContextHolderStrategy();
+
+		SecurityReactorContextSubscriberRegistrar() {
+			this.CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletRequest.class,
 					SecurityReactorContextSubscriberRegistrar::getRequest);
-			CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletResponse.class,
+			this.CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(HttpServletResponse.class,
 					SecurityReactorContextSubscriberRegistrar::getResponse);
-			CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(Authentication.class,
-					SecurityReactorContextSubscriberRegistrar::getAuthentication);
+			this.CONTEXT_ATTRIBUTE_VALUE_LOADERS.put(Authentication.class, this::getAuthentication);
 		}
 
 		@Override
@@ -93,6 +109,11 @@ class SecurityReactorContextConfiguration {
 			Hooks.resetOnLastOperator(SECURITY_REACTOR_CONTEXT_OPERATOR_KEY);
 		}
 
+		void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+			Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+			this.securityContextHolderStrategy = securityContextHolderStrategy;
+		}
+
 		<T> CoreSubscriber<T> createSubscriberIfNecessary(CoreSubscriber<T> delegate) {
 			if (delegate.currentContext().hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES)) {
 				// Already enriched. No need to create Subscriber so return original
@@ -101,8 +122,8 @@ class SecurityReactorContextConfiguration {
 			return new SecurityReactorContextSubscriber<>(delegate, getContextAttributes());
 		}
 
-		private static Map<Object, Object> getContextAttributes() {
-			return new LoadingMap<>(CONTEXT_ATTRIBUTE_VALUE_LOADERS);
+		private Map<Object, Object> getContextAttributes() {
+			return new LoadingMap<>(this.CONTEXT_ATTRIBUTE_VALUE_LOADERS);
 		}
 
 		private static HttpServletRequest getRequest() {
@@ -123,8 +144,8 @@ class SecurityReactorContextConfiguration {
 			return null;
 		}
 
-		private static Authentication getAuthentication() {
-			return SecurityContextHolder.getContext().getAuthentication();
+		private Authentication getAuthentication() {
+			return this.securityContextHolderStrategy.getContext().getAuthentication();
 		}
 
 	}

+ 20 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationResourceServerTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -28,9 +28,11 @@ import org.junit.jupiter.api.extension.ExtendWith;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
+import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.test.SpringTestContext;
 import org.springframework.security.config.test.SpringTestContextExtension;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.oauth2.server.resource.authentication.BearerTokenAuthentication;
 import org.springframework.security.oauth2.server.resource.authentication.TestBearerTokenAuthentications;
 import org.springframework.security.oauth2.server.resource.web.reactive.function.client.ServletBearerExchangeFilterFunction;
@@ -40,6 +42,8 @@ import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.reactive.function.client.WebClient;
 
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.verify;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
@@ -85,6 +89,21 @@ public class SecurityReactorContextConfigurationResourceServerTests {
 		// @formatter:on
 	}
 
+	@Test
+	public void requestWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
+		BearerTokenAuthentication authentication = TestBearerTokenAuthentications.bearer();
+		this.spring.register(BearerFilterConfig.class, WebServerConfig.class, Controller.class,
+				SecurityContextChangedListenerConfig.class).autowire();
+		MockHttpServletRequestBuilder authenticatedRequest = get("/token").with(authentication(authentication));
+		// @formatter:off
+		this.mockMvc.perform(authenticatedRequest)
+				.andExpect(status().isOk())
+				.andExpect(content().string("Bearer token"));
+		// @formatter:on
+		SecurityContextHolderStrategy strategy = this.spring.getContext().getBean(SecurityContextHolderStrategy.class);
+		verify(strategy, atLeastOnce()).getContext();
+	}
+
 	@EnableWebSecurity
 	static class BearerFilterConfig extends WebSecurityConfigurerAdapter {
 

+ 37 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -38,12 +38,14 @@ import org.springframework.http.HttpStatus;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.SecurityReactorContextConfiguration.SecurityReactorContextSubscriber;
 import org.springframework.security.config.test.SpringTestContext;
 import org.springframework.security.config.test.SpringTestContextExtension;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
 import org.springframework.web.context.request.RequestAttributes;
 import org.springframework.web.context.request.RequestContextHolder;
@@ -54,6 +56,8 @@ import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.entry;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
 
 /**
  * Tests for {@link SecurityReactorContextConfiguration}.
@@ -232,6 +236,38 @@ public class SecurityReactorContextConfigurationTests {
 		// @formatter:on
 	}
 
+	@Test
+	public void createPublisherWhenCustomSecurityContextHolderStrategyThenUses() {
+		this.spring.register(SecurityConfig.class, SecurityContextChangedListenerConfig.class).autowire();
+		SecurityContextHolderStrategy strategy = this.spring.getContext().getBean(SecurityContextHolderStrategy.class);
+		strategy.getContext().setAuthentication(this.authentication);
+		ClientResponse clientResponseOk = ClientResponse.create(HttpStatus.OK).build();
+		// @formatter:off
+		ExchangeFilterFunction filter = (req, next) -> Mono.deferContextual(Mono::just)
+				.filter((ctx) -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
+				.map((ctx) -> ctx.get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
+				.cast(Map.class)
+				.map((attributes) -> clientResponseOk);
+		// @formatter:on
+		ClientRequest clientRequest = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")).build();
+		MockExchangeFunction exchange = new MockExchangeFunction();
+		Map<Object, Object> expectedContextAttributes = new HashMap<>();
+		expectedContextAttributes.put(HttpServletRequest.class, null);
+		expectedContextAttributes.put(HttpServletResponse.class, null);
+		expectedContextAttributes.put(Authentication.class, this.authentication);
+		Mono<ClientResponse> clientResponseMono = filter.filter(clientRequest, exchange)
+				.flatMap((response) -> filter.filter(clientRequest, exchange));
+		// @formatter:off
+		StepVerifier.create(clientResponseMono)
+				.expectAccessibleContext()
+				.contains(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES, expectedContextAttributes)
+				.then()
+				.expectNext(clientResponseOk)
+				.verifyComplete();
+		// @formatter:on
+		verify(strategy, times(2)).getContext();
+	}
+
 	@EnableWebSecurity
 	static class SecurityConfig extends WebSecurityConfigurerAdapter {