Browse Source

AuthenticationFilter.securityContextRepository

Issue gh-10953
Rob Winch 3 years ago
parent
commit
7c5b939bbd

+ 17 - 0
web/src/main/java/org/springframework/security/web/authentication/AuthenticationFilter.java

@@ -32,6 +32,8 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.AnyRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
@@ -74,6 +76,8 @@ public class AuthenticationFilter extends OncePerRequestFilter {
 	private AuthenticationFailureHandler failureHandler = new AuthenticationEntryPointFailureHandler(
 			new HttpStatusEntryPoint(HttpStatus.UNAUTHORIZED));
 
+	private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
 	private AuthenticationManagerResolver<HttpServletRequest> authenticationManagerResolver;
 
 	public AuthenticationFilter(AuthenticationManager authenticationManager,
@@ -135,6 +139,18 @@ public class AuthenticationFilter extends OncePerRequestFilter {
 		this.authenticationManagerResolver = authenticationManagerResolver;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * authentication success. The default action is not to save the
+	 * {@link SecurityContext}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
@@ -173,6 +189,7 @@ public class AuthenticationFilter extends OncePerRequestFilter {
 		SecurityContext context = SecurityContextHolder.createEmptyContext();
 		context.setAuthentication(authentication);
 		SecurityContextHolder.setContext(context);
+		this.securityContextRepository.saveContext(context, request, response);
 		this.successHandler.onAuthenticationSuccess(request, response, chain, authentication);
 	}
 

+ 35 - 0
web/src/test/java/org/springframework/security/web/authentication/AuthenticationFilterTests.java

@@ -25,6 +25,7 @@ import jakarta.servlet.http.HttpServletRequest;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.junit.jupiter.MockitoExtension;
 
@@ -38,7 +39,9 @@ import org.springframework.security.authentication.AuthenticationManagerResolver
 import org.springframework.security.authentication.BadCredentialsException;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -256,4 +259,36 @@ public class AuthenticationFilterTests {
 		assertThat(session.getId()).isNotEqualTo(sessionId);
 	}
 
+	@Test
+	public void filterWhenSuccessfulAuthenticationThenNoSessionCreated() throws Exception {
+		Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER");
+		given(this.authenticationConverter.convert(any())).willReturn(authentication);
+		given(this.authenticationManager.authenticate(any())).willReturn(authentication);
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", "/");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain chain = new MockFilterChain();
+		AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager,
+				this.authenticationConverter);
+		filter.doFilter(request, response, chain);
+		assertThat(request.getSession(false)).isNull();
+	}
+
+	@Test
+	public void filterWhenCustomSecurityContextRepositoryAndSuccessfulAuthenticationRepositoryUsed() throws Exception {
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		ArgumentCaptor<SecurityContext> securityContextArg = ArgumentCaptor.forClass(SecurityContext.class);
+		Authentication authentication = new TestingAuthenticationToken("test", "this", "ROLE_USER");
+		given(this.authenticationConverter.convert(any())).willReturn(authentication);
+		given(this.authenticationManager.authenticate(any())).willReturn(authentication);
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", "/");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain chain = new MockFilterChain();
+		AuthenticationFilter filter = new AuthenticationFilter(this.authenticationManager,
+				this.authenticationConverter);
+		filter.setSecurityContextRepository(securityContextRepository);
+		filter.doFilter(request, response, chain);
+		verify(securityContextRepository).saveContext(securityContextArg.capture(), eq(request), eq(response));
+		assertThat(securityContextArg.getValue().getAuthentication()).isEqualTo(authentication);
+	}
+
 }