Просмотр исходного кода

Use SecurityContextHolderStrategy for Async Requests

Issue gh-11060
Issue gh-11061
Josh Cummings 3 лет назад
Родитель
Сommit
a218d3e140

+ 13 - 1
config/src/main/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfiguration.java

@@ -33,6 +33,8 @@ import org.springframework.security.config.annotation.authentication.configurati
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
 import org.springframework.security.config.annotation.web.configurers.DefaultLoginPageConfigurer;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter;
 
 import static org.springframework.security.config.Customizer.withDefaults;
@@ -58,6 +60,9 @@ class HttpSecurityConfiguration {
 
 	private ApplicationContext context;
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	@Autowired
 	void setObjectPostProcessor(ObjectPostProcessor<Object> objectPostProcessor) {
 		this.objectPostProcessor = objectPostProcessor;
@@ -77,6 +82,11 @@ class HttpSecurityConfiguration {
 		this.context = context;
 	}
 
+	@Autowired(required = false)
+	void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 	@Bean(HTTPSECURITY_BEAN_NAME)
 	@Scope("prototype")
 	HttpSecurity httpSecurity() throws Exception {
@@ -86,10 +96,12 @@ class HttpSecurityConfiguration {
 				this.objectPostProcessor, passwordEncoder);
 		authenticationBuilder.parentAuthenticationManager(authenticationManager());
 		HttpSecurity http = new HttpSecurity(this.objectPostProcessor, authenticationBuilder, createSharedObjects());
+		WebAsyncManagerIntegrationFilter webAsyncManagerIntegrationFilter = new WebAsyncManagerIntegrationFilter();
+		webAsyncManagerIntegrationFilter.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
 		// @formatter:off
 		http
 			.csrf(withDefaults())
-			.addFilter(new WebAsyncManagerIntegrationFilter())
+			.addFilter(webAsyncManagerIntegrationFilter)
 			.exceptionHandling(withDefaults())
 			.headers(withDefaults())
 			.sessionManagement(withDefaults())

+ 1 - 0
config/src/main/java/org/springframework/security/config/http/HttpConfigurationBuilder.java

@@ -587,6 +587,7 @@ class HttpConfigurationBuilder {
 		boolean asyncSupported = ClassUtils.hasMethod(ServletRequest.class, "startAsync");
 		if (asyncSupported) {
 			this.webAsyncManagerFilter = new RootBeanDefinition(WebAsyncManagerIntegrationFilter.class);
+			this.webAsyncManagerFilter.getPropertyValues().add("securityContextHolderStrategy", this.holderStrategyRef);
 		}
 	}
 

+ 23 - 3
config/src/test/java/org/springframework/security/config/annotation/web/configuration/HttpSecurityConfigurationTests.java

@@ -35,11 +35,13 @@ import org.springframework.core.io.support.SpringFactoriesLoader;
 import org.springframework.mock.web.MockHttpSession;
 import org.springframework.security.access.AccessDeniedException;
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
 import org.springframework.security.config.test.SpringTestContext;
 import org.springframework.security.config.test.SpringTestContextExtension;
-import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.core.userdetails.User;
 import org.springframework.security.core.userdetails.UserDetails;
 import org.springframework.security.core.userdetails.UserDetailsService;
@@ -54,6 +56,8 @@ import org.springframework.web.bind.annotation.RestController;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.verify;
 import static org.springframework.security.config.Customizer.withDefaults;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.authentication;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
@@ -134,6 +138,22 @@ public class HttpSecurityConfigurationTests {
 		// @formatter:on
 	}
 
+	@Test
+	public void asyncDispatchWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
+		this.spring.register(DefaultWithFilterChainConfig.class, SecurityContextChangedListenerConfig.class,
+				NameController.class).autowire();
+		// @formatter:off
+		MockHttpServletRequestBuilder requestWithBob = get("/name").with(user("Bob"));
+		MvcResult mvcResult = this.mockMvc.perform(requestWithBob)
+				.andExpect(request().asyncStarted())
+				.andReturn();
+		this.mockMvc.perform(asyncDispatch(mvcResult))
+				.andExpect(status().isOk())
+				.andExpect(content().string("Bob"));
+		// @formatter:on
+		verify(this.spring.getContext().getBean(SecurityContextHolderStrategy.class), atLeastOnce()).getContext();
+	}
+
 	@Test
 	public void getWhenDefaultFilterChainBeanThenAnonymousPermitted() throws Exception {
 		this.spring.register(AuthorizeRequestsConfig.class, UserDetailsConfig.class, BaseController.class).autowire();
@@ -243,8 +263,8 @@ public class HttpSecurityConfigurationTests {
 	static class NameController {
 
 		@GetMapping("/name")
-		Callable<String> name() {
-			return () -> SecurityContextHolder.getContext().getAuthentication().getName();
+		Callable<String> name(Authentication authentication) {
+			return () -> authentication.getName();
 		}
 
 	}

+ 24 - 0
config/src/test/java/org/springframework/security/config/http/MiscHttpConfigTests.java

@@ -27,6 +27,7 @@ import java.util.HashSet;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.Callable;
 import java.util.stream.Collectors;
 
 import javax.security.auth.Subject;
@@ -127,12 +128,15 @@ import static org.mockito.Mockito.verify;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.csrf;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.httpBasic;
+import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.user;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.x509;
+import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.asyncDispatch;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.delete;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
+import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.request;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
 
 /**
@@ -762,6 +766,21 @@ public class MiscHttpConfigTests {
 		// @formatter:on
 	}
 
+	@Test
+	public void asyncDispatchWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
+		this.spring.configLocations(xml("WithSecurityContextHolderStrategy")).autowire();
+		// @formatter:off
+		MockHttpServletRequestBuilder requestWithBob = get("/name").with(user("Bob"));
+		MvcResult mvcResult = this.mvc.perform(requestWithBob)
+				.andExpect(request().asyncStarted())
+				.andReturn();
+		this.mvc.perform(asyncDispatch(mvcResult))
+				.andExpect(status().isOk())
+				.andExpect(content().string("Bob"));
+		// @formatter:on
+		verify(this.spring.getContext().getBean(SecurityContextHolderStrategy.class), atLeastOnce()).getContext();
+	}
+
 	/**
 	 * SEC-1893
 	 */
@@ -905,6 +924,11 @@ public class MiscHttpConfigTests {
 			return authentication.getDetails().getClass().getName();
 		}
 
+		@GetMapping("/name")
+		Callable<String> name(Authentication authentication) {
+			return () -> authentication.getName();
+		}
+
 	}
 
 	@RestController

+ 51 - 0
config/src/test/resources/org/springframework/security/config/http/MiscHttpConfigTests-WithSecurityContextHolderStrategy.xml

@@ -0,0 +1,51 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+  ~ Copyright 2002-2018 the original author or authors.
+  ~
+  ~ Licensed under the Apache License, Version 2.0 (the "License");
+  ~ you may not use this file except in compliance with the License.
+  ~ You may obtain a copy of the License at
+  ~
+  ~       https://www.apache.org/licenses/LICENSE-2.0
+  ~
+  ~ Unless required by applicable law or agreed to in writing, software
+  ~ distributed under the License is distributed on an "AS IS" BASIS,
+  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  ~ See the License for the specific language governing permissions and
+  ~ limitations under the License.
+  -->
+
+<b:beans xmlns:b="http://www.springframework.org/schema/beans"
+		xmlns:mvc="http://www.springframework.org/schema/mvc"
+		xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+		xmlns="http://www.springframework.org/schema/security"
+		xsi:schemaLocation="
+			http://www.springframework.org/schema/security
+			https://www.springframework.org/schema/security/spring-security.xsd
+			http://www.springframework.org/schema/beans
+			https://www.springframework.org/schema/beans/spring-beans.xsd
+			http://www.springframework.org/schema/mvc
+			https://www.springframework.org/schema/mvc/spring-mvc.xsd">
+
+	<http auto-config="true" security-context-holder-strategy-ref="ref">
+		<intercept-url pattern="/**" access="authenticated"/>
+	</http>
+
+	<b:bean id="ref" class="org.mockito.Mockito" factory-method="spy">
+		<b:constructor-arg>
+			<b:bean class="org.springframework.security.config.MockSecurityContextHolderStrategy"/>
+		</b:constructor-arg>
+	</b:bean>
+
+	<mvc:annotation-driven>
+		<mvc:argument-resolvers>
+			<b:bean class="org.springframework.security.web.method.annotation.AuthenticationPrincipalArgumentResolver">
+				<b:property name="securityContextHolderStrategy" ref="ref"/>
+			</b:bean>
+		</mvc:argument-resolvers>
+	</mvc:annotation-driven>
+
+	<b:bean class="org.springframework.security.config.http.MiscHttpConfigTests.AuthenticationController"/>
+
+	<b:import resource="userservice.xml"/>
+</b:beans>

+ 19 - 4
web/src/main/java/org/springframework/security/web/context/request/async/SecurityContextCallableProcessingInterceptor.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@ import java.util.concurrent.Callable;
 
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.util.Assert;
 import org.springframework.web.context.request.NativeWebRequest;
 import org.springframework.web.context.request.async.CallableProcessingInterceptor;
@@ -43,6 +44,9 @@ public final class SecurityContextCallableProcessingInterceptor implements Calla
 
 	private volatile SecurityContext securityContext;
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	/**
 	 * Create a new {@link SecurityContextCallableProcessingInterceptor} that uses the
 	 * {@link SecurityContext} from the {@link SecurityContextHolder} at the time
@@ -67,18 +71,29 @@ public final class SecurityContextCallableProcessingInterceptor implements Calla
 	@Override
 	public <T> void beforeConcurrentHandling(NativeWebRequest request, Callable<T> task) {
 		if (this.securityContext == null) {
-			setSecurityContext(SecurityContextHolder.getContext());
+			setSecurityContext(this.securityContextHolderStrategy.getContext());
 		}
 	}
 
 	@Override
 	public <T> void preProcess(NativeWebRequest request, Callable<T> task) {
-		SecurityContextHolder.setContext(this.securityContext);
+		this.securityContextHolderStrategy.setContext(this.securityContext);
 	}
 
 	@Override
 	public <T> void postProcess(NativeWebRequest request, Callable<T> task, Object concurrentResult) {
-		SecurityContextHolder.clearContext();
+		this.securityContextHolderStrategy.clearContext();
+	}
+
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 *
+	 * @since 5.8
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
 	}
 
 	private void setSecurityContext(SecurityContext securityContext) {

+ 21 - 3
web/src/main/java/org/springframework/security/web/context/request/async/WebAsyncManagerIntegrationFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2016 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -25,6 +25,9 @@ import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletResponse;
 
 import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.util.Assert;
 import org.springframework.web.context.request.async.WebAsyncManager;
 import org.springframework.web.context.request.async.WebAsyncUtils;
 import org.springframework.web.filter.OncePerRequestFilter;
@@ -42,6 +45,9 @@ public final class WebAsyncManagerIntegrationFilter extends OncePerRequestFilter
 
 	private static final Object CALLABLE_INTERCEPTOR_KEY = new Object();
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
@@ -49,10 +55,22 @@ public final class WebAsyncManagerIntegrationFilter extends OncePerRequestFilter
 		SecurityContextCallableProcessingInterceptor securityProcessingInterceptor = (SecurityContextCallableProcessingInterceptor) asyncManager
 				.getCallableInterceptor(CALLABLE_INTERCEPTOR_KEY);
 		if (securityProcessingInterceptor == null) {
-			asyncManager.registerCallableInterceptor(CALLABLE_INTERCEPTOR_KEY,
-					new SecurityContextCallableProcessingInterceptor());
+			SecurityContextCallableProcessingInterceptor interceptor = new SecurityContextCallableProcessingInterceptor();
+			interceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
+			asyncManager.registerCallableInterceptor(CALLABLE_INTERCEPTOR_KEY, interceptor);
 		}
 		filterChain.doFilter(request, response);
 	}
 
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 *
+	 * @since 5.8
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 }