Browse Source

Remove DefaultAuthenticationSuccessHandler

We always need to save the user after authentication, so it should be
part of AuthenticationWebFilter

Fixes gh-4524
Rob Winch 8 years ago
parent
commit
8d997fd079

+ 3 - 5
config/src/main/java/org/springframework/security/config/web/server/HttpSecurity.java

@@ -31,7 +31,6 @@ import org.springframework.security.web.server.HttpBasicAuthenticationConverter;
 import org.springframework.security.web.server.MatcherSecurityWebFilterChain;
 import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.authentication.AuthenticationWebFilter;
-import org.springframework.security.web.server.authentication.DefaultAuthenticationSuccessHandler;
 import org.springframework.security.web.server.authentication.www.HttpBasicAuthenticationEntryPoint;
 import org.springframework.security.web.server.authorization.AuthorizationContext;
 import org.springframework.security.web.server.authorization.AuthorizationWebFilter;
@@ -40,6 +39,7 @@ import org.springframework.security.web.server.context.AuthenticationReactorCont
 import org.springframework.security.web.server.context.SecurityContextRepositoryWebFilter;
 import org.springframework.security.web.server.authorization.ExceptionTranslationWebFilter;
 import org.springframework.security.web.server.context.SecurityContextRepository;
+import org.springframework.security.web.server.context.ServerWebExchangeAttributeSecurityContextRepository;
 import org.springframework.security.web.server.header.CacheControlHttpHeadersWriter;
 import org.springframework.security.web.server.header.CompositeHttpHeadersWriter;
 import org.springframework.security.web.server.header.ContentTypeOptionsHttpHeadersWriter;
