|
@@ -36,6 +36,7 @@ import org.springframework.http.converter.HttpMessageConverter;
|
|
|
import org.springframework.http.server.ServletServerHttpResponse;
|
|
|
import org.springframework.security.authentication.AuthenticationManager;
|
|
|
import org.springframework.security.core.Authentication;
|
|
|
+import org.springframework.security.core.AuthenticationException;
|
|
|
import org.springframework.security.core.context.SecurityContextHolder;
|
|
|
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
|
|
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
|
@@ -55,6 +56,8 @@ import org.springframework.security.oauth2.server.authorization.authentication.O
|
|
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2RefreshTokenAuthenticationProvider;
|
|
|
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2RefreshTokenAuthenticationToken;
|
|
|
import org.springframework.security.web.authentication.AuthenticationConverter;
|
|
|
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
|
|
+import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
|
|
|
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
|
|
import org.springframework.security.web.util.matcher.RequestMatcher;
|
|
|
import org.springframework.util.Assert;
|
|
@@ -105,6 +108,8 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
|
|
|
new OAuth2AccessTokenResponseHttpMessageConverter();
|
|
|
private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
|
|
|
new OAuth2ErrorHttpMessageConverter();
|
|
|
+ private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAccessTokenResponse;
|
|
|
+ private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
|
|
|
|
|
|
/**
|
|
|
* Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters.
|
|
@@ -155,16 +160,40 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
|
|
|
|
|
|
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
|
|
|
(OAuth2AccessTokenAuthenticationToken) this.authenticationManager.authenticate(authorizationGrantAuthentication);
|
|
|
- sendAccessTokenResponse(response, accessTokenAuthentication);
|
|
|
-
|
|
|
+ this.authenticationSuccessHandler.onAuthenticationSuccess(request, response, accessTokenAuthentication);
|
|
|
} catch (OAuth2AuthenticationException ex) {
|
|
|
SecurityContextHolder.clearContext();
|
|
|
- sendErrorResponse(response, ex.getError());
|
|
|
+ this.authenticationFailureHandler.onAuthenticationFailure(request, response, ex);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- private void sendAccessTokenResponse(HttpServletResponse response,
|
|
|
- OAuth2AccessTokenAuthenticationToken accessTokenAuthentication) throws IOException {
|
|
|
+ /**
|
|
|
+ * Sets the {@link AuthenticationSuccessHandler} used for handling an {@link OAuth2AccessTokenAuthenticationToken}
|
|
|
+ * and returning the {@link OAuth2AccessTokenResponse Access Token Response}.
|
|
|
+ *
|
|
|
+ * @param authenticationSuccessHandler the {@link AuthenticationSuccessHandler} used for handling an {@link OAuth2AccessTokenAuthenticationToken}
|
|
|
+ */
|
|
|
+ public final void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) {
|
|
|
+ Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null");
|
|
|
+ this.authenticationSuccessHandler = authenticationSuccessHandler;
|
|
|
+ }
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Sets the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
|
|
|
+ * and returning the {@link OAuth2Error Error Response}.
|
|
|
+ *
|
|
|
+ * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
|
|
|
+ */
|
|
|
+ public final void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
|
|
|
+ Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
|
|
|
+ this.authenticationFailureHandler = authenticationFailureHandler;
|
|
|
+ }
|
|
|
+
|
|
|
+ private void sendAccessTokenResponse(HttpServletRequest request, HttpServletResponse response,
|
|
|
+ Authentication authentication) throws IOException {
|
|
|
+
|
|
|
+ OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
|
|
|
+ (OAuth2AccessTokenAuthenticationToken) authentication;
|
|
|
|
|
|
OAuth2AccessToken accessToken = accessTokenAuthentication.getAccessToken();
|
|
|
OAuth2RefreshToken refreshToken = accessTokenAuthentication.getRefreshToken();
|
|
@@ -188,7 +217,10 @@ public class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
|
|
|
this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse);
|
|
|
}
|
|
|
|
|
|
- private void sendErrorResponse(HttpServletResponse response, OAuth2Error error) throws IOException {
|
|
|
+ private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
|
|
|
+ AuthenticationException exception) throws IOException {
|
|
|
+
|
|
|
+ OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
|
|
|
ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
|
|
|
httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
|
|
|
this.errorHttpResponseConverter.write(error, null, httpResponse);
|