Pārlūkot izejas kodu

Add OAuth2ErrorAuthenticationFailureHandler

Related gh-1369

Closes gh-1384
Dmitriy Dubson 1 gadu atpakaļ
vecāks
revīzija
96c90dded7

+ 4 - 4
docs/modules/ROOT/pages/protocol-endpoints.adoc

@@ -167,7 +167,7 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h
 * `*AuthenticationConverter*` -- An `OAuth2DeviceAuthorizationRequestAuthenticationConverter`.
 * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OAuth2DeviceAuthorizationRequestAuthenticationProvider`.
 * `*AuthenticationSuccessHandler*` -- An internal implementation that handles an "`authenticated`" `OAuth2DeviceAuthorizationRequestAuthenticationToken` and returns the `OAuth2DeviceAuthorizationResponse`.
-* `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response.
+* `*AuthenticationFailureHandler*` -- An `OAuth2ErrorAuthenticationFailureHandler` instance that handles the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response.
 
 [[oauth2-device-verification-endpoint]]
 == OAuth2 Device Verification Endpoint
@@ -264,7 +264,7 @@ The supported https://datatracker.ietf.org/doc/html/rfc6749#section-1.3[authoriz
 * `*AuthenticationConverter*` -- A `DelegatingAuthenticationConverter` composed of `OAuth2AuthorizationCodeAuthenticationConverter`, `OAuth2RefreshTokenAuthenticationConverter`, `OAuth2ClientCredentialsAuthenticationConverter`, and `OAuth2DeviceCodeAuthenticationConverter`.
 * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OAuth2AuthorizationCodeAuthenticationProvider`, `OAuth2RefreshTokenAuthenticationProvider`, `OAuth2ClientCredentialsAuthenticationProvider`, and `OAuth2DeviceCodeAuthenticationProvider`.
 * `*AuthenticationSuccessHandler*` -- An internal implementation that handles an `OAuth2AccessTokenAuthenticationToken` and returns the `OAuth2AccessTokenResponse`.
-* `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response.
+* `*AuthenticationFailureHandler*` -- An `OAuth2ErrorAuthenticationFailureHandler` instance that handles the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response.
 
 [[oauth2-token-introspection-endpoint]]
 == OAuth2 Token Introspection Endpoint
@@ -311,7 +311,7 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h
 * `*AuthenticationConverter*` -- An `OAuth2TokenIntrospectionAuthenticationConverter`.
 * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OAuth2TokenIntrospectionAuthenticationProvider`.
 * `*AuthenticationSuccessHandler*` -- An internal implementation that handles an "`authenticated`" `OAuth2TokenIntrospectionAuthenticationToken` and returns the `OAuth2TokenIntrospection` response.
-* `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response.
+* `*AuthenticationFailureHandler*` -- An `OAuth2ErrorAuthenticationFailureHandler` instance that handles the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response.
 
 [[oauth2-token-revocation-endpoint]]
 == OAuth2 Token Revocation Endpoint
@@ -358,7 +358,7 @@ public SecurityFilterChain authorizationServerSecurityFilterChain(HttpSecurity h
 * `*AuthenticationConverter*` -- An `OAuth2TokenRevocationAuthenticationConverter`.
 * `*AuthenticationManager*` -- An `AuthenticationManager` composed of `OAuth2TokenRevocationAuthenticationProvider`.
 * `*AuthenticationSuccessHandler*` -- An internal implementation that handles an "`authenticated`" `OAuth2TokenRevocationAuthenticationToken` and returns the OAuth2 revocation response.
-* `*AuthenticationFailureHandler*` -- An internal implementation that uses the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response.
+* `*AuthenticationFailureHandler*` -- An `OAuth2ErrorAuthenticationFailureHandler` instance that handles the `OAuth2Error` associated with the `OAuth2AuthenticationException` and returns the `OAuth2Error` response.
 
 [[oauth2-authorization-server-metadata-endpoint]]
 == OAuth2 Authorization Server Metadata Endpoint

+ 2 - 15
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2DeviceAuthorizationEndpointFilter.java

@@ -24,14 +24,12 @@ import jakarta.servlet.http.HttpServletResponse;
 
 import org.springframework.core.log.LogMessage;
 import org.springframework.http.HttpMethod;
-import org.springframework.http.HttpStatus;
 import org.springframework.http.converter.HttpMessageConverter;
 import org.springframework.http.server.ServletServerHttpResponse;
 import org.springframework.security.authentication.AbstractAuthenticationToken;
 import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
-import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2DeviceCode;
@@ -40,10 +38,10 @@ import org.springframework.security.oauth2.core.OAuth2UserCode;
 import org.springframework.security.oauth2.core.endpoint.OAuth2DeviceAuthorizationResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.http.converter.OAuth2DeviceAuthorizationResponseHttpMessageConverter;
-import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceAuthorizationRequestAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceAuthorizationRequestAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.context.AuthorizationServerContextHolder;
+import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2DeviceAuthorizationRequestAuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
@@ -76,13 +74,11 @@ public final class OAuth2DeviceAuthorizationEndpointFilter extends OncePerReques
 	private final RequestMatcher deviceAuthorizationEndpointMatcher;
 	private final HttpMessageConverter<OAuth2DeviceAuthorizationResponse> deviceAuthorizationHttpResponseConverter =
 			new OAuth2DeviceAuthorizationResponseHttpMessageConverter();
-	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
-			new OAuth2ErrorHttpMessageConverter();
 	private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource =
 			new WebAuthenticationDetailsSource();
 	private AuthenticationConverter authenticationConverter;
 	private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendDeviceAuthorizationResponse;
-	private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
+	private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler();
 	private String verificationUri = OAuth2DeviceVerificationEndpointFilter.DEFAULT_DEVICE_VERIFICATION_ENDPOINT_URI;
 
 	/**
@@ -225,13 +221,4 @@ public final class OAuth2DeviceAuthorizationEndpointFilter extends OncePerReques
 		this.deviceAuthorizationHttpResponseConverter.write(deviceAuthorizationResponse, null, httpResponse);
 	}
 
-	private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
-			AuthenticationException authenticationException) throws IOException {
-
-		OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
-		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
-		httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
-		this.errorHttpResponseConverter.write(error, null, httpResponse);
-	}
-
 }

+ 2 - 15
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenEndpointFilter.java

@@ -27,14 +27,12 @@ import jakarta.servlet.http.HttpServletResponse;
 
 import org.springframework.core.log.LogMessage;
 import org.springframework.http.HttpMethod;
-import org.springframework.http.HttpStatus;
 import org.springframework.http.converter.HttpMessageConverter;
 import org.springframework.http.server.ServletServerHttpResponse;
 import org.springframework.security.authentication.AbstractAuthenticationToken;
 import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
-import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
@@ -44,13 +42,13 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
 import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
-import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AccessTokenAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2AuthorizationGrantAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2DeviceCodeAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2RefreshTokenAuthenticationProvider;
+import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler;
 import org.springframework.security.oauth2.server.authorization.web.authentication.DelegatingAuthenticationConverter;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2AuthorizationCodeAuthenticationConverter;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ClientCredentialsAuthenticationConverter;
@@ -107,13 +105,11 @@ public final class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 	private final RequestMatcher tokenEndpointMatcher;
 	private final HttpMessageConverter<OAuth2AccessTokenResponse> accessTokenHttpResponseConverter =
 			new OAuth2AccessTokenResponseHttpMessageConverter();
-	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
-			new OAuth2ErrorHttpMessageConverter();
 	private AuthenticationDetailsSource<HttpServletRequest, ?> authenticationDetailsSource =
 			new WebAuthenticationDetailsSource();
 	private AuthenticationConverter authenticationConverter;
 	private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendAccessTokenResponse;
-	private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
+	private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler();
 
 	/**
 	 * Constructs an {@code OAuth2TokenEndpointFilter} using the provided parameters.
@@ -250,15 +246,6 @@ public final class OAuth2TokenEndpointFilter extends OncePerRequestFilter {
 		this.accessTokenHttpResponseConverter.write(accessTokenResponse, null, httpResponse);
 	}
 
-	private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
-			AuthenticationException exception) throws IOException {
-
-		OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
-		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
-		httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
-		this.errorHttpResponseConverter.write(error, null, httpResponse);
-	}
-
 	private static void throwError(String errorCode, String parameterName) {
 		OAuth2Error error = new OAuth2Error(errorCode, "OAuth 2.0 Parameter: " + parameterName, DEFAULT_ERROR_URI);
 		throw new OAuth2AuthenticationException(error);

+ 3 - 14
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenIntrospectionEndpointFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,20 +24,18 @@ import jakarta.servlet.http.HttpServletResponse;
 
 import org.springframework.core.log.LogMessage;
 import org.springframework.http.HttpMethod;
-import org.springframework.http.HttpStatus;
 import org.springframework.http.converter.HttpMessageConverter;
 import org.springframework.http.server.ServletServerHttpResponse;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
-import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
-import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.OAuth2TokenIntrospection;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenIntrospectionAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenIntrospectionAuthenticationToken;
 import org.springframework.security.oauth2.server.authorization.http.converter.OAuth2TokenIntrospectionHttpMessageConverter;
+import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2TokenIntrospectionAuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
@@ -69,9 +67,8 @@ public final class OAuth2TokenIntrospectionEndpointFilter extends OncePerRequest
 	private AuthenticationConverter authenticationConverter;
 	private final HttpMessageConverter<OAuth2TokenIntrospection> tokenIntrospectionHttpResponseConverter =
 			new OAuth2TokenIntrospectionHttpMessageConverter();
-	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter();
 	private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendIntrospectionResponse;
-	private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
+	private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler();
 
 	/**
 	 * Constructs an {@code OAuth2TokenIntrospectionEndpointFilter} using the provided parameters.
@@ -166,12 +163,4 @@ public final class OAuth2TokenIntrospectionEndpointFilter extends OncePerRequest
 		this.tokenIntrospectionHttpResponseConverter.write(tokenClaims, null, httpResponse);
 	}
 
-	private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
-			AuthenticationException exception) throws IOException {
-		OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
-		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
-		httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
-		this.errorHttpResponseConverter.write(error, null, httpResponse);
-	}
-
 }

+ 3 - 16
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/OAuth2TokenRevocationEndpointFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020-2022 the original author or authors.
+ * Copyright 2020-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -25,17 +25,14 @@ import jakarta.servlet.http.HttpServletResponse;
 import org.springframework.core.log.LogMessage;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpStatus;
-import org.springframework.http.converter.HttpMessageConverter;
-import org.springframework.http.server.ServletServerHttpResponse;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.core.Authentication;
-import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
-import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationProvider;
 import org.springframework.security.oauth2.server.authorization.authentication.OAuth2TokenRevocationAuthenticationToken;
+import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2ErrorAuthenticationFailureHandler;
 import org.springframework.security.oauth2.server.authorization.web.authentication.OAuth2TokenRevocationAuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
@@ -65,10 +62,8 @@ public final class OAuth2TokenRevocationEndpointFilter extends OncePerRequestFil
 	private final AuthenticationManager authenticationManager;
 	private final RequestMatcher tokenRevocationEndpointMatcher;
 	private AuthenticationConverter authenticationConverter;
-	private final HttpMessageConverter<OAuth2Error> errorHttpResponseConverter =
-			new OAuth2ErrorHttpMessageConverter();
 	private AuthenticationSuccessHandler authenticationSuccessHandler = this::sendRevocationSuccessResponse;
-	private AuthenticationFailureHandler authenticationFailureHandler = this::sendErrorResponse;
+	private AuthenticationFailureHandler authenticationFailureHandler = new OAuth2ErrorAuthenticationFailureHandler();
 
 	/**
 	 * Constructs an {@code OAuth2TokenRevocationEndpointFilter} using the provided parameters.
@@ -157,12 +152,4 @@ public final class OAuth2TokenRevocationEndpointFilter extends OncePerRequestFil
 		response.setStatus(HttpStatus.OK.value());
 	}
 
-	private void sendErrorResponse(HttpServletRequest request, HttpServletResponse response,
-			AuthenticationException exception) throws IOException {
-		OAuth2Error error = ((OAuth2AuthenticationException) exception).getError();
-		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
-		httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
-		this.errorHttpResponseConverter.write(error, null, httpResponse);
-	}
-
 }

+ 72 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ErrorAuthenticationFailureHandler.java

@@ -0,0 +1,72 @@
+/*
+ * Copyright 2020-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.web.authentication;
+
+import java.io.IOException;
+
+import jakarta.servlet.ServletException;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.springframework.core.log.LogMessage;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.http.server.ServletServerHttpResponse;
+import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.http.converter.OAuth2ErrorHttpMessageConverter;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
+
+/**
+ * A default implementation of an {@link AuthenticationFailureHandler} used for handling an {@link OAuth2AuthenticationException}
+ * and returning the {@link OAuth2Error Error Response}.
+ *
+ * @author Dmitriy Dubson
+ * @see AuthenticationFailureHandler
+ * @since 1.2.0
+ */
+public final class OAuth2ErrorAuthenticationFailureHandler implements AuthenticationFailureHandler {
+
+	private final Log logger = LogFactory.getLog(getClass());
+
+	private HttpMessageConverter<OAuth2Error> errorHttpResponseConverter = new OAuth2ErrorHttpMessageConverter();
+
+	@Override
+	public void onAuthenticationFailure(HttpServletRequest request, HttpServletResponse response, AuthenticationException authenticationException) throws IOException, ServletException {
+		ServletServerHttpResponse httpResponse = new ServletServerHttpResponse(response);
+		httpResponse.setStatusCode(HttpStatus.BAD_REQUEST);
+
+		if (authenticationException instanceof OAuth2AuthenticationException) {
+			OAuth2Error error = ((OAuth2AuthenticationException) authenticationException).getError();
+			this.errorHttpResponseConverter.write(error, null, httpResponse);
+		} else {
+			if (this.logger.isWarnEnabled()) {
+				this.logger.warn(LogMessage.format("Authentication exception must be of type 'org.springframework.security.oauth2.core.OAuth2AuthenticationException'. Provided exception was '%s'", authenticationException.getClass().getName()));
+			}
+		}
+	}
+
+	/**
+	 * Sets OAuth error HTTP message converter to write to upon authentication failure
+	 *
+	 * @param errorHttpResponseConverter the error HTTP message converter to set
+	 */
+	public void setErrorHttpResponseConverter(HttpMessageConverter<OAuth2Error> errorHttpResponseConverter) {
+		this.errorHttpResponseConverter = errorHttpResponseConverter;
+	}
+}

+ 89 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/web/authentication/OAuth2ErrorAuthenticationFailureHandlerTests.java

@@ -0,0 +1,89 @@
+/*
+ * Copyright 2020-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization.web.authentication;
+
+import java.io.IOException;
+
+import jakarta.servlet.ServletException;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.converter.HttpMessageConverter;
+import org.springframework.http.server.ServletServerHttpResponse;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.BadCredentialsException;
+import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.isNull;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+
+/**
+ * Tests for {@link OAuth2ErrorAuthenticationFailureHandler}
+ *
+ * @author Dmitriy Dubson
+ */
+public class OAuth2ErrorAuthenticationFailureHandlerTests {
+
+	private HttpMessageConverter<OAuth2Error> errorHttpMessageConverter;
+
+	private HttpServletRequest request;
+
+	private HttpServletResponse response;
+
+	@BeforeEach
+	@SuppressWarnings("unchecked")
+	public void setUp() {
+		errorHttpMessageConverter = (HttpMessageConverter<OAuth2Error>) mock(HttpMessageConverter.class);
+		request = new MockHttpServletRequest();
+		response = new MockHttpServletResponse();
+	}
+
+	@Test
+	public void onAuthenticationFailure() throws IOException, ServletException {
+		OAuth2Error invalidRequestError = new OAuth2Error(OAuth2ErrorCodes.INVALID_REQUEST);
+		AuthenticationException authenticationException = new OAuth2AuthenticationException(invalidRequestError);
+		OAuth2ErrorAuthenticationFailureHandler handler = new OAuth2ErrorAuthenticationFailureHandler();
+		handler.setErrorHttpResponseConverter(errorHttpMessageConverter);
+
+		handler.onAuthenticationFailure(request, response, authenticationException);
+
+		verify(errorHttpMessageConverter).write(eq(invalidRequestError), isNull(), any(ServletServerHttpResponse.class));
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+	}
+
+	@Test
+	public void onAuthenticationFailure_ifExceptionProvidedIsNotOAuth2AuthenticationException() throws ServletException, IOException {
+		OAuth2ErrorAuthenticationFailureHandler handler = new OAuth2ErrorAuthenticationFailureHandler();
+		handler.setErrorHttpResponseConverter(errorHttpMessageConverter);
+
+		handler.onAuthenticationFailure(request, response, new BadCredentialsException("Not a valid exception."));
+
+		verifyNoInteractions(errorHttpMessageConverter);
+		assertThat(response.getStatus()).isEqualTo(HttpStatus.BAD_REQUEST.value());
+	}
+
+}