@@ -232,7 +232,7 @@ public class HttpSecurity {
 	public class HttpBasicBuilder {
 		private ReactiveAuthenticationManager authenticationManager;
 
-		private SecurityContextRepository securityContextRepository;
+		private SecurityContextRepository securityContextRepository = new ServerWebExchangeAttributeSecurityContextRepository();
 
 		private AuthenticationEntryPoint entryPoint = new HttpBasicAuthenticationEntryPoint();
 
@@ -261,9 +261,7 @@ public class HttpSecurity {
 			authenticationFilter.setEntryPoint(this.entryPoint);
 			authenticationFilter.setAuthenticationConverter(new HttpBasicAuthenticationConverter());
 			if(this.securityContextRepository != null) {
-				DefaultAuthenticationSuccessHandler handler = new DefaultAuthenticationSuccessHandler();
-				handler.setSecurityContextRepository(this.securityContextRepository);
-				authenticationFilter.setAuthenticationSuccessHandler(handler);
+				authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
 			}
 			return authenticationFilter;
 		}

+ 20 - 2
webflux/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java

@@ -20,9 +20,12 @@ package org.springframework.security.web.server.authentication;
 import org.springframework.security.authentication.ReactiveAuthenticationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.web.server.AuthenticationEntryPoint;
 import org.springframework.security.web.server.HttpBasicAuthenticationConverter;
 import org.springframework.security.web.server.authentication.www.HttpBasicAuthenticationEntryPoint;
+import org.springframework.security.web.server.context.SecurityContextRepository;
+import org.springframework.security.web.server.context.ServerWebExchangeAttributeSecurityContextRepository;
 import org.springframework.util.Assert;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
@@ -40,12 +43,14 @@ public class AuthenticationWebFilter implements WebFilter {
 
 	private final ReactiveAuthenticationManager authenticationManager;
 
-	private AuthenticationSuccessHandler authenticationSuccessHandler = new DefaultAuthenticationSuccessHandler();
+	private AuthenticationSuccessHandler authenticationSuccessHandler = new WebFilterChainAuthenticationSuccessHandler();
 
 	private Function<ServerWebExchange,Mono<Authentication>> authenticationConverter = new HttpBasicAuthenticationConverter();
 
 	private AuthenticationEntryPoint entryPoint = new HttpBasicAuthenticationEntryPoint();
 
+	private SecurityContextRepository securityContextRepository = new ServerWebExchangeAttributeSecurityContextRepository();
+
 	public AuthenticationWebFilter(ReactiveAuthenticationManager authenticationManager) {
 		Assert.notNull(authenticationManager, "authenticationManager cannot be null");
 		this.authenticationManager = authenticationManager;
@@ -56,11 +61,24 @@ public class AuthenticationWebFilter implements WebFilter {
 		return this.authenticationConverter.apply(exchange)
 			.switchIfEmpty(Mono.defer(() -> chain.filter(exchange).cast(Authentication.class)))
 			.flatMap( token -> this.authenticationManager.authenticate(token)
-				.flatMap(authentication -> this.authenticationSuccessHandler.success(authentication, exchange, chain))
+				.flatMap(authentication -> onAuthenticationSuccess(authentication, exchange, chain))
 				.onErrorResume( AuthenticationException.class, t -> this.entryPoint.commence(exchange, t))
 			);
 	}
 
+	private Mono<Void> onAuthenticationSuccess(Authentication authentication, ServerWebExchange exchange, WebFilterChain chain) {
+		SecurityContextImpl securityContext = new SecurityContextImpl();
+		securityContext.setAuthentication(authentication);
+		return this.securityContextRepository.save(exchange, securityContext)
+			.flatMap( wrappedExchange -> this.authenticationSuccessHandler.success(authentication, wrappedExchange, chain));
+	}
+
+	public void setSecurityContextRepository(
+		SecurityContextRepository securityContextRepository) {
+		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
+		this.securityContextRepository = securityContextRepository;
+	}
+
 	public void setAuthenticationSuccessHandler(AuthenticationSuccessHandler authenticationSuccessHandler) {
 		this.authenticationSuccessHandler = authenticationSuccessHandler;
 	}

+ 0 - 56
webflux/src/main/java/org/springframework/security/web/server/authentication/DefaultAuthenticationSuccessHandler.java

@@ -1,56 +0,0 @@
-/*
- *
- *  * Copyright 2002-2017 the original author or authors.
- *  *
- *  * Licensed under the Apache License, Version 2.0 (the "License");
- *  * you may not use this file except in compliance with the License.
- *  * You may obtain a copy of the License at
- *  *
- *  *      http://www.apache.org/licenses/LICENSE-2.0
- *  *
- *  * Unless required by applicable law or agreed to in writing, software
- *  * distributed under the License is distributed on an "AS IS" BASIS,
- *  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- *  * See the License for the specific language governing permissions and
- *  * limitations under the License.
- *
- */
-
-package org.springframework.security.web.server.authentication;
-
-import org.springframework.security.core.Authentication;
-import org.springframework.security.core.context.SecurityContextImpl;
-import org.springframework.security.web.server.context.SecurityContextRepository;
-import org.springframework.security.web.server.context.ServerWebExchangeAttributeSecurityContextRepository;
-import org.springframework.util.Assert;
-import org.springframework.web.server.ServerWebExchange;
-import org.springframework.web.server.WebFilterChain;
-import reactor.core.publisher.Mono;
-
-/**
- * @author Rob Winch
- * @since 5.0
- */
-public class DefaultAuthenticationSuccessHandler implements AuthenticationSuccessHandler {
-	private SecurityContextRepository securityContextRepository = new ServerWebExchangeAttributeSecurityContextRepository();
-
-	private AuthenticationSuccessHandler delegate = new WebFilterChainAuthenticationSuccessHandler();
-
-	@Override
-	public Mono<Void> success(Authentication authentication, ServerWebExchange exchange, WebFilterChain chain) {
-		SecurityContextImpl securityContext = new SecurityContextImpl();
-		securityContext.setAuthentication(authentication);
-		return securityContextRepository.save(exchange, securityContext)
-			.flatMap( wrappedExchange -> delegate.success(authentication, wrappedExchange, chain));
-	}
-
-	public void setDelegate(AuthenticationSuccessHandler delegate) {
-		Assert.notNull(delegate, "delegate cannot be null");
-		this.delegate = delegate;
-	}
-
-	public void setSecurityContextRepository(SecurityContextRepository securityContextRepository) {
-		Assert.notNull(securityContextRepository, "securityContextRepository cannot be null");
-		this.securityContextRepository = securityContextRepository;
-	}
-}

+ 16 - 4
webflux/src/test/java/org/springframework/security/web/server/authentication/AuthenticationWebFilterTests.java

@@ -18,27 +18,30 @@
 
 package org.springframework.security.web.server.authentication;
 
+import java.util.function.Function;
+
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.runners.MockitoJUnitRunner;
+import reactor.core.publisher.Mono;
+
 import org.springframework.security.authentication.BadCredentialsException;
 import org.springframework.security.authentication.ReactiveAuthenticationManager;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
 import org.springframework.security.web.server.AuthenticationEntryPoint;
+import org.springframework.security.web.server.context.SecurityContextRepository;
 import org.springframework.test.web.reactive.server.EntityExchangeResult;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.web.server.ServerWebExchange;
-import reactor.core.publisher.Mono;
-
-import java.util.function.Function;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
@@ -59,6 +62,8 @@ public class AuthenticationWebFilterTests {
 	private ReactiveAuthenticationManager authenticationManager;
 	@Mock
 	private AuthenticationEntryPoint entryPoint;
+	@Mock
+	private SecurityContextRepository securityContextRepository;
 
 	private AuthenticationWebFilter filter;
 
@@ -68,6 +73,7 @@ public class AuthenticationWebFilterTests {
 		this.filter.setAuthenticationSuccessHandler(this.successHandler);
 		this.filter.setAuthenticationConverter(this.authenticationConverter);
 		this.filter.setEntryPoint(this.entryPoint);
+		this.filter.setSecurityContextRepository(this.securityContextRepository);
 	}
 
 	@Test
@@ -151,6 +157,7 @@ public class AuthenticationWebFilterTests {
 			.expectBody(String.class).consumeWith(b -> assertThat(b.getResponseBody()).isEqualTo("ok"))
 			.returnResult();
 
+		verify(this.securityContextRepository, never()).save(any(), any());
 		verifyZeroInteractions(this.authenticationManager, this.successHandler,
 			this.entryPoint);
 	}
@@ -170,16 +177,18 @@ public class AuthenticationWebFilterTests {
 			.expectStatus().is5xxServerError()
 			.expectBody().isEmpty();
 
+		verify(this.securityContextRepository, never()).save(any(), any());
 		verifyZeroInteractions(this.authenticationManager, this.successHandler,
 			this.entryPoint);
 	}
 
 	@Test
-	public void filterWhenConvertAndAuthenticationSuccessThenSuccessHandler() {
+	public void filterWhenConvertAndAuthenticationSuccessThenSuccess() {
 		Mono<Authentication> authentication = Mono.just(new TestingAuthenticationToken("test", "this", "ROLE_USER"));
 		when(this.authenticationConverter.apply(any())).thenReturn(authentication);
 		when(this.authenticationManager.authenticate(any())).thenReturn(authentication);
 		when(this.successHandler.success(any(),any(),any())).thenReturn(Mono.empty());
+		when(this.securityContextRepository.save(any(),any())).thenAnswer( a -> Mono.just(a.getArguments()[0]));
 
 		WebTestClient client = WebTestClientBuilder
 			.bindToWebFilters(this.filter)
@@ -193,6 +202,7 @@ public class AuthenticationWebFilterTests {
 			.expectBody().isEmpty();
 
 		verify(this.successHandler).success(eq(authentication.block()), any(), any());
+		verify(this.securityContextRepository).save(any(), any());
 		verifyZeroInteractions(this.entryPoint);
 	}
 
@@ -215,6 +225,7 @@ public class AuthenticationWebFilterTests {
 			.expectBody().isEmpty();
 
 		verify(this.entryPoint).commence(any(),any());
+		verify(this.securityContextRepository, never()).save(any(), any());
 		verifyZeroInteractions(this.successHandler);
 	}
 
@@ -236,6 +247,7 @@ public class AuthenticationWebFilterTests {
 			.expectStatus().is5xxServerError()
 			.expectBody().isEmpty();
 
+		verify(this.securityContextRepository, never()).save(any(), any());
 		verifyZeroInteractions(this.successHandler, this.entryPoint);
 	}
 }