|
@@ -1,5 +1,5 @@
|
|
|
/*
|
|
|
- * Copyright 2002-2022 the original author or authors.
|
|
|
+ * Copyright 2002-2024 the original author or authors.
|
|
|
*
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
* you may not use this file except in compliance with the License.
|
|
@@ -25,6 +25,7 @@ import jakarta.servlet.http.HttpServletResponse;
|
|
|
|
|
|
import org.springframework.core.log.LogMessage;
|
|
|
import org.springframework.http.HttpStatus;
|
|
|
+import org.springframework.security.core.AuthenticationException;
|
|
|
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
|
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
|
|
import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
|
|
@@ -32,6 +33,7 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
|
|
|
import org.springframework.security.web.DefaultRedirectStrategy;
|
|
|
import org.springframework.security.web.RedirectStrategy;
|
|
|
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
|
|
|
import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
|
|
|
import org.springframework.security.web.savedrequest.RequestCache;
|
|
|
import org.springframework.security.web.util.ThrowableAnalyzer;
|
|
@@ -97,6 +99,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
|
|
|
|
|
private RequestCache requestCache = new HttpSessionRequestCache();
|
|
|
|
|
|
+ private AuthenticationFailureHandler authenticationFailureHandler = this::unsuccessfulRedirectForAuthorization;
|
|
|
+
|
|
|
/**
|
|
|
* Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided
|
|
|
* parameters.
|
|
@@ -163,6 +167,18 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
|
|
this.requestCache = requestCache;
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * Sets the {@link AuthenticationFailureHandler} used to handle errors redirecting to
|
|
|
+ * the Authorization Server's Authorization Endpoint.
|
|
|
+ * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used
|
|
|
+ * to handle errors redirecting to the Authorization Server's Authorization Endpoint
|
|
|
+ * @since 6.3
|
|
|
+ */
|
|
|
+ public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
|
|
|
+ Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
|
|
|
+ this.authenticationFailureHandler = authenticationFailureHandler;
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
|
|
|
throws ServletException, IOException {
|
|
@@ -174,7 +190,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
|
|
}
|
|
|
}
|
|
|
catch (Exception ex) {
|
|
|
- this.unsuccessfulRedirectForAuthorization(request, response, ex);
|
|
|
+ AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex);
|
|
|
+ this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException);
|
|
|
return;
|
|
|
}
|
|
|
try {
|
|
@@ -199,7 +216,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
|
|
this.sendRedirectForAuthorization(request, response, authorizationRequest);
|
|
|
}
|
|
|
catch (Exception failed) {
|
|
|
- this.unsuccessfulRedirectForAuthorization(request, response, failed);
|
|
|
+ AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex);
|
|
|
+ this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException);
|
|
|
}
|
|
|
return;
|
|
|
}
|
|
@@ -223,9 +241,10 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
|
|
}
|
|
|
|
|
|
private void unsuccessfulRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response,
|
|
|
- Exception ex) throws IOException {
|
|
|
- LogMessage message = LogMessage.format("Authorization Request failed: %s", ex);
|
|
|
- if (InvalidClientRegistrationIdException.class.isAssignableFrom(ex.getClass())) {
|
|
|
+ AuthenticationException ex) throws IOException {
|
|
|
+ Throwable cause = ex.getCause();
|
|
|
+ LogMessage message = LogMessage.format("Authorization Request failed: %s", cause);
|
|
|
+ if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) {
|
|
|
// Log an invalid registrationId at WARN level to allow these errors to be
|
|
|
// tuned separately from other errors
|
|
|
this.logger.warn(message, ex);
|
|
@@ -250,4 +269,12 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
|
|
|
|
|
|
}
|
|
|
|
|
|
+ private static final class OAuth2AuthorizationRequestException extends AuthenticationException {
|
|
|
+
|
|
|
+ OAuth2AuthorizationRequestException(Throwable cause) {
|
|
|
+ super(cause.getMessage(), cause);
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
}
|