Browse Source

Provide extension for processing authorization response

Issue gh-342
Joe Grandja 4 years ago
parent
commit
fb276e7a4a

+ 33 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilter.java

@@ -34,6 +34,7 @@ import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.oidc.OidcScopes;
 import org.springframework.security.oauth2.server.authorization.OAuth2AuthorizationCodeRequestAuthenticationException;
@@ -45,6 +46,8 @@ import org.springframework.security.oauth2.server.authorization.web.authenticati
 import org.springframework.security.web.DefaultRedirectStrategy;
 import org.springframework.security.web.RedirectStrategy;
 import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.util.RedirectUrlBuilder;
 import org.springframework.security.web.util.UrlUtils;
 import org.springframework.security.web.util.matcher.AndRequestMatcher;
@@ -82,6 +85,8 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 	private final RequestMatcher authorizationEndpointMatcher;
 	private final RedirectStrategy redirectStrategy = new DefaultRedirectStrategy();
 	private AuthenticationConverter authenticationConverter;
+	private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAuthorizationResponse;
+	private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
 	private String consentPage;
 
 	/**
@@ -185,11 +190,12 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 				return;
 			}
 
-			sendAuthorizationResponse(request, response, authorizationCodeRequestAuthenticationResult);
+			this.authenticationSuccessHandler.onAuthenticationSuccess(
+					request, response, authorizationCodeRequestAuthenticationResult);
 
 		} catch (OAuth2AuthenticationException ex) {
 			SecurityContextHolder.clearContext();
-			sendErrorResponse(request, response, ex);
+			this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
 		}
 	}
 
@@ -204,6 +210,28 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 		this.authenticationConverter = authenticationConverter;
 	}
 
+	/**
+	 * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OAuth2AuthorizationCodeRequestAuthenticationToken}
+	 * and returning the {@link OAuth2AuthorizationResponse Authorization Response}.
+	 *
+	 * @param authenticationSuccessHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OAuth2AuthorizationCodeRequestAuthenticationToken}
+	 */
+	public final void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) {
+		Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null");
+		this.authenticationSuccessHandler = authenticationSuccessHandler;
+	}
+
+	/**
+	 * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthorizationCodeRequestAuthenticationException}
+	 * and returning the {@link OAuth2Error Error Response}.
+	 *
+	 * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthorizationCodeRequestAuthenticationException}
+	 */
+	public final void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
+		Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
+		this.authenticationFailureHandler = authenticationFailureHandler;
+	}
+
 	/**
 	 * Specify the URI to redirect Resource Owners to if consent is required. A default consent
 	 * page will be generated when this attribute is not specified.
@@ -255,8 +283,10 @@ public class OAuth2AuthorizationEndpointFilter extends OncePerRequestFilter {
 	}
 
 	private void sendAuthorizationResponse(HttpServletRequest request, HttpServletResponse response,
-			OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication) throws IOException {
+			Authentication authentication) throws IOException {
 
+		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
+				(OAuth2AuthorizationCodeRequestAuthenticationToken) authentication;
 		UriComponentsBuilder uriBuilder = UriComponentsBuilder
 				.fromUriString(authorizationCodeRequestAuthentication.getRedirectUri())
 				.queryParam(OAuth2ParameterNames.CODE, authorizationCodeRequestAuthentication.getAuthorizationCode().getTokenValue());

+ 68 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2AuthorizationEndpointFilterTests.java

@@ -53,11 +53,14 @@ import org.springframework.security.oauth2.server.authorization.authentication.O
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.util.StringUtils;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.same;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
@@ -118,6 +121,20 @@ public class OAuth2AuthorizationEndpointFilterTests {
 				.hasMessage("authenticationConverter cannot be null");
 	}
 
+	@Test
+	public void setAuthenticationSuccessHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authenticationSuccessHandler cannot be null");
+	}
+
+	@Test
+	public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.filter.setAuthenticationFailureHandler(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authenticationFailureHandler cannot be null");
+	}
+
 	@Test
 	public void doFilterWhenNotAuthorizationRequestThenNotProcessed() throws Exception {
 		String requestUri = "/path";
@@ -275,6 +292,57 @@ public class OAuth2AuthorizationEndpointFilterTests {
 		verify(filterChain).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
+	@Test
+	public void doFilterWhenCustomAuthenticationSuccessHandlerThenUsed() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthenticationResult =
+				authorizationCodeRequestAuthentication(registeredClient, this.principal)
+						.authorizationCode(this.authorizationCode)
+						.build();
+		authorizationCodeRequestAuthenticationResult.setAuthenticated(true);
+		when(this.authenticationManager.authenticate(any()))
+				.thenReturn(authorizationCodeRequestAuthenticationResult);
+
+		AuthenticationSuccessHandler authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class);
+		this.filter.setAuthenticationSuccessHandler(authenticationSuccessHandler);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(this.authenticationManager).authenticate(any());
+		verifyNoInteractions(filterChain);
+		verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), same(authorizationCodeRequestAuthenticationResult));
+	}
+
+	@Test
+	public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exception {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		OAuth2AuthorizationCodeRequestAuthenticationToken authorizationCodeRequestAuthentication =
+				authorizationCodeRequestAuthentication(registeredClient, this.principal)
+						.build();
+		OAuth2Error error = new OAuth2Error("errorCode", "errorDescription", "errorUri");
+		OAuth2AuthorizationCodeRequestAuthenticationException authenticationException =
+				new OAuth2AuthorizationCodeRequestAuthenticationException(error, authorizationCodeRequestAuthentication);
+		when(this.authenticationManager.authenticate(any()))
+				.thenThrow(authenticationException);
+
+		AuthenticationFailureHandler authenticationFailureHandler = mock(AuthenticationFailureHandler.class);
+		this.filter.setAuthenticationFailureHandler(authenticationFailureHandler);
+
+		MockHttpServletRequest request = createAuthorizationRequest(registeredClient);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(this.authenticationManager).authenticate(any());
+		verifyNoInteractions(filterChain);
+		verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), same(authenticationException));
+	}
+
 	@Test
 	public void doFilterWhenAuthorizationRequestPrincipalNotAuthenticatedThenCommenceAuthentication() throws Exception {
 		this.principal.setAuthenticated(false);