浏览代码

AbstractPreAuthenticatedProcessingFilter.securityContextRepository

Issue gh-10953
Rob Winch 3 年之前
父节点
当前提交
4462b73fd9

+ 17 - 0
web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java

@@ -40,6 +40,8 @@ import org.springframework.security.web.WebAttributes;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
 import org.springframework.web.filter.GenericFilterBean;
@@ -104,6 +106,8 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
 
 	private RequestMatcher requiresAuthenticationRequestMatcher = new PreAuthenticatedProcessingRequestMatcher();
 
+	private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
 	/**
 	 * Check whether all required properties have been set.
 	 */
@@ -210,6 +214,7 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
 		SecurityContext context = SecurityContextHolder.createEmptyContext();
 		context.setAuthentication(authResult);
 		SecurityContextHolder.setContext(context);
+		this.securityContextRepository.saveContext(context, request, response);
 		if (this.eventPublisher != null) {
 			this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass()));
 		}
@@ -242,6 +247,18 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
 		this.eventPublisher = anApplicationEventPublisher;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * authentication success. The default action is not to save the
+	 * {@link SecurityContext}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	/**
 	 * @param authenticationDetailsSource The AuthenticationDetailsSource to use
 	 */

+ 29 - 0
web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java

@@ -23,6 +23,7 @@ import jakarta.servlet.http.HttpServletRequest;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
 import org.mockito.stubbing.Answer;
 
 import org.springframework.mock.web.MockFilterChain;
@@ -34,17 +35,20 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.userdetails.User;
 import org.springframework.security.web.WebAttributes;
 import org.springframework.security.web.authentication.ForwardAuthenticationFailureHandler;
 import org.springframework.security.web.authentication.ForwardAuthenticationSuccessHandler;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -210,6 +214,31 @@ public class AbstractPreAuthenticatedProcessingFilterTests {
 		assertThat(response.getForwardedUrl()).isEqualTo("/forwardUrl");
 	}
 
+	@Test
+	public void securityContextRepository() throws Exception {
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		Object currentPrincipal = "currentUser";
+		TestingAuthenticationToken authRequest = new TestingAuthenticationToken(currentPrincipal, "something",
+				"ROLE_USER");
+		SecurityContextHolder.getContext().setAuthentication(authRequest);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		MockFilterChain chain = new MockFilterChain();
+		ConcretePreAuthenticatedProcessingFilter filter = new ConcretePreAuthenticatedProcessingFilter();
+		filter.setSecurityContextRepository(securityContextRepository);
+		filter.setAuthenticationSuccessHandler(new ForwardAuthenticationSuccessHandler("/forwardUrl"));
+		filter.setCheckForPrincipalChanges(true);
+		filter.principal = "newUser";
+		AuthenticationManager am = mock(AuthenticationManager.class);
+		given(am.authenticate(any())).willReturn(authRequest);
+		filter.setAuthenticationManager(am);
+		filter.afterPropertiesSet();
+		filter.doFilter(request, response, chain);
+		ArgumentCaptor<SecurityContext> contextArg = ArgumentCaptor.forClass(SecurityContext.class);
+		verify(securityContextRepository).saveContext(contextArg.capture(), eq(request), eq(response));
+		assertThat(contextArg.getValue().getAuthentication().getPrincipal()).isEqualTo(authRequest.getName());
+	}
+
 	@Test
 	public void callsAuthenticationFailureHandlerOnFailedAuthentication() throws Exception {
 		MockHttpServletRequest request = new MockHttpServletRequest();