Browse Source

Add authenticationFailureHandler

- To ServerHttpSecurity#httpBasic
- To ServerHttpSecurity#oauthResourceServer

Closes gh-12132
Josh Cummings 2 years ago
parent
commit
3192618220

+ 35 - 8
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -2023,6 +2023,8 @@ public class ServerHttpSecurity {
 
 		private ServerAuthenticationEntryPoint entryPoint;
 
+		private ServerAuthenticationFailureHandler authenticationFailureHandler;
+
 		private HttpBasicSpec() {
 			List<DelegateEntry> entryPoints = new ArrayList<>();
 			entryPoints
@@ -2071,6 +2073,13 @@ public class ServerHttpSecurity {
 			return this;
 		}
 
+		public HttpBasicSpec authenticationFailureHandler(
+				ServerAuthenticationFailureHandler authenticationFailureHandler) {
+			Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
+			this.authenticationFailureHandler = authenticationFailureHandler;
+			return this;
+		}
+
 		/**
 		 * Allows method chaining to continue configuring the {@link ServerHttpSecurity}
 		 * @return the {@link ServerHttpSecurity} to continue configuring
@@ -2102,13 +2111,19 @@ public class ServerHttpSecurity {
 					Arrays.asList(this.xhrMatcher, restNotHtmlMatcher));
 			ServerHttpSecurity.this.defaultEntryPoints.add(new DelegateEntry(preferredMatcher, this.entryPoint));
 			AuthenticationWebFilter authenticationFilter = new AuthenticationWebFilter(this.authenticationManager);
-			authenticationFilter
-					.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(this.entryPoint));
+			authenticationFilter.setAuthenticationFailureHandler(authenticationFailureHandler());
 			authenticationFilter.setAuthenticationConverter(new ServerHttpBasicAuthenticationConverter());
 			authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
 			http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.HTTP_BASIC);
 		}
 
+		private ServerAuthenticationFailureHandler authenticationFailureHandler() {
+			if (this.authenticationFailureHandler != null) {
+				return this.authenticationFailureHandler;
+			}
+			return new ServerAuthenticationEntryPointFailureHandler(this.entryPoint);
+		}
+
 	}
 
 	/**
@@ -3996,6 +4011,8 @@ public class ServerHttpSecurity {
 
 		private ServerAuthenticationEntryPoint entryPoint = new BearerTokenServerAuthenticationEntryPoint();
 
+		private ServerAuthenticationFailureHandler authenticationFailureHandler;
+
 		private ServerAccessDeniedHandler accessDeniedHandler = new BearerTokenServerAccessDeniedHandler();
 
 		private ServerAuthenticationConverter bearerTokenConverter = new ServerBearerTokenAuthenticationConverter();
@@ -4038,6 +4055,12 @@ public class ServerHttpSecurity {
 			return this;
 		}
 
+		public OAuth2ResourceServerSpec authenticationFailureHandler(
+				ServerAuthenticationFailureHandler authenticationFailureHandler) {
+			this.authenticationFailureHandler = authenticationFailureHandler;
+			return this;
+		}
+
 		/**
 		 * Configures the {@link ServerAuthenticationConverter} to use for requests
 		 * authenticating with
@@ -4127,8 +4150,7 @@ public class ServerHttpSecurity {
 			if (this.authenticationManagerResolver != null) {
 				AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(this.authenticationManagerResolver);
 				oauth2.setServerAuthenticationConverter(this.bearerTokenConverter);
-				oauth2.setAuthenticationFailureHandler(
-						new ServerAuthenticationEntryPointFailureHandler(this.entryPoint));
+				oauth2.setAuthenticationFailureHandler(authenticationFailureHandler());
 				http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION);
 			}
 			else if (this.jwt != null) {
@@ -4181,6 +4203,13 @@ public class ServerHttpSecurity {
 			}
 		}
 
+		private ServerAuthenticationFailureHandler authenticationFailureHandler() {
+			if (this.authenticationFailureHandler != null) {
+				return this.authenticationFailureHandler;
+			}
+			return new ServerAuthenticationEntryPointFailureHandler(this.entryPoint);
+		}
+
 		public ServerHttpSecurity and() {
 			return ServerHttpSecurity.this;
 		}
@@ -4262,8 +4291,7 @@ public class ServerHttpSecurity {
 				ReactiveAuthenticationManager authenticationManager = getAuthenticationManager();
 				AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager);
 				oauth2.setServerAuthenticationConverter(OAuth2ResourceServerSpec.this.bearerTokenConverter);
-				oauth2.setAuthenticationFailureHandler(
-						new ServerAuthenticationEntryPointFailureHandler(OAuth2ResourceServerSpec.this.entryPoint));
+				oauth2.setAuthenticationFailureHandler(authenticationFailureHandler());
 				http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION);
 			}
 
@@ -4398,8 +4426,7 @@ public class ServerHttpSecurity {
 				ReactiveAuthenticationManager authenticationManager = getAuthenticationManager();
 				AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager);
 				oauth2.setServerAuthenticationConverter(OAuth2ResourceServerSpec.this.bearerTokenConverter);
-				oauth2.setAuthenticationFailureHandler(
-						new ServerAuthenticationEntryPointFailureHandler(OAuth2ResourceServerSpec.this.entryPoint));
+				oauth2.setAuthenticationFailureHandler(authenticationFailureHandler());
 				http.addFilterAt(oauth2, SecurityWebFiltersOrder.AUTHENTICATION);
 			}
 

+ 3 - 0
config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDsl.kt

@@ -21,6 +21,7 @@ import org.springframework.security.core.Authentication
 import org.springframework.security.core.context.SecurityContext
 import org.springframework.security.web.authentication.www.BasicAuthenticationFilter
 import org.springframework.security.web.server.ServerAuthenticationEntryPoint
+import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler
 import org.springframework.security.web.server.context.ReactorContextWebFilter
 import org.springframework.security.web.server.context.ServerSecurityContextRepository
 
@@ -42,6 +43,7 @@ import org.springframework.security.web.server.context.ServerSecurityContextRepo
 class ServerHttpBasicDsl {
     var authenticationManager: ReactiveAuthenticationManager? = null
     var securityContextRepository: ServerSecurityContextRepository? = null
+    var authenticationFailureHandler: ServerAuthenticationFailureHandler? = null
     var authenticationEntryPoint: ServerAuthenticationEntryPoint? = null
 
     private var disabled = false
@@ -57,6 +59,7 @@ class ServerHttpBasicDsl {
         return { httpBasic ->
             authenticationManager?.also { httpBasic.authenticationManager(authenticationManager) }
             securityContextRepository?.also { httpBasic.securityContextRepository(securityContextRepository) }
+            authenticationFailureHandler?.also { httpBasic.authenticationFailureHandler(authenticationFailureHandler) }
             authenticationEntryPoint?.also { httpBasic.authenticationEntryPoint(authenticationEntryPoint) }
             if (disabled) {
                 httpBasic.disable()

+ 3 - 0
config/src/main/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDsl.kt

@@ -19,6 +19,7 @@ package org.springframework.security.config.web.server
 import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver
 import org.springframework.security.web.server.ServerAuthenticationEntryPoint
 import org.springframework.security.web.server.authentication.ServerAuthenticationConverter
+import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler
 import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler
 import org.springframework.web.server.ServerWebExchange
 
@@ -38,6 +39,7 @@ import org.springframework.web.server.ServerWebExchange
 @ServerSecurityMarker
 class ServerOAuth2ResourceServerDsl {
     var accessDeniedHandler: ServerAccessDeniedHandler? = null
+    var authenticationFailureHandler: ServerAuthenticationFailureHandler? = null
     var authenticationEntryPoint: ServerAuthenticationEntryPoint? = null
     var bearerTokenConverter: ServerAuthenticationConverter? = null
     var authenticationManagerResolver: ReactiveAuthenticationManagerResolver<ServerWebExchange>? = null
@@ -107,6 +109,7 @@ class ServerOAuth2ResourceServerDsl {
     internal fun get(): (ServerHttpSecurity.OAuth2ResourceServerSpec) -> Unit {
         return { oauth2ResourceServer ->
             accessDeniedHandler?.also { oauth2ResourceServer.accessDeniedHandler(accessDeniedHandler) }
+            authenticationFailureHandler?.also { oauth2ResourceServer.authenticationFailureHandler(authenticationFailureHandler) }
             authenticationEntryPoint?.also { oauth2ResourceServer.authenticationEntryPoint(authenticationEntryPoint) }
             bearerTokenConverter?.also { oauth2ResourceServer.bearerTokenConverter(bearerTokenConverter) }
             authenticationManagerResolver?.also { oauth2ResourceServer.authenticationManagerResolver(authenticationManagerResolver!!) }

+ 50 - 0
config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java

@@ -51,6 +51,7 @@ import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
 import org.springframework.security.authentication.AbstractAuthenticationToken;
+import org.springframework.security.authentication.BadCredentialsException;
 import org.springframework.security.authentication.ReactiveAuthenticationManager;
 import org.springframework.security.authentication.ReactiveAuthenticationManagerResolver;
 import org.springframework.security.authentication.TestingAuthenticationToken;
@@ -73,6 +74,7 @@ import org.springframework.security.oauth2.server.resource.introspection.Reactiv
 import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint;
 import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
+import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler;
 import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
 import org.springframework.test.context.junit.jupiter.SpringExtension;
 import org.springframework.test.web.reactive.server.WebTestClient;
@@ -348,6 +350,25 @@ public class OAuth2ResourceServerSpecTests {
 		// @formatter:on
 	}
 
+	@Test
+	public void getWhenUsingCustomAuthenticationFailureHandlerThenUsesIsAccordingly() {
+		this.spring.register(CustomAuthenticationFailureHandlerConfig.class).autowire();
+		ServerAuthenticationFailureHandler handler = this.spring.getContext()
+				.getBean(ServerAuthenticationFailureHandler.class);
+		ReactiveAuthenticationManager authenticationManager = this.spring.getContext()
+				.getBean(ReactiveAuthenticationManager.class);
+		given(authenticationManager.authenticate(any()))
+				.willReturn(Mono.error(() -> new BadCredentialsException("bad")));
+		given(handler.onAuthenticationFailure(any(), any())).willReturn(Mono.empty());
+		// @formatter:off
+		this.client.get()
+				.headers((headers) -> headers.setBearerAuth(this.messageReadToken))
+				.exchange()
+				.expectStatus().isOk();
+		// @formatter:on
+		verify(handler).onAuthenticationFailure(any(), any());
+	}
+
 	@Test
 	public void postWhenSignedThenReturnsOk() {
 		this.spring.register(PublicKeyConfig.class, RootController.class).autowire();
@@ -893,6 +914,35 @@ public class OAuth2ResourceServerSpecTests {
 
 	}
 
+	@EnableWebFlux
+	@EnableWebFluxSecurity
+	static class CustomAuthenticationFailureHandlerConfig {
+
+		@Bean
+		SecurityWebFilterChain springSecurity(ServerHttpSecurity http) {
+			// @formatter:off
+			http
+				.authorizeExchange((authorize) -> authorize.anyExchange().authenticated())
+				.oauth2ResourceServer((oauth2) -> oauth2
+					.authenticationFailureHandler(authenticationFailureHandler())
+					.jwt((jwt) -> jwt.authenticationManager(authenticationManager()))
+				);
+			// @formatter:on
+			return http.build();
+		}
+
+		@Bean
+		ReactiveAuthenticationManager authenticationManager() {
+			return mock(ReactiveAuthenticationManager.class);
+		}
+
+		@Bean
+		ServerAuthenticationFailureHandler authenticationFailureHandler() {
+			return mock(ServerAuthenticationFailureHandler.class);
+		}
+
+	}
+
 	@EnableWebFlux
 	@EnableWebFluxSecurity
 	static class CustomBearerTokenServerAuthenticationConverter {

+ 23 - 0
config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java

@@ -35,6 +35,7 @@ import reactor.test.publisher.TestPublisher;
 import org.springframework.http.HttpStatus;
 import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
 import org.springframework.mock.web.server.MockServerWebExchange;
+import org.springframework.security.authentication.BadCredentialsException;
 import org.springframework.security.authentication.ReactiveAuthenticationManager;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
@@ -57,6 +58,7 @@ import org.springframework.security.web.server.WebFilterChainProxy;
 import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests;
 import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
 import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint;
+import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler;
 import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
 import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler;
 import org.springframework.security.web.server.authentication.logout.LogoutWebFilter;
@@ -218,6 +220,27 @@ public class ServerHttpSecurityTests {
 		verify(authenticationEntryPoint).commence(any(), any());
 	}
 
+	@Test
+	public void basicWhenCustomAuthenticationFailureHandlerThenUses() {
+		ReactiveAuthenticationManager authenticationManager = mock(ReactiveAuthenticationManager.class);
+		ServerAuthenticationFailureHandler authenticationFailureHandler = mock(
+				ServerAuthenticationFailureHandler.class);
+		this.http.httpBasic().authenticationFailureHandler(authenticationFailureHandler);
+		this.http.httpBasic().authenticationManager(authenticationManager);
+		this.http.authorizeExchange().anyExchange().authenticated();
+		given(authenticationManager.authenticate(any()))
+				.willReturn(Mono.error(() -> new BadCredentialsException("bad")));
+		given(authenticationFailureHandler.onAuthenticationFailure(any(), any())).willReturn(Mono.empty());
+		WebTestClient client = buildClient();
+		// @formatter:off
+		client.get().uri("/")
+			.headers((headers) -> headers.setBasicAuth("user", "password"))
+			.exchange()
+			.expectStatus().isOk();
+		// @formatter:on
+		verify(authenticationFailureHandler).onAuthenticationFailure(any(), any());
+	}
+
 	@Test
 	public void buildWhenServerWebExchangeFromContextThenFound() {
 		SecurityWebFilterChain filter = this.http.build();

+ 39 - 1
config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpBasicDslTests.kt

@@ -19,7 +19,6 @@ package org.springframework.security.config.web.server
 import io.mockk.every
 import io.mockk.mockkObject
 import io.mockk.verify
-import java.util.Base64
 import org.junit.jupiter.api.Test
 import org.junit.jupiter.api.extension.ExtendWith
 import org.springframework.beans.factory.annotation.Autowired
@@ -36,6 +35,7 @@ import org.springframework.security.core.userdetails.MapReactiveUserDetailsServi
 import org.springframework.security.core.userdetails.User
 import org.springframework.security.web.server.SecurityWebFilterChain
 import org.springframework.security.web.server.ServerAuthenticationEntryPoint
+import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler
 import org.springframework.security.web.server.context.ServerSecurityContextRepository
 import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository
 import org.springframework.test.web.reactive.server.WebTestClient
@@ -43,6 +43,7 @@ import org.springframework.web.bind.annotation.RequestMapping
 import org.springframework.web.bind.annotation.RestController
 import org.springframework.web.reactive.config.EnableWebFlux
 import reactor.core.publisher.Mono
+import java.util.*
 
 /**
  * Tests for [ServerHttpBasicDsl]
@@ -216,6 +217,43 @@ class ServerHttpBasicDslTests {
         }
     }
 
+    @Test
+    fun `http basic when custom authentication failure handler then failure handler used`() {
+        this.spring.register(CustomAuthenticationFailureHandlerConfig::class.java, UserDetailsConfig::class.java).autowire()
+        mockkObject(CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER)
+        every {
+            CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any())
+        } returns Mono.empty()
+
+        this.client.get()
+            .uri("/")
+            .header("Authorization", "Basic " + Base64.getEncoder().encodeToString("user:wrong".toByteArray()))
+            .exchange()
+
+        verify(exactly = 1) { CustomAuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any()) }
+    }
+
+    @EnableWebFluxSecurity
+    @EnableWebFlux
+    open class CustomAuthenticationFailureHandlerConfig {
+
+        companion object {
+            val FAILURE_HANDLER: ServerAuthenticationFailureHandler = ServerAuthenticationFailureHandler { _, _ -> Mono.empty() }
+        }
+
+        @Bean
+        open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
+            return http {
+                authorizeExchange {
+                    authorize(anyExchange, authenticated)
+                }
+                httpBasic {
+                    authenticationFailureHandler = FAILURE_HANDLER
+                }
+            }
+        }
+    }
+
     @Configuration
     open class UserDetailsConfig {
         @Bean

+ 46 - 4
config/src/test/kotlin/org/springframework/security/config/web/server/ServerOAuth2ResourceServerDslTests.kt

@@ -19,10 +19,6 @@ package org.springframework.security.config.web.server
 import io.mockk.every
 import io.mockk.mockkObject
 import io.mockk.verify
-import java.math.BigInteger
-import java.security.KeyFactory
-import java.security.interfaces.RSAPublicKey
-import java.security.spec.RSAPublicKeySpec
 import org.junit.jupiter.api.Test
 import org.junit.jupiter.api.extension.ExtendWith
 import org.springframework.beans.factory.annotation.Autowired
@@ -36,11 +32,16 @@ import org.springframework.security.config.test.SpringTestContextExtension
 import org.springframework.security.oauth2.server.resource.web.server.authentication.ServerBearerTokenAuthenticationConverter
 import org.springframework.security.web.server.SecurityWebFilterChain
 import org.springframework.security.web.server.authentication.HttpStatusServerEntryPoint
+import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler
 import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler
 import org.springframework.test.web.reactive.server.WebTestClient
 import org.springframework.web.reactive.config.EnableWebFlux
 import org.springframework.web.server.ServerWebExchange
 import reactor.core.publisher.Mono
+import java.math.BigInteger
+import java.security.KeyFactory
+import java.security.interfaces.RSAPublicKey
+import java.security.spec.RSAPublicKeySpec
 
 /**
  * Tests for [ServerOAuth2ResourceServerDsl]
@@ -125,6 +126,47 @@ class ServerOAuth2ResourceServerDslTests {
         }
     }
 
+    @Test
+    fun `http basic when custom authentication failure handler then failure handler used`() {
+        this.spring.register(AuthenticationFailureHandlerConfig::class.java).autowire()
+        mockkObject(AuthenticationFailureHandlerConfig.FAILURE_HANDLER)
+        every {
+            AuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any())
+        } returns Mono.empty()
+
+        this.client.get()
+            .uri("/")
+            .header("Authorization", "Bearer token")
+            .exchange()
+            .expectStatus().isOk
+
+        verify(exactly = 1) { AuthenticationFailureHandlerConfig.FAILURE_HANDLER.onAuthenticationFailure(any(), any()) }
+    }
+
+    @EnableWebFluxSecurity
+    @EnableWebFlux
+    open class AuthenticationFailureHandlerConfig {
+
+        companion object {
+            val FAILURE_HANDLER: ServerAuthenticationFailureHandler = ServerAuthenticationFailureHandler { _, _ -> Mono.empty() }
+        }
+
+        @Bean
+        open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
+            return http {
+                authorizeExchange {
+                    authorize(anyExchange, authenticated)
+                }
+                oauth2ResourceServer {
+                    authenticationFailureHandler = FAILURE_HANDLER
+                    jwt {
+                        publicKey = publicKey()
+                    }
+                }
+            }
+        }
+    }
+
     @Test
     fun `request when custom bearer token converter configured then custom converter used`() {
         this.spring.register(BearerTokenConverterConfig::class.java).autowire()