Pārlūkot izejas kodu

Add AuthenticationSuccessHandler support to AbstractPreAuthenticatedProcessingFilter

Fixes gh-3389
Shazin Sadakath 9 gadi atpakaļ
vecāks
revīzija
1bc7060c93

+ 28 - 4
web/src/main/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilter.java

@@ -33,7 +33,7 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.WebAttributes;
-import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
+import org.springframework.security.web.authentication.*;
 import org.springframework.util.Assert;
 import org.springframework.web.filter.GenericFilterBean;
 
@@ -84,6 +84,8 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
 	private boolean continueFilterChainOnUnsuccessfulAuthentication = true;
 	private boolean checkForPrincipalChanges;
 	private boolean invalidateSessionOnPrincipalChange = true;
+	private AuthenticationSuccessHandler authenticationSuccessHandler = null;
+	private AuthenticationFailureHandler authenticationFailureHandler = null;
 
 	/**
 	 * Check whether all required properties have been set.
@@ -156,7 +158,7 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
 	/**
 	 * Do the actual authentication for a pre-authenticated user.
 	 */
-	private void doAuthenticate(HttpServletRequest request, HttpServletResponse response) {
+	private void doAuthenticate(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException {
 		Authentication authResult;
 
 		Object principal = getPreAuthenticatedPrincipal(request);
@@ -229,7 +231,7 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
 	 * manager into the secure context.
 	 */
 	protected void successfulAuthentication(HttpServletRequest request,
-			HttpServletResponse response, Authentication authResult) {
+			HttpServletResponse response, Authentication authResult) throws IOException, ServletException {
 		if (logger.isDebugEnabled()) {
 			logger.debug("Authentication success: " + authResult);
 		}
@@ -239,6 +241,10 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
 			eventPublisher.publishEvent(new InteractiveAuthenticationSuccessEvent(
 					authResult, this.getClass()));
 		}
+
+		if(authenticationSuccessHandler != null) {
+			authenticationSuccessHandler.onAuthenticationSuccess(request, response, authResult);
+		}
 	}
 
 	/**
@@ -248,13 +254,17 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
 	 * Caches the failure exception as a request attribute
 	 */
 	protected void unsuccessfulAuthentication(HttpServletRequest request,
-			HttpServletResponse response, AuthenticationException failed) {
+			HttpServletResponse response, AuthenticationException failed) throws IOException, ServletException {
 		SecurityContextHolder.clearContext();
 
 		if (logger.isDebugEnabled()) {
 			logger.debug("Cleared security context due to exception", failed);
 		}
 		request.setAttribute(WebAttributes.AUTHENTICATION_EXCEPTION, failed);
+
+		if(authenticationFailureHandler != null) {
+			authenticationFailureHandler.onAuthenticationFailure(request, response, failed);
+		}
 	}
 
 	/**
@@ -324,6 +334,20 @@ public abstract class AbstractPreAuthenticatedProcessingFilter extends GenericFi
 		this.invalidateSessionOnPrincipalChange = invalidateSessionOnPrincipalChange;
 	}
 
+	/**
+	 * Sets the strategy used to handle a successful authentication.
+	 */
+	public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) {
+		this.authenticationSuccessHandler = authenticationSuccessHandler;
+	}
+
+	/**
+	 * Sets the strategy used to handle a failed authentication.
+	 */
+	public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
+		this.authenticationFailureHandler = authenticationFailureHandler;
+	}
+
 	/**
 	 * Override to extract the principal information from the current request
 	 */

+ 49 - 0
web/src/test/java/org/springframework/security/web/authentication/preauth/AbstractPreAuthenticatedProcessingFilterTests.java

@@ -40,9 +40,13 @@ import org.springframework.security.authentication.BadCredentialsException;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.userdetails.User;
+import org.springframework.security.web.WebAttributes;
+import org.springframework.security.web.authentication.ForwardAuthenticationFailureHandler;
+import org.springframework.security.web.authentication.ForwardAuthenticationSuccessHandler;
 
 /**
  *
@@ -206,6 +210,51 @@ public class AbstractPreAuthenticatedProcessingFilterTests {
 		verify(am).authenticate(any(PreAuthenticatedAuthenticationToken.class));
 	}
 
+	@Test
+	public void callsAuthenticationSuccessHandlerOnSuccessfulAuthentication() throws Exception {
+		Object currentPrincipal = "currentUser";
+		TestingAuthenticationToken authRequest = new TestingAuthenticationToken(
+				currentPrincipal, "something", "ROLE_USER");
+		SecurityContextHolder.getContext().setAuthentication(authRequest);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		MockFilterChain chain = new MockFilterChain();
+
+		ConcretePreAuthenticatedProcessingFilter filter = new ConcretePreAuthenticatedProcessingFilter();
+		filter.setAuthenticationSuccessHandler(new ForwardAuthenticationSuccessHandler("/forwardUrl"));
+		filter.setCheckForPrincipalChanges(true);
+		filter.principal = "newUser";
+		AuthenticationManager am = mock(AuthenticationManager.class);
+		filter.setAuthenticationManager(am);
+		filter.afterPropertiesSet();
+
+		filter.doFilter(request, response, chain);
+
+		verify(am).authenticate(any(PreAuthenticatedAuthenticationToken.class));
+		assertThat(response.getForwardedUrl()).isEqualTo("/forwardUrl");
+	}
+
+	@Test
+	public void callsAuthenticationFailureHandlerOnFailedAuthentication() throws Exception {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		MockFilterChain chain = new MockFilterChain();
+
+		ConcretePreAuthenticatedProcessingFilter filter = new ConcretePreAuthenticatedProcessingFilter();
+		filter.setAuthenticationFailureHandler(new ForwardAuthenticationFailureHandler("/forwardUrl"));
+		filter.setCheckForPrincipalChanges(true);
+		AuthenticationManager am = mock(AuthenticationManager.class);
+		when(am.authenticate(any(PreAuthenticatedAuthenticationToken.class))).thenThrow(new PreAuthenticatedCredentialsNotFoundException("invalid"));
+		filter.setAuthenticationManager(am);
+		filter.afterPropertiesSet();
+
+		filter.doFilter(request, response, chain);
+
+		verify(am).authenticate(any(PreAuthenticatedAuthenticationToken.class));
+		assertThat(response.getForwardedUrl()).isEqualTo("/forwardUrl");
+		assertThat(request.getAttribute(WebAttributes.AUTHENTICATION_EXCEPTION)).isNotNull();
+	}
+
 	// SEC-2078
 	@Test
 	public void requiresAuthenticationFalsePrincipalNotString() throws Exception {