Browse Source

BearerTokenAuthenticationFilter exposes AuthenticationFailureHandler

Make BearerTokenAuthenticationFilter expose an AuthenticationFailureHandler which, by default, invokes the AuthenticationEntryPoint set in the filter.

Fixes gh-7009
Thomas Vitale 6 years ago
parent
commit
f9747e6591

+ 15 - 1
oauth2/oauth2-resource-server/src/main/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilter.java

@@ -33,6 +33,7 @@ import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.server.resource.BearerTokenAuthenticationToken;
 import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationProvider;
 import org.springframework.security.web.AuthenticationEntryPoint;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
 import org.springframework.util.Assert;
 import org.springframework.web.filter.OncePerRequestFilter;
@@ -61,6 +62,9 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
 
 	private AuthenticationEntryPoint authenticationEntryPoint = new BearerTokenAuthenticationEntryPoint();
 
+	private AuthenticationFailureHandler authenticationFailureHandler = (request, response, exception) ->
+		authenticationEntryPoint.commence(request, response, exception);
+
 	/**
 	 * Construct a {@code BearerTokenAuthenticationFilter} using the provided parameter(s)
 	 * @param authenticationManagerResolver
@@ -131,7 +135,7 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
 				this.logger.debug("Authentication request for failed: " + failed);
 			}
 
-			this.authenticationEntryPoint.commence(request, response, failed);
+			this.authenticationFailureHandler.onAuthenticationFailure(request, response, failed);
 		}
 	}
 
@@ -152,4 +156,14 @@ public final class BearerTokenAuthenticationFilter extends OncePerRequestFilter
 		Assert.notNull(authenticationEntryPoint, "authenticationEntryPoint cannot be null");
 		this.authenticationEntryPoint = authenticationEntryPoint;
 	}
+
+	/**
+	 * Set the {@link AuthenticationFailureHandler} to use. Default implementation invokes {@link AuthenticationEntryPoint}.
+	 * @param authenticationFailureHandler the {@code AuthenticationFailureHandler} to use
+	 * @since 5.2
+	 */
+	public final void setAuthenticationFailureHandler(final AuthenticationFailureHandler authenticationFailureHandler) {
+		Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
+		this.authenticationFailureHandler = authenticationFailureHandler;
+	}
 }

+ 28 - 1
oauth2/oauth2-resource-server/src/test/java/org/springframework/security/oauth2/server/resource/web/BearerTokenAuthenticationFilterTests.java

@@ -37,6 +37,7 @@ import org.springframework.security.oauth2.server.resource.BearerTokenAuthentica
 import org.springframework.security.oauth2.server.resource.BearerTokenError;
 import org.springframework.security.oauth2.server.resource.BearerTokenErrorCodes;
 import org.springframework.security.web.AuthenticationEntryPoint;
+import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatCode;
@@ -55,6 +56,9 @@ public class BearerTokenAuthenticationFilterTests {
 	@Mock
 	AuthenticationEntryPoint authenticationEntryPoint;
 
+	@Mock
+	AuthenticationFailureHandler authenticationFailureHandler;
+
 	@Mock
 	AuthenticationManager authenticationManager;
 
@@ -138,7 +142,7 @@ public class BearerTokenAuthenticationFilterTests {
 	}
 
 	@Test
-	public void doFilterWhenAuthenticationFailsThenPropagatesError() throws ServletException, IOException {
+	public void doFilterWhenAuthenticationFailsWithDefaultHandlerThenPropagatesError() throws ServletException, IOException {
 		BearerTokenError error = new BearerTokenError(
 				BearerTokenErrorCodes.INVALID_TOKEN,
 				HttpStatus.UNAUTHORIZED,
@@ -159,6 +163,29 @@ public class BearerTokenAuthenticationFilterTests {
 		verify(this.authenticationEntryPoint).commence(this.request, this.response, exception);
 	}
 
+	@Test
+	public void doFilterWhenAuthenticationFailsWithCustomHandlerThenPropagatesError() throws ServletException, IOException {
+		BearerTokenError error = new BearerTokenError(
+				BearerTokenErrorCodes.INVALID_TOKEN,
+				HttpStatus.UNAUTHORIZED,
+				"description",
+				"uri"
+		);
+
+		OAuth2AuthenticationException exception = new OAuth2AuthenticationException(error);
+
+		when(this.bearerTokenResolver.resolve(this.request)).thenReturn("token");
+		when(this.authenticationManager.authenticate(any(BearerTokenAuthenticationToken.class)))
+				.thenThrow(exception);
+
+		BearerTokenAuthenticationFilter filter =
+				addMocks(new BearerTokenAuthenticationFilter(this.authenticationManager));
+		filter.setAuthenticationFailureHandler(this.authenticationFailureHandler);
+		filter.doFilter(this.request, this.response, this.filterChain);
+
+		verify(this.authenticationFailureHandler).onAuthenticationFailure(this.request, this.response, exception);
+	}
+
 	@Test
 	public void setAuthenticationEntryPointWhenNullThenThrowsException() {
 		BearerTokenAuthenticationFilter filter = new BearerTokenAuthenticationFilter(this.authenticationManager);