Browse Source

Provide extension for processing access token response

Issue gh-319
Joe Grandja 4 years ago
parent
commit
23732187a9

+ 38 - 6
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java

@@ -36,6 +36,7 @@ import org.springframework.http.converter.HttpMessageConverter;
 import org.springframework.http.server.ServletServerHttpResponse;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -55,6 +56,8 @@ import org.springframework.security.oauth2.server.authorization.authentication.O
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2RefreshTokenAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2RefreshTokenAuthenticationToken;
 import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
@@ -105,6 +108,8 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 			new OAuth2AccessTokenResponseHttpMessageConverter();
 	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
 			new OAuth2ErrorHttpMessageConverter();
+	private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAccessTokenResponse;
+	private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
 
 	/**
 	 * Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters.
@@ -155,16 +160,40 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 
 			OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
 					(OAuth2AccessTokenAuthenticationToken) this.authenticationManager.authenticate(authorizationGrantAuthentication);
-			sendAccessTokenResponse(response, accessTokenAuthentication);
-
+			this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, accessTokenAuthentication);
 		} catch (OAuth2AuthenticationException ex) {
 			SecurityContextHolder.clearContext();
-			sendErrorResponse(response, ex.getError());
+			this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
 		}
 	}
 
-	private void sendAccessTokenResponse(HttpServletResponse response,
-			OAuth2AccessTokenAuthenticationToken accessTokenAuthentication) throws IOException {
+	/**
+	 * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OAuth2AccessTokenAuthenticationToken}
+	 * and returning the {@link OAuth2AccessTokenResponse Access Token Response}.
+	 *
+	 * @param authenticationSuccessHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OAuth2AccessTokenAuthenticationToken}
+	 */
+	public final void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) {
+		Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null");
+		this.authenticationSuccessHandler = authenticationSuccessHandler;
+	}
+
+	/**
+	 * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
+	 * and returning the {@link OAuth2Error Error Response}.
+	 *
+	 * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
+	 */
+	public final void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
+		Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
+		this.authenticationFailureHandler = authenticationFailureHandler;
+	}
+
+	private void sendAccessTokenResponse(HttpServletRequest request, HttpServletResponse response,
+			Authentication authentication) throws IOException {
+
+		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+				(OAuth2AccessTokenAuthenticationToken) authentication;
 
 		OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken();
 		OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken();
@@ -188,7 +217,10 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 		this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse);
 	}
 
-	private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException {
+	private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
+			AuthenticationException exception) throws IOException {
+
+		OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
 		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
 		httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
 		this.errorHttpResponseConverter.write(error, null, httpResponse);

+ 63 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilterTests.java

@@ -57,6 +57,8 @@ import org.springframework.security.oauth2.server.authorization.authentication.O
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2RefreshTokenAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.util.StringUtils;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -108,6 +110,20 @@ public class OAuth2TokenEndpointFilterTests {
 				.hasMessage("tokenEndpointUri cannot be empty");
 	}
 
+	@Test
+	public void setAuthenticationSuccessHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.filter.setAuthenticationSuccessHandler(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authenticationSuccessHandler cannot be null");
+	}
+
+	@Test
+	public void setAuthenticationFailureHandlerWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.filter.setAuthenticationFailureHandler(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authenticationFailureHandler cannot be null");
+	}
+
 	@Test
 	public void doFilterWhenNotTokenRequestThenNotProcessed() throws Exception {
 		String requestUri = "/path";
@@ -397,6 +413,53 @@ public class OAuth2TokenEndpointFilterTests {
 		assertThat(refreshTokenResult.getTokenValue()).isEqualTo(refreshToken.getTokenValue());
 	}
 
+	@Test
+	public void doFilterWhenCustomAuthenticationSuccessHandlerThenUsed() throws Exception {
+		AuthenticationSuccessHandler authenticationSuccessHandler = mock(AuthenticationSuccessHandler.class);
+		this.filter.setAuthenticationSuccessHandler(authenticationSuccessHandler);
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+		Authentication clientPrincipal = new OAuth2ClientAuthenticationToken(registeredClient);
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(
+				OAuth2AccessToken.TokenType.BEARER, "token",
+				Instant.now(), Instant.now().plus(Duration.ofHours(1)),
+				new HashSet<>(Arrays.asList("scope1", "scope2")));
+		OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
+				new OAuth2AccessTokenAuthenticationToken(registeredClient, clientPrincipal, accessToken);
+
+		when(this.authenticationManager.authenticate(any())).thenReturn(accessTokenAuthentication);
+
+		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+		securityContext.setAuthentication(clientPrincipal);
+		SecurityContextHolder.setContext(securityContext);
+
+		MockHttpServletRequest request = createAuthorizationCodeTokenRequest(registeredClient);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(authenticationSuccessHandler).onAuthenticationSuccess(any(), any(), any());
+	}
+
+	@Test
+	public void doFilterWhenCustomAuthenticationFailureHandlerThenUsed() throws Exception {
+		AuthenticationFailureHandler authenticationFailureHandler = mock(AuthenticationFailureHandler.class);
+		this.filter.setAuthenticationFailureHandler(authenticationFailureHandler);
+
+		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
+
+		MockHttpServletRequest request = createAuthorizationCodeTokenRequest(registeredClient);
+		request.removeParameter(OAuth2ParameterNames.GRANT_TYPE);
+
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+
+		this.filter.doFilter(request, response, filterChain);
+
+		verify(authenticationFailureHandler).onAuthenticationFailure(any(), any(), any());
+	}
+
 	private void doFilterWhenTokenRequestInvalidParameterThenError(String parameterName, String errorCode,
 			MockHttpServletRequest request) throws Exception {