Răsfoiți Sursa

Introduce Customizable AuthorizationFailureHandler

Closes gh-13793
greg.lee 1 an în urmă
părinte
comite
07ac0b616b

+ 33 - 6
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -25,6 +25,7 @@ import jakarta.servlet.http.HttpServletResponse;
 
 import org.springframework.core.log.LogMessage;
 import org.springframework.http.HttpStatus;
+import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
@@ -32,6 +33,7 @@ import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.web.DefaultRedirectStrategy;
 import org.springframework.security.web.RedirectStrategy;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
 import org.springframework.security.web.savedrequest.RequestCache;
 import org.springframework.security.web.util.ThrowableAnalyzer;
@@ -97,6 +99,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
 
 	private RequestCache requestCache = new HttpSessionRequestCache();
 
+	private AuthenticationFailureHandler authenticationFailureHandler = this::unsuccessfulRedirectForAuthorization;
+
 	/**
 	 * Constructs an {@code OAuth2AuthorizationRequestRedirectFilter} using the provided
 	 * parameters.
@@ -163,6 +167,18 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
 		this.requestCache = requestCache;
 	}
 
+	/**
+	 * Sets the {@link AuthenticationFailureHandler} used to handle errors redirecting to
+	 * the Authorization Server's Authorization Endpoint.
+	 * @param authenticationFailureHandler the {@link AuthenticationFailureHandler} used
+	 * to handle errors redirecting to the Authorization Server's Authorization Endpoint
+	 * @since 6.3
+	 */
+	public void setAuthenticationFailureHandler(AuthenticationFailureHandler authenticationFailureHandler) {
+		Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
+		this.authenticationFailureHandler = authenticationFailureHandler;
+	}
+
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
@@ -174,7 +190,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
 			}
 		}
 		catch (Exception ex) {
-			this.unsuccessfulRedirectForAuthorization(request, response, ex);
+			AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex);
+			this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException);
 			return;
 		}
 		try {
@@ -199,7 +216,8 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
 					this.sendRedirectForAuthorization(request, response, authorizationRequest);
 				}
 				catch (Exception failed) {
-					this.unsuccessfulRedirectForAuthorization(request, response, failed);
+					AuthenticationException wrappedException = new OAuth2AuthorizationRequestException(ex);
+					this.authenticationFailureHandler.onAuthenticationFailure(request, response, wrappedException);
 				}
 				return;
 			}
@@ -223,9 +241,10 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
 	}
 
 	private void unsuccessfulRedirectForAuthorization(HttpServletRequest request, HttpServletResponse response,
-			Exception ex) throws IOException {
-		LogMessage message = LogMessage.format("Authorization Request failed: %s", ex);
-		if (InvalidClientRegistrationIdException.class.isAssignableFrom(ex.getClass())) {
+			AuthenticationException ex) throws IOException {
+		Throwable cause = ex.getCause();
+		LogMessage message = LogMessage.format("Authorization Request failed: %s", cause);
+		if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) {
 			// Log an invalid registrationId at WARN level to allow these errors to be
 			// tuned separately from other errors
 			this.logger.warn(message, ex);
@@ -250,4 +269,12 @@ public class OAuth2AuthorizationRequestRedirectFilter extends OncePerRequestFilt
 
 	}
 
+	private static final class OAuth2AuthorizationRequestException extends AuthenticationException {
+
+		OAuth2AuthorizationRequestException(Throwable cause) {
+			super(cause.getMessage(), cause);
+		}
+
+	}
+
 }

+ 31 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/OAuth2AuthorizationRequestRedirectFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -119,6 +119,11 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestCache(null));
 	}
 
+	@Test
+	public void setAuthenticationFailureHandlerIsNullThenThrowIllegalArgumentException() {
+		assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setAuthenticationFailureHandler(null));
+	}
+
 	@Test
 	public void doFilterWhenNotAuthorizationRequestThenNextFilter() throws Exception {
 		String requestUri = "/path";
@@ -144,6 +149,31 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
 		assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
 	}
 
+	@Test
+	public void doFilterWhenAuthorizationRequestWithInvalidClientAndCustomFailureHandlerThenCustomError()
+			throws Exception {
+		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"
+				+ this.registration1.getRegistrationId() + "-invalid";
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
+		request.setServletPath(requestUri);
+		MockHttpServletResponse response = new MockHttpServletResponse();
+		FilterChain filterChain = mock(FilterChain.class);
+		this.filter.setAuthenticationFailureHandler((request1, response1, ex) -> {
+			Throwable cause = ex.getCause();
+			if (InvalidClientRegistrationIdException.class.isAssignableFrom(cause.getClass())) {
+				response1.sendError(HttpStatus.BAD_REQUEST.value(), HttpStatus.BAD_REQUEST.getReasonPhrase());
+			}
+			else {
+				response1.sendError(HttpStatus.INTERNAL_SERVER_ERROR.value(),
+						HttpStatus.INTERNAL_SERVER_ERROR.getReasonPhrase());
+			}
+		});
+		this.filter.doFilter(request, response, filterChain);
+		verifyNoMoreInteractions(filterChain);
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+		assertThat(response.getErrorMessage()).isEqualTo(HttpStatus.BAD_REQUEST.getReasonPhrase());
+	}
+
 	@Test
 	public void doFilterWhenAuthorizationRequestOAuth2LoginThenRedirectForAuthorization() throws Exception {
 		String requestUri = OAuth2AuthorizationRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI + "/"