瀏覽代碼

CasAuthenticationFilter.securityContextRepository

Issue gh-10953
Rob Winch 3 年之前
父節點
當前提交
2e9b04ed48

+ 17 - 0
cas/src/main/java/org/springframework/security/cas/web/CasAuthenticationFilter.java

@@ -42,6 +42,8 @@ import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.SimpleUrlAuthenticationFailureHandler;
+import org.springframework.security.web.context.NullSecurityContextRepository;
+import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
@@ -205,6 +207,8 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil
 
 	private AuthenticationFailureHandler proxyFailureHandler = new SimpleUrlAuthenticationFailureHandler();
 
+	private SecurityContextRepository securityContextRepository = new NullSecurityContextRepository();
+
 	public CasAuthenticationFilter() {
 		super("/login/cas");
 		setAuthenticationFailureHandler(new SimpleUrlAuthenticationFailureHandler());
@@ -223,6 +227,7 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil
 		SecurityContext context = SecurityContextHolder.createEmptyContext();
 		context.setAuthentication(authResult);
 		SecurityContextHolder.setContext(context);
+		this.securityContextRepository.saveContext(context, request, response);
 		if (this.eventPublisher != null) {
 			this.eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(authResult, this.getClass()));
 		}
@@ -274,6 +279,18 @@ public class CasAuthenticationFilter extends AbstractAuthenticationProcessingFil
 		return result;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextRepository} to save the {@link SecurityContext} on
+	 * authentication success. The default action is not to save the
+	 * {@link SecurityContext}.
+	 * @param securityContextRepository the {@link SecurityContextRepository} to use.
+	 * Cannot be null.
+	 */
+	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	/**
 	 * Sets the {@link AuthenticationFailureHandler} for proxy requests.
 	 * @param proxyFailureHandler

+ 36 - 0
cas/src/test/java/org/springframework/security/cas/web/CasAuthenticationFilterTests.java

@@ -21,6 +21,7 @@ import javax.servlet.FilterChain;
 import org.jasig.cas.client.proxy.ProxyGrantingTicketStorage;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
@@ -32,12 +33,15 @@ import org.springframework.security.cas.ServiceProperties;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
+import org.springframework.security.web.context.SecurityContextRepository;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -182,6 +186,38 @@ public class CasAuthenticationFilterTests {
 		verify(successHandler).onAuthenticationSuccess(request, response, authentication);
 	}
 
+	@Test
+	public void testSecurityContextHolder() throws Exception {
+		SecurityContextRepository securityContextRepository = mock(SecurityContextRepository.class);
+		AuthenticationManager manager = mock(AuthenticationManager.class);
+		Authentication authentication = new TestingAuthenticationToken("un", "pwd", "ROLE_USER");
+		given(manager.authenticate(any(Authentication.class))).willReturn(authentication);
+		ServiceProperties serviceProperties = new ServiceProperties();
+		serviceProperties.setAuthenticateAllArtifacts(true);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setParameter("ticket", "ST-1-123");
+		request.setServletPath("/authenticate");
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain chain = mock(FilterChain.class);
+		CasAuthenticationFilter filter = new CasAuthenticationFilter();
+		filter.setServiceProperties(serviceProperties);
+		filter.setProxyGrantingTicketStorage(mock(ProxyGrantingTicketStorage.class));
+		filter.setAuthenticationManager(manager);
+		filter.setSecurityContextRepository(securityContextRepository);
+		filter.afterPropertiesSet();
+		filter.doFilter(request, response, chain);
+		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull()
+				.withFailMessage("Authentication should not be null");
+		verify(chain).doFilter(request, response);
+		// validate for when the filterProcessUrl matches
+		filter.setFilterProcessesUrl(request.getServletPath());
+		SecurityContextHolder.clearContext();
+		filter.doFilter(request, response, chain);
+		ArgumentCaptor<SecurityContext> contextArg = ArgumentCaptor.forClass(SecurityContext.class);
+		verify(securityContextRepository).saveContext(contextArg.capture(), eq(request), eq(response));
+		assertThat(contextArg.getValue().getAuthentication().getPrincipal()).isEqualTo(authentication.getName());
+	}
+
 	// SEC-1592
 	@Test
 	public void testChainNotInvokedForProxyReceptor() throws Exception {