瀏覽代碼

Add SecurityContextRepository to all Authentication Filters

Closes gh-10949
Rob Winch 3 年之前
父節點
當前提交
d2f24ae5f5
共有 16 個文件被更改,包括 363 次插入0 次删除
  1. 17 0
      cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java
  2. 36 0
      cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java
  3. 17 0
      oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java
  4. 25 0
      oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java
  5. 17 0
      web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java
  6. 17 0
      web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java
  7. 17 0
      web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java
  8. 17 0
      web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java
  9. 17 0
      web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java
  10. 17 0
      web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java
  11. 35 0
      web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java
  12. 35 0
      web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java
  13. 29 0
      web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java
  14. 19 0
      web/src/test/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilterTests.java
  15. 24 0
      web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java
  16. 24 0
      web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java

+ 17 - 0
cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java

@@ -42,6 +42,8 @@ import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
@@ -205,6 +207,8 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil
 
 	private AuthenticationFailureHandler proxyFailureHandler = new SimpleUrlAuthenticationFailureHandler();
 
+	private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
 	public CasAuthenticationFilter() {
 		super("/login/cas");
 		setAuthenticationFailureHandler(new SimpleUrlAuthenticationFailureHandler());
@@ -223,6 +227,7 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil
 		SecurityContext context = SecurityContextHolder.createEmptyContext();
 		context.setAuthentication(authResult);
 		SecurityContextHolder.setContext(context);
+		this.securityContextRepository.saveContext(context, request, response);
 		if (this.eventPublisher != null) {
 			this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass()));
 		}
@@ -274,6 +279,18 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil
 		return result;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * authentication success. The default action is not to save the
+	 * {@link SecurityContext}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	/**
 	 * Sets the {@link AuthenticationFailureHandler} for proxy requests.
 	 * @param proxyFailureHandler

+ 36 - 0
cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java

@@ -21,6 +21,7 @@ import javax.servlet.FilterChain;
 import org.jasig.cas.client.proxy.ProxyGrantingTicketStorage;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
@@ -32,12 +33,15 @@ import org.springframework.security.cas.ServiceProperties;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
+import org.springframework.security.web.context.SecurityContextRepository;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -182,6 +186,38 @@ public class CasAuthenticationFilterTests {
 		verify(successHandler).onAuthenticationSuccess(request, response, authentication);
 	}
 
+	@Test
+	public void testSecurityContextHolder() throws Exception {
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		AuthenticationManager manager = mock(AuthenticationManager.class);
+		Authentication authentication = new TestingAuthenticationToken("un", "pwd", "ROLE_USER");
+		given(manager.authenticate(any(Authentication.class))).willReturn(authentication);
+		ServiceProperties serviceProperties = new ServiceProperties();
+		serviceProperties.setAuthenticateAllArtifacts(true);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setParameter("ticket", "ST-1-123");
+		request.setServletPath("/authenticate");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain chain = mock(FilterChain.class);
+		CasAuthenticationFilter filter = new CasAuthenticationFilter();
+		filter.setServiceProperties(serviceProperties);
+		filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class));
+		filter.setAuthenticationManager(manager);
+		filter.setSecurityContextRepository(securityContextRepository);
+		filter.afterPropertiesSet();
+		filter.doFilter(request, response, chain);
+		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull()
+				.withFailMessage("Authentication should not be null");
+		verify(chain).doFilter(request, response);
+		// validate for when the filterProcessUrl matches
+		filter.setFilterProcessesUrl(request.getServletPath());
+		SecurityContextHolder.clearContext();
+		filter.doFilter(request, response, chain);
+		ArgumentCaptor<SecurityContext> contextArg = ArgumentCaptor.forClass(SecurityContext.class);
+		verify(securityContextRepository).saveContext(contextArg.capture(), eq(request), eq(response));
+		assertThat(contextArg.getValue().getAuthentication().getPrincipal()).isEqualTo(authentication.getName());
+	}
+
 	// SEC-1592
 	@Test
 	public void testChainNotInvokedForProxyReceptor() throws Exception {

+ 17 - 0
oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java

@@ -38,6 +38,8 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut
 import org.springframework.security.web.AuthenticationEntryPoint;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.util.Assert;
 import org.springframework.web.filter.OncePerRequestFilter;
 
@@ -75,6 +77,8 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
 
 	private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource = new WebAuthenticationDetailsSource();
 
+	private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
 	/**
 	 * Construct a {@code BearerTokenAuthenticationFilter} using the provided parameter(s)
 	 * @param authenticationManagerResolver
@@ -131,6 +135,7 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
 			SecurityContext context = SecurityContextHolder.createEmptyContext();
 			context.setAuthentication(authenticationResult);
 			SecurityContextHolder.setContext(context);
+			this.securityContextRepository.saveContext(context, request, response);
 			if (this.logger.isDebugEnabled()) {
 				this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authenticationResult));
 			}
@@ -143,6 +148,18 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
 		}
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * authentication success. The default action is not to save the
+	 * {@link SecurityContext}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	/**
 	 * Set the {@link BearerTokenResolver} to use. Defaults to
 	 * {@link DefaultBearerTokenResolver}.

+ 25 - 0
oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java

@@ -36,18 +36,23 @@ import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationManagerResolver;
 import org.springframework.security.authentication.AuthenticationServiceException;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken;
 import org.springframework.security.oauth2.server.resource.BearerTokenError;
 import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes;
 import org.springframework.security.web.AuthenticationEntryPoint;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.context.SecurityContextRepository;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 
@@ -102,6 +107,26 @@ public class BearerTokenAuthenticationFilterTests {
 		assertThat(captor.getValue().getPrincipal()).isEqualTo("token");
 	}
 
+	@Test
+	public void doFilterWhenSecurityContextRepositoryThenSaves() throws ServletException, IOException {
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		String token = "token";
+		given(this.bearerTokenResolver.resolve(this.request)).willReturn(token);
+		TestingAuthenticationToken expectedAuthentication = new TestingAuthenticationToken("test", "password");
+		given(this.authenticationManager.authenticate(any())).willReturn(expectedAuthentication);
+		BearerTokenAuthenticationFilter filter = addMocks(
+				new BearerTokenAuthenticationFilter(this.authenticationManager));
+		filter.setSecurityContextRepository(securityContextRepository);
+		filter.doFilter(this.request, this.response, this.filterChain);
+		ArgumentCaptor<BearerTokenAuthenticationToken> captor = ArgumentCaptor
+				.forClass(BearerTokenAuthenticationToken.class);
+		verify(this.authenticationManager).authenticate(captor.capture());
+		assertThat(captor.getValue().getPrincipal()).isEqualTo(token);
+		ArgumentCaptor<SecurityContext> contextArg = ArgumentCaptor.forClass(SecurityContext.class);
+		verify(securityContextRepository).saveContext(contextArg.capture(), eq(this.request), eq(this.response));
+		assertThat(contextArg.getValue().getAuthentication().getName()).isEqualTo(expectedAuthentication.getName());
+	}
+
 	@Test
 	public void doFilterWhenUsingAuthenticationManagerResolverThenAuthenticates() throws Exception {
 		BearerTokenAuthenticationFilter filter = addMocks(

+ 17 - 0
web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java

@@ -42,6 +42,8 @@ import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.authentication.session.NullAuthenticatedSessionStrategy;
 import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
@@ -134,6 +136,8 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
 
 	private AuthenticationFailureHandler failureHandler = new SimpleUrlAuthenticationFailureHandler();
 
+	private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
 	/**
 	 * @param defaultFilterProcessesUrl the default value for <tt>filterProcessesUrl</tt>.
 	 */
@@ -314,6 +318,7 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
 		SecurityContext context = SecurityContextHolder.createEmptyContext();
 		context.setAuthentication(authResult);
 		SecurityContextHolder.setContext(context);
+		this.securityContextRepository.saveContext(context, request, response);
 		if (this.logger.isDebugEnabled()) {
 			this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authResult));
 		}
@@ -435,6 +440,18 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
 		this.failureHandler = failureHandler;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * authentication success. The default action is not to save the
+	 * {@link SecurityContext}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	protected AuthenticationSuccessHandler getSuccessHandler() {
 		return this.successHandler;
 	}

+ 17 - 0
web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java

@@ -32,6 +32,8 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.AnyRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
@@ -74,6 +76,8 @@ public class AuthenticationFilter extends OncePerRequestFilter {
 	private AuthenticationFailureHandler failureHandler = new AuthenticationEntryPointFailureHandler(
 			new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED));
 
+	private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
 	private AuthenticationManagerResolver<HttpServletRequest> authenticationManagerResolver;
 
 	public AuthenticationFilter(AuthenticationManager authenticationManager,
@@ -135,6 +139,18 @@ public class AuthenticationFilter extends OncePerRequestFilter {
 		this.authenticationManagerResolver = authenticationManagerResolver;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * authentication success. The default action is not to save the
+	 * {@link SecurityContext}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
@@ -173,6 +189,7 @@ public class AuthenticationFilter extends OncePerRequestFilter {
 		SecurityContext context = SecurityContextHolder.createEmptyContext();
 		context.setAuthentication(authentication);
 		SecurityContextHolder.setContext(context);
+		this.securityContextRepository.saveContext(context, request, response);
 		this.successHandler.onAuthenticationSuccess(request, response, chain, authentication);
 	}
 

+ 17 - 0
web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java

@@ -40,6 +40,8 @@ import org.springframework.security.web.WebAttributes;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
 import org.springframework.web.filter.GenericFilterBean;
@@ -104,6 +106,8 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
 
 	private RequestMatcher requiresAuthenticationRequestMatcher = new PreAuthenticatedProcessingRequestMatcher();
 
+	private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
 	/**
 	 * Check whether all required properties have been set.
 	 */
@@ -210,6 +214,7 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
 		SecurityContext context = SecurityContextHolder.createEmptyContext();
 		context.setAuthentication(authResult);
 		SecurityContextHolder.setContext(context);
+		this.securityContextRepository.saveContext(context, request, response);
 		if (this.eventPublisher != null) {
 			this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass()));
 		}
@@ -242,6 +247,18 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
 		this.eventPublisher = anApplicationEventPublisher;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * authentication success. The default action is not to save the
+	 * {@link SecurityContext}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	/**
 	 * @param authenticationDetailsSource The AuthenticationDetailsSource to use
 	 */

+ 17 - 0
web/src/main/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilter.java

@@ -36,6 +36,8 @@ import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.authentication.RememberMeServices;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.util.Assert;
 import org.springframework.web.filter.GenericFilterBean;
 
@@ -73,6 +75,8 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements
 
 	private RememberMeServices rememberMeServices;
 
+	private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
 	public RememberMeAuthenticationFilter(AuthenticationManager authenticationManager,
 			RememberMeServices rememberMeServices) {
 		Assert.notNull(authenticationManager, "authenticationManager cannot be null");
@@ -114,6 +118,7 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements
 				onSuccessfulAuthentication(request, response, rememberMeAuth);
 				this.logger.debug(LogMessage.of(() -> "SecurityContextHolder populated with remember-me token: '"
 						+ SecurityContextHolder.getContext().getAuthentication() + "'"));
+				this.securityContextRepository.saveContext(context, request, response);
 				if (this.eventPublisher != null) {
 					this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(
 							SecurityContextHolder.getContext().getAuthentication(), this.getClass()));
@@ -179,4 +184,16 @@ public class RememberMeAuthenticationFilter extends GenericFilterBean implements
 		this.successHandler = successHandler;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * authentication success. The default action is not to save the
+	 * {@link SecurityContext}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 }

+ 17 - 0
web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java

@@ -36,6 +36,8 @@ import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.AuthenticationEntryPoint;
 import org.springframework.security.web.authentication.NullRememberMeServices;
 import org.springframework.security.web.authentication.RememberMeServices;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.util.Assert;
 import org.springframework.web.filter.OncePerRequestFilter;
 
@@ -103,6 +105,8 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
 
 	private BasicAuthenticationConverter authenticationConverter = new BasicAuthenticationConverter();
 
+	private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
 	/**
 	 * Creates an instance which will authenticate against the supplied
 	 * {@code AuthenticationManager} and which will ignore failed authentication attempts,
@@ -131,6 +135,18 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
 		this.authenticationEntryPoint = authenticationEntryPoint;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * authentication success. The default action is not to save the
+	 * {@link SecurityContext}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	@Override
 	public void afterPropertiesSet() {
 		Assert.notNull(this.authenticationManager, "An AuthenticationManager is required");
@@ -161,6 +177,7 @@ public class BasicAuthenticationFilter extends OncePerRequestFilter {
 					this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authResult));
 				}
 				this.rememberMeServices.loginSuccess(request, response, authResult);
+				this.securityContextRepository.saveContext(context, request, response);
 				onSuccessfulAuthentication(request, response, authResult);
 			}
 		}

+ 17 - 0
web/src/main/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilter.java

@@ -49,6 +49,8 @@ import org.springframework.security.core.userdetails.UserDetailsService;
 import org.springframework.security.core.userdetails.UsernameNotFoundException;
 import org.springframework.security.core.userdetails.cache.NullUserCache;
 import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 import org.springframework.web.filter.GenericFilterBean;
@@ -106,6 +108,8 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
 
 	private boolean createAuthenticatedToken = false;
 
+	private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
 	@Override
 	public void afterPropertiesSet() {
 		Assert.notNull(this.userDetailsService, "A UserDetailsService is required");
@@ -192,6 +196,7 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
 		SecurityContext context = SecurityContextHolder.createEmptyContext();
 		context.setAuthentication(authentication);
 		SecurityContextHolder.setContext(context);
+		this.securityContextRepository.saveContext(context, request, response);
 		chain.doFilter(request, response);
 	}
 
@@ -271,6 +276,18 @@ public class DigestAuthenticationFilter extends GenericFilterBean implements Mes
 		this.createAuthenticatedToken = createAuthenticatedToken;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * authentication success. The default action is not to save the
+	 * {@link SecurityContext}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	private class DigestData {
 
 		private final String username;

+ 35 - 0
web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java

@@ -27,6 +27,7 @@ import org.apache.commons.logging.Log;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
 
 import org.springframework.mock.web.MockFilterConfig;
 import org.springframework.mock.web.MockHttpServletRequest;
@@ -34,14 +35,17 @@ import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.BadCredentialsException;
 import org.springframework.security.authentication.InternalAuthenticationServiceException;
+import org.springframework.security.authentication.TestAuthentication;
 import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServicesTests;
 import org.springframework.security.web.authentication.rememberme.TokenBasedRememberMeServices;
 import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.firewall.DefaultHttpFirewall;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -322,6 +326,37 @@ public class AbstractAuthenticationProcessingFilterTests {
 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull();
 	}
 
+	@Test
+	public void testSuccessfulAuthenticationThenDefaultDoesNotCreateSession() throws Exception {
+		Authentication authentication = TestAuthentication.authenticatedUser();
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		MockFilterChain chain = new MockFilterChain(false);
+		MockAuthenticationFilter filter = new MockAuthenticationFilter();
+
+		filter.successfulAuthentication(request, response, chain, authentication);
+
+		assertThat(request.getSession(false)).isNull();
+	}
+
+	@Test
+	public void testSuccessfulAuthenticationWhenCustomSecurityContextRepositoryThenAuthenticationSaved()
+			throws Exception {
+		ArgumentCaptor<SecurityContext> contextCaptor = ArgumentCaptor.forClass(SecurityContext.class);
+		SecurityContextRepository repository = mock(SecurityContextRepository.class);
+		Authentication authentication = TestAuthentication.authenticatedUser();
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		MockFilterChain chain = new MockFilterChain(false);
+		MockAuthenticationFilter filter = new MockAuthenticationFilter();
+		filter.setSecurityContextRepository(repository);
+
+		filter.successfulAuthentication(request, response, chain, authentication);
+
+		verify(repository).saveContext(contextCaptor.capture(), eq(request), eq(response));
+		assertThat(contextCaptor.getValue().getAuthentication()).isEqualTo(authentication);
+	}
+
 	@Test
 	public void testFailedAuthenticationInvokesFailureHandler() throws Exception {
 		// Setup our HTTP request

+ 35 - 0
web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java

@@ -25,6 +25,7 @@ import javax.servlet.http.HttpServletRequest;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.junit.jupiter.MockitoExtension;
 
@@ -38,7 +39,9 @@ import org.springframework.security.authentication.AuthenticationManagerResolver
 import org.springframework.security.authentication.BadCredentialsException;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -256,4 +259,36 @@ public class AuthenticationFilterTests {
 		assertThat(session.getId()).isNotEqualTo(sessionId);
 	}
 
+	@Test
+	public void filterWhenSuccessfulAuthenticationThenNoSessionCreated() throws Exception {
+		Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER");
+		given(this.authenticationConverter.convert(any())).willReturn(authentication);
+		given(this.authenticationManager.authenticate(any())).willReturn(authentication);
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", "/");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain chain = new MockFilterChain();
+		AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager,
+				this.authenticationConverter);
+		filter.doFilter(request, response, chain);
+		assertThat(request.getSession(false)).isNull();
+	}
+
+	@Test
+	public void filterWhenCustomSecurityContextRepositoryAndSuccessfulAuthenticationRepositoryUsed() throws Exception {
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		ArgumentCaptor<SecurityContext> securityContextArg = ArgumentCaptor.forClass(SecurityContext.class);
+		Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER");
+		given(this.authenticationConverter.convert(any())).willReturn(authentication);
+		given(this.authenticationManager.authenticate(any())).willReturn(authentication);
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", "/");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain chain = new MockFilterChain();
+		AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager,
+				this.authenticationConverter);
+		filter.setSecurityContextRepository(securityContextRepository);
+		filter.doFilter(request, response, chain);
+		verify(securityContextRepository).saveContext(securityContextArg.capture(), eq(request), eq(response));
+		assertThat(securityContextArg.getValue().getAuthentication()).isEqualTo(authentication);
+	}
+
 }

+ 29 - 0
web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java

@@ -23,6 +23,7 @@ import javax.servlet.http.HttpServletRequest;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
 import org.mockito.stubbing.Answer;
 
 import org.springframework.mock.web.MockFilterChain;
@@ -34,17 +35,20 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.userdetails.User;
 import org.springframework.security.web.WebAttributes;
 import org.springframework.security.web.authentication.ForwardAuthenticationFailureHandler;
 import org.springframework.security.web.authentication.ForwardAuthenticationSuccessHandler;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -210,6 +214,31 @@ public class AbstractPreAuthenticatedProcessingFilterTests {
 		assertThat(response.getForwardedUrl()).isEqualTo("/forwardUrl");
 	}
 
+	@Test
+	public void securityContextRepository() throws Exception {
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		Object currentPrincipal = "currentUser";
+		TestingAuthenticationToken authRequest = new TestingAuthenticationToken(currentPrincipal, "something",
+				"ROLE_USER");
+		SecurityContextHolder.getContext().setAuthentication(authRequest);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		MockFilterChain chain = new MockFilterChain();
+		ConcretePreAuthenticatedProcessingFilter filter = new ConcretePreAuthenticatedProcessingFilter();
+		filter.setSecurityContextRepository(securityContextRepository);
+		filter.setAuthenticationSuccessHandler(new ForwardAuthenticationSuccessHandler("/forwardUrl"));
+		filter.setCheckForPrincipalChanges(true);
+		filter.principal = "newUser";
+		AuthenticationManager am = mock(AuthenticationManager.class);
+		given(am.authenticate(any())).willReturn(authRequest);
+		filter.setAuthenticationManager(am);
+		filter.afterPropertiesSet();
+		filter.doFilter(request, response, chain);
+		ArgumentCaptor<SecurityContext> contextArg = ArgumentCaptor.forClass(SecurityContext.class);
+		verify(securityContextRepository).saveContext(contextArg.capture(), eq(request), eq(response));
+		assertThat(contextArg.getValue().getAuthentication().getPrincipal()).isEqualTo(authRequest.getName());
+	}
+
 	@Test
 	public void callsAuthenticationFailureHandlerOnFailedAuthentication() throws Exception {
 		MockHttpServletRequest request = new MockHttpServletRequest();

+ 19 - 0
web/src/test/java/org/springframework/security/web/authentication/rememberme/RememberMeAuthenticationFilterTests.java

@@ -36,10 +36,12 @@ import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.authentication.NullRememberMeServices;
 import org.springframework.security.web.authentication.RememberMeServices;
 import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler;
+import org.springframework.security.web.context.SecurityContextRepository;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -152,6 +154,23 @@ public class RememberMeAuthenticationFilterTests {
 		verifyZeroInteractions(fc);
 	}
 
+	@Test
+	public void securityContextRepositoryInvokedIfSet() throws Exception {
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		AuthenticationManager am = mock(AuthenticationManager.class);
+		given(am.authenticate(this.remembered)).willReturn(this.remembered);
+		RememberMeAuthenticationFilter filter = new RememberMeAuthenticationFilter(am,
+				new MockRememberMeServices(this.remembered));
+		filter.setAuthenticationSuccessHandler(new SimpleUrlAuthenticationSuccessHandler("/target"));
+		filter.setSecurityContextRepository(securityContextRepository);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain fc = mock(FilterChain.class);
+		request.setRequestURI("x");
+		filter.doFilter(request, response, fc);
+		verify(securityContextRepository).saveContext(any(), eq(request), eq(response));
+	}
+
 	private class MockRememberMeServices implements RememberMeServices {
 
 		private Authentication authToReturn;

+ 24 - 0
web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java

@@ -27,6 +27,7 @@ import org.apache.commons.codec.binary.Base64;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
@@ -36,8 +37,10 @@ import org.springframework.security.authentication.BadCredentialsException;
 import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.authentication.WebAuthenticationDetails;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.web.util.WebUtils;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -364,4 +367,25 @@ public class BasicAuthenticationFilterTests {
 		assertThat(response.getStatus()).isEqualTo(401);
 	}
 
+	@Test
+	public void requestWhenSecurityContextRepository() throws Exception {
+		ArgumentCaptor<SecurityContext> contextArg = ArgumentCaptor.forClass(SecurityContext.class);
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		this.filter.setSecurityContextRepository(securityContextRepository);
+		String token = "rod:koala";
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes())));
+		request.setServletPath("/some_file.html");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		// Test
+		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
+		FilterChain chain = mock(FilterChain.class);
+		this.filter.doFilter(request, response, chain);
+		verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
+		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull();
+		assertThat(SecurityContextHolder.getContext().getAuthentication().getName()).isEqualTo("rod");
+		verify(securityContextRepository).saveContext(contextArg.capture(), eq(request), eq(response));
+		assertThat(contextArg.getValue().getAuthentication().getName()).isEqualTo("rod");
+	}
+
 }

+ 24 - 0
web/src/test/java/org/springframework/security/web/authentication/www/DigestAuthenticationFilterTests.java

@@ -29,6 +29,7 @@ import org.apache.commons.codec.digest.DigestUtils;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
@@ -40,10 +41,12 @@ import org.springframework.security.core.userdetails.User;
 import org.springframework.security.core.userdetails.UserDetails;
 import org.springframework.security.core.userdetails.UserDetailsService;
 import org.springframework.security.core.userdetails.cache.NullUserCache;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.util.StringUtils;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -389,4 +392,25 @@ public class DigestAuthenticationFilterTests {
 		assertThat(existingAuthentication).isSameAs(existingContext.getAuthentication());
 	}
 
+	@Test
+	public void testSecurityContextRepository() throws Exception {
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		ArgumentCaptor<SecurityContext> contextArg = ArgumentCaptor.forClass(SecurityContext.class);
+		String responseDigest = DigestAuthUtils.generateDigest(false, USERNAME, REALM, PASSWORD, "GET", REQUEST_URI,
+				QOP, NONCE, NC, CNONCE);
+		this.request.addHeader("Authorization",
+				createAuthorizationHeader(USERNAME, REALM, NONCE, REQUEST_URI, responseDigest, QOP, NC, CNONCE));
+		this.filter.setSecurityContextRepository(securityContextRepository);
+		this.filter.setCreateAuthenticatedToken(true);
+		MockHttpServletResponse response = executeFilterInContainerSimulator(this.filter, this.request, true);
+		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull();
+		assertThat(((UserDetails) SecurityContextHolder.getContext().getAuthentication().getPrincipal()).getUsername())
+				.isEqualTo(USERNAME);
+		assertThat(SecurityContextHolder.getContext().getAuthentication().isAuthenticated()).isTrue();
+		assertThat(SecurityContextHolder.getContext().getAuthentication().getAuthorities())
+				.isEqualTo(AuthorityUtils.createAuthorityList("ROLE_ONE", "ROLE_TWO"));
+		verify(securityContextRepository).saveContext(contextArg.capture(), eq(this.request), eq(response));
+		assertThat(contextArg.getValue().getAuthentication().getName()).isEqualTo(USERNAME);
+	}
+
 }