소스 검색

Use SecurityContextHolderStrategy for AuthenticationFilter

Issue gh-11060
Josh Cummings 3 년 전
부모
커밋
f3d99f557b

+ 19 - 4
web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -32,6 +32,7 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.web.context.NullSecurityContextRepository;
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.AnyRequestMatcher;
@@ -67,6 +68,9 @@ import org.springframework.web.filter.OncePerRequestFilter;
  */
 public class AuthenticationFilter extends OncePerRequestFilter {
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	private RequestMatcher requestMatcher = AnyRequestMatcher.INSTANCE;
 
 	private AuthenticationConverter authenticationConverter;
@@ -151,6 +155,17 @@ public class AuthenticationFilter extends OncePerRequestFilter {
 		this.securityContextRepository = securityContextRepository;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 *
+	 * @since 5.8
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
@@ -180,15 +195,15 @@ public class AuthenticationFilter extends OncePerRequestFilter {
 
 	private void unsuccessfulAuthentication(HttpServletRequest request, HttpServletResponse response,
 			AuthenticationException failed) throws IOException, ServletException {
-		SecurityContextHolder.clearContext();
+		this.securityContextHolderStrategy.clearContext();
 		this.failureHandler.onAuthenticationFailure(request, response, failed);
 	}
 
 	private void successfulAuthentication(HttpServletRequest request, HttpServletResponse response, FilterChain chain,
 			Authentication authentication) throws IOException, ServletException {
-		SecurityContext context = SecurityContextHolder.createEmptyContext();
+		SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
 		context.setAuthentication(authentication);
-		SecurityContextHolder.setContext(context);
+		this.securityContextHolderStrategy.setContext(context);
 		this.securityContextRepository.saveContext(context, request, response);
 		this.successHandler.onAuthenticationSuccess(request, response, chain, authentication);
 	}

+ 22 - 1
web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -40,6 +40,8 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 
@@ -128,6 +130,25 @@ public class AuthenticationFilterTests {
 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull();
 	}
 
+	@Test
+	public void filterWhenCustomSecurityContextHolderStrategyThenUses() throws Exception {
+		Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE");
+		given(this.authenticationConverter.convert(any())).willReturn(authentication);
+		given(this.authenticationManager.authenticate(any())).willReturn(authentication);
+		AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager,
+				this.authenticationConverter);
+		SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
+		given(strategy.createEmptyContext()).willReturn(new SecurityContextImpl());
+		filter.setSecurityContextHolderStrategy(strategy);
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", "/");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain chain = mock(FilterChain.class);
+		filter.doFilter(request, response, chain);
+		verify(this.authenticationManager).authenticate(any(Authentication.class));
+		verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
+		verify(strategy).setContext(any());
+	}
+
 	@Test
 	public void filterWhenAuthenticationManagerResolverDefaultsAndAuthenticationSuccessThenContinues()
 			throws Exception {