瀏覽代碼

AuthenticationWebFilter uses AuthenticationFailureHandler

Issue gh-4533
Rob Winch 8 年之前
父節點
當前提交
45bac0fd2c

+ 11 - 7
webflux/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java

@@ -52,7 +52,7 @@ public class AuthenticationWebFilter implements WebFilter {
 
 
 	private Function<ServerWebExchange,Mono<Authentication>> authenticationConverter = new HttpBasicAuthenticationConverter();
 	private Function<ServerWebExchange,Mono<Authentication>> authenticationConverter = new HttpBasicAuthenticationConverter();
 
 
-	private AuthenticationEntryPoint entryPoint = new HttpBasicAuthenticationEntryPoint();
+	private AuthenticationFailureHandler authenticationFailureHandler = new AuthenticationEntryPointFailureHandler(new HttpBasicAuthenticationEntryPoint());
 
 
 	private SecurityContextRepository securityContextRepository = new ServerWebExchangeAttributeSecurityContextRepository();
 	private SecurityContextRepository securityContextRepository = new ServerWebExchangeAttributeSecurityContextRepository();
 
 
@@ -79,16 +79,18 @@ public class AuthenticationWebFilter implements WebFilter {
 
 
 	private Mono<Void> authenticate(ServerWebExchange wrappedExchange,
 	private Mono<Void> authenticate(ServerWebExchange wrappedExchange,
 		WebFilterChain chain, Authentication token) {
 		WebFilterChain chain, Authentication token) {
+		WebFilterExchange webFilterExchange = new WebFilterExchange(wrappedExchange, chain);
 		return this.authenticationManager.authenticate(token)
 		return this.authenticationManager.authenticate(token)
-			.flatMap(authentication -> onAuthenticationSuccess(authentication, wrappedExchange, chain))
-			.onErrorResume(AuthenticationException.class, e -> this.entryPoint.commence(wrappedExchange, e));
+			.flatMap(authentication -> onAuthenticationSuccess(authentication, webFilterExchange))
+			.onErrorResume(AuthenticationException.class, e -> this.authenticationFailureHandler.onAuthenticationFailure(webFilterExchange, e));
 	}
 	}
 
 
-	private Mono<Void> onAuthenticationSuccess(Authentication authentication, ServerWebExchange exchange, WebFilterChain chain) {
+	private Mono<Void> onAuthenticationSuccess(Authentication authentication, WebFilterExchange webFilterExchange) {
+		ServerWebExchange exchange = webFilterExchange.getExchange();
 		SecurityContextImpl securityContext = new SecurityContextImpl();
 		SecurityContextImpl securityContext = new SecurityContextImpl();
 		securityContext.setAuthentication(authentication);
 		securityContext.setAuthentication(authentication);
 		return this.securityContextRepository.save(exchange, securityContext)
 		return this.securityContextRepository.save(exchange, securityContext)
-			.then(this.authenticationSuccessHandler.success(authentication, new WebFilterExchange(exchange, chain)));
+			.then(this.authenticationSuccessHandler.success(authentication, webFilterExchange));
 	}
 	}
 
 
 	public void setSecurityContextRepository(
 	public void setSecurityContextRepository(
@@ -105,8 +107,10 @@ public class AuthenticationWebFilter implements WebFilter {
 		this.authenticationConverter = authenticationConverter;
 		this.authenticationConverter = authenticationConverter;
 	}
 	}
 
 
-	public void setEntryPoint(AuthenticationEntryPoint entryPoint) {
-		this.entryPoint = entryPoint;
+	public void setAuthenticationFailureHandler(
+		AuthenticationFailureHandler authenticationFailureHandler) {
+		Assert.notNull(authenticationFailureHandler, "authenticationFailureHandler cannot be null");
+		this.authenticationFailureHandler = authenticationFailureHandler;
 	}
 	}
 
 
 	public void setRequiresAuthenticationMatcher(
 	public void setRequiresAuthenticationMatcher(

+ 9 - 10
webflux/src/test/java/org/springframework/security/web/server/authentication/AuthenticationWebFilterTests.java

@@ -32,7 +32,6 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
 import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
-import org.springframework.security.web.server.AuthenticationEntryPoint;
 import org.springframework.security.web.server.context.SecurityContextRepository;
 import org.springframework.security.web.server.context.SecurityContextRepository;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.test.web.reactive.server.EntityExchangeResult;
 import org.springframework.test.web.reactive.server.EntityExchangeResult;
@@ -62,7 +61,7 @@ public class AuthenticationWebFilterTests {
 	@Mock
 	@Mock
 	private ReactiveAuthenticationManager authenticationManager;
 	private ReactiveAuthenticationManager authenticationManager;
 	@Mock
 	@Mock
-	private AuthenticationEntryPoint entryPoint;
+	private AuthenticationFailureHandler failureHandler;
 	@Mock
 	@Mock
 	private SecurityContextRepository securityContextRepository;
 	private SecurityContextRepository securityContextRepository;
 
 
@@ -73,8 +72,8 @@ public class AuthenticationWebFilterTests {
 		this.filter = new AuthenticationWebFilter(this.authenticationManager);
 		this.filter = new AuthenticationWebFilter(this.authenticationManager);
 		this.filter.setAuthenticationSuccessHandler(this.successHandler);
 		this.filter.setAuthenticationSuccessHandler(this.successHandler);
 		this.filter.setAuthenticationConverter(this.authenticationConverter);
 		this.filter.setAuthenticationConverter(this.authenticationConverter);
-		this.filter.setEntryPoint(this.entryPoint);
 		this.filter.setSecurityContextRepository(this.securityContextRepository);
 		this.filter.setSecurityContextRepository(this.securityContextRepository);
+		this.filter.setAuthenticationFailureHandler(this.failureHandler);
 	}
 	}
 
 
 	@Test
 	@Test
@@ -160,7 +159,7 @@ public class AuthenticationWebFilterTests {
 
 
 		verify(this.securityContextRepository, never()).save(any(), any());
 		verify(this.securityContextRepository, never()).save(any(), any());
 		verifyZeroInteractions(this.authenticationManager, this.successHandler,
 		verifyZeroInteractions(this.authenticationManager, this.successHandler,
-			this.entryPoint);
+			this.failureHandler);
 	}
 	}
 
 
 	@Test
 	@Test
@@ -180,7 +179,7 @@ public class AuthenticationWebFilterTests {
 
 
 		verify(this.securityContextRepository, never()).save(any(), any());
 		verify(this.securityContextRepository, never()).save(any(), any());
 		verifyZeroInteractions(this.authenticationManager, this.successHandler,
 		verifyZeroInteractions(this.authenticationManager, this.successHandler,
-			this.entryPoint);
+			this.failureHandler);
 	}
 	}
 
 
 	@Test
 	@Test
@@ -204,7 +203,7 @@ public class AuthenticationWebFilterTests {
 
 
 		verify(this.successHandler).success(eq(authentication.block()), any());
 		verify(this.successHandler).success(eq(authentication.block()), any());
 		verify(this.securityContextRepository).save(any(), any());
 		verify(this.securityContextRepository).save(any(), any());
-		verifyZeroInteractions(this.entryPoint);
+		verifyZeroInteractions(this.failureHandler);
 	}
 	}
 
 
 	@Test
 	@Test
@@ -235,7 +234,7 @@ public class AuthenticationWebFilterTests {
 		Mono<Authentication> authentication = Mono.just(new TestingAuthenticationToken("test", "this", "ROLE_USER"));
 		Mono<Authentication> authentication = Mono.just(new TestingAuthenticationToken("test", "this", "ROLE_USER"));
 		when(this.authenticationConverter.apply(any())).thenReturn(authentication);
 		when(this.authenticationConverter.apply(any())).thenReturn(authentication);
 		when(this.authenticationManager.authenticate(any())).thenReturn(Mono.error(new BadCredentialsException("Failed")));
 		when(this.authenticationManager.authenticate(any())).thenReturn(Mono.error(new BadCredentialsException("Failed")));
-		when(this.entryPoint.commence(any(),any())).thenReturn(Mono.empty());
+		when(this.failureHandler.onAuthenticationFailure(any(),any())).thenReturn(Mono.empty());
 
 
 		WebTestClient client = WebTestClientBuilder
 		WebTestClient client = WebTestClientBuilder
 			.bindToWebFilters(this.filter)
 			.bindToWebFilters(this.filter)
@@ -248,7 +247,7 @@ public class AuthenticationWebFilterTests {
 			.expectStatus().isOk()
 			.expectStatus().isOk()
 			.expectBody().isEmpty();
 			.expectBody().isEmpty();
 
 
-		verify(this.entryPoint).commence(any(),any());
+		verify(this.failureHandler).onAuthenticationFailure(any(),any());
 		verify(this.securityContextRepository, never()).save(any(), any());
 		verify(this.securityContextRepository, never()).save(any(), any());
 		verifyZeroInteractions(this.successHandler);
 		verifyZeroInteractions(this.successHandler);
 	}
 	}
@@ -258,7 +257,7 @@ public class AuthenticationWebFilterTests {
 		Mono<Authentication> authentication = Mono.just(new TestingAuthenticationToken("test", "this", "ROLE_USER"));
 		Mono<Authentication> authentication = Mono.just(new TestingAuthenticationToken("test", "this", "ROLE_USER"));
 		when(this.authenticationConverter.apply(any())).thenReturn(authentication);
 		when(this.authenticationConverter.apply(any())).thenReturn(authentication);
 		when(this.authenticationManager.authenticate(any())).thenReturn(Mono.error(new RuntimeException("Failed")));
 		when(this.authenticationManager.authenticate(any())).thenReturn(Mono.error(new RuntimeException("Failed")));
-		when(this.entryPoint.commence(any(),any())).thenReturn(Mono.empty());
+		when(this.failureHandler.onAuthenticationFailure(any(),any())).thenReturn(Mono.empty());
 
 
 		WebTestClient client = WebTestClientBuilder
 		WebTestClient client = WebTestClientBuilder
 			.bindToWebFilters(this.filter)
 			.bindToWebFilters(this.filter)
@@ -272,7 +271,7 @@ public class AuthenticationWebFilterTests {
 			.expectBody().isEmpty();
 			.expectBody().isEmpty();
 
 
 		verify(this.securityContextRepository, never()).save(any(), any());
 		verify(this.securityContextRepository, never()).save(any(), any());
-		verifyZeroInteractions(this.successHandler, this.entryPoint);
+		verifyZeroInteractions(this.successHandler, this.failureHandler);
 	}
 	}
 
 
 	@Test(expected = IllegalArgumentException.class)
 	@Test(expected = IllegalArgumentException.class)