瀏覽代碼

AbstractAuthenticationProcessingFilter.securityContextRepository

Issue gh-10953
Rob Winch 3 年之前
父節點
當前提交
cbba7ea4de

+ 17 - 0
web/src/main/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilter.java

@@ -42,6 +42,8 @@ import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.authentication.session.NullAuthenticatedSessionStrategy;
 import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
@@ -134,6 +136,8 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
 
 	private AuthenticationFailureHandler failureHandler = new SimpleUrlAuthenticationFailureHandler();
 
+	private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
 	/**
 	 * @param defaultFilterProcessesUrl the default value for <tt>filterProcessesUrl</tt>.
 	 */
@@ -314,6 +318,7 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
 		SecurityContext context = SecurityContextHolder.createEmptyContext();
 		context.setAuthentication(authResult);
 		SecurityContextHolder.setContext(context);
+		this.securityContextRepository.saveContext(context, request, response);
 		if (this.logger.isDebugEnabled()) {
 			this.logger.debug(LogMessage.format("Set SecurityContextHolder to %s", authResult));
 		}
@@ -435,6 +440,18 @@ public abstract class AbstractAuthenticationProcessingFilter extends GenericFilt
 		this.failureHandler = failureHandler;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * authentication success. The default action is not to save the
+	 * {@link SecurityContext}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	protected AuthenticationSuccessHandler getSuccessHandler() {
 		return this.successHandler;
 	}

+ 35 - 0
web/src/test/java/org/springframework/security/web/authentication/AbstractAuthenticationProcessingFilterTests.java

@@ -27,6 +27,7 @@ import org.apache.commons.logging.Log;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
 
 import org.springframework.mock.web.MockFilterConfig;
 import org.springframework.mock.web.MockHttpServletRequest;
@@ -34,14 +35,17 @@ import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.BadCredentialsException;
 import org.springframework.security.authentication.InternalAuthenticationServiceException;
+import org.springframework.security.authentication.TestAuthentication;
 import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServicesTests;
 import org.springframework.security.web.authentication.rememberme.TokenBasedRememberMeServices;
 import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.firewall.DefaultHttpFirewall;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -322,6 +326,37 @@ public class AbstractAuthenticationProcessingFilterTests {
 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull();
 	}
 
+	@Test
+	public void testSuccessfulAuthenticationThenDefaultDoesNotCreateSession() throws Exception {
+		Authentication authentication = TestAuthentication.authenticatedUser();
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		MockFilterChain chain = new MockFilterChain(false);
+		MockAuthenticationFilter filter = new MockAuthenticationFilter();
+
+		filter.successfulAuthentication(request, response, chain, authentication);
+
+		assertThat(request.getSession(false)).isNull();
+	}
+
+	@Test
+	public void testSuccessfulAuthenticationWhenCustomSecurityContextRepositoryThenAuthenticationSaved()
+			throws Exception {
+		ArgumentCaptor<SecurityContext> contextCaptor = ArgumentCaptor.forClass(SecurityContext.class);
+		SecurityContextRepository repository = mock(SecurityContextRepository.class);
+		Authentication authentication = TestAuthentication.authenticatedUser();
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		MockFilterChain chain = new MockFilterChain(false);
+		MockAuthenticationFilter filter = new MockAuthenticationFilter();
+		filter.setSecurityContextRepository(repository);
+
+		filter.successfulAuthentication(request, response, chain, authentication);
+
+		verify(repository).saveContext(contextCaptor.capture(), eq(request), eq(response));
+		assertThat(contextCaptor.getValue().getAuthentication()).isEqualTo(authentication);
+	}
+
 	@Test
 	public void testFailedAuthenticationInvokesFailureHandler() throws Exception {
 		// Setup our HTTP request