Procházet zdrojové kódy

Allow to set default securityContextRepository for each authentication mechanisms

Fixes gh-7249
Eddú Meléndez před 6 roky
rodič
revize
8773c7994f

+ 35 - 17
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -268,7 +268,7 @@ public class ServerHttpSecurity {
 
 	private ReactiveAuthenticationManager authenticationManager;
 
-	private ServerSecurityContextRepository securityContextRepository = new WebSessionServerSecurityContextRepository();
+	private ServerSecurityContextRepository securityContextRepository;
 
 	private ServerAuthenticationEntryPoint authenticationEntryPoint;
 
@@ -346,7 +346,7 @@ public class ServerHttpSecurity {
 	}
 
 	/**
-	 * The strategy used with {@code ReactorContextWebFilter}. It does not impact how the {@code SecurityContext} is
+	 * The strategy used with {@code ReactorContextWebFilter}. It does impact how the {@code SecurityContext} is
 	 * saved which is configured on a per {@link AuthenticationWebFilter} basis.
 	 * @param securityContextRepository the repository to use
 	 * @return the {@link ServerHttpSecurity} to continue configuring
@@ -971,7 +971,7 @@ public class ServerHttpSecurity {
 
 		private ReactiveAuthenticationManager authenticationManager;
 
-		private ServerSecurityContextRepository securityContextRepository = new WebSessionServerSecurityContextRepository();
+		private ServerSecurityContextRepository securityContextRepository;
 
 		private ServerAuthenticationConverter authenticationConverter;
 
@@ -2254,9 +2254,7 @@ public class ServerHttpSecurity {
 			this.headers.configure(this);
 		}
 		WebFilter securityContextRepositoryWebFilter = securityContextRepositoryWebFilter();
-		if (securityContextRepositoryWebFilter != null) {
-			this.webFilters.add(securityContextRepositoryWebFilter);
-		}
+		this.webFilters.add(securityContextRepositoryWebFilter);
 		if (this.httpsRedirectSpec != null) {
 			this.httpsRedirectSpec.configure(this);
 		}
@@ -2273,18 +2271,42 @@ public class ServerHttpSecurity {
 			if (this.httpBasic.authenticationManager == null) {
 				this.httpBasic.authenticationManager(this.authenticationManager);
 			}
+			if (this.httpBasic.securityContextRepository != null) {
+				this.httpBasic.securityContextRepository(this.httpBasic.securityContextRepository);
+			}
+			else if (this.securityContextRepository != null) {
+				this.httpBasic.securityContextRepository(this.securityContextRepository);
+			}
+			else {
+				this.httpBasic.securityContextRepository(NoOpServerSecurityContextRepository.getInstance());
+			}
 			this.httpBasic.configure(this);
 		}
 		if (this.formLogin != null) {
 			if (this.formLogin.authenticationManager == null) {
 				this.formLogin.authenticationManager(this.authenticationManager);
 			}
-			if (this.securityContextRepository != null) {
+			if (this.formLogin.securityContextRepository != null) {
+				this.formLogin.securityContextRepository(this.formLogin.securityContextRepository);
+			}
+			else if (this.securityContextRepository != null) {
 				this.formLogin.securityContextRepository(this.securityContextRepository);
 			}
+			else {
+				this.formLogin.securityContextRepository(new WebSessionServerSecurityContextRepository());
+			}
 			this.formLogin.configure(this);
 		}
 		if (this.oauth2Login != null) {
+			if (this.oauth2Login.securityContextRepository != null) {
+				this.oauth2Login.securityContextRepository(this.oauth2Login.securityContextRepository);
+			}
+			else if (this.securityContextRepository != null) {
+				this.oauth2Login.securityContextRepository(this.securityContextRepository);
+			}
+			else {
+				this.oauth2Login.securityContextRepository(new WebSessionServerSecurityContextRepository());
+			}
 			this.oauth2Login.configure(this);
 		}
 		if (this.resourceServer != null) {
@@ -2379,10 +2401,8 @@ public class ServerHttpSecurity {
 	}
 
 	private WebFilter securityContextRepositoryWebFilter() {
-		ServerSecurityContextRepository repository = this.securityContextRepository;
-		if (repository == null) {
-			return null;
-		}
+		ServerSecurityContextRepository repository = this.securityContextRepository == null ?
+				new WebSessionServerSecurityContextRepository() : this.securityContextRepository;
 		WebFilter result = new ReactorContextWebFilter(repository);
 		return new OrderedWebFilter(result, SecurityWebFiltersOrder.REACTOR_CONTEXT.getOrder());
 	}
@@ -2774,7 +2794,7 @@ public class ServerHttpSecurity {
 	public class HttpBasicSpec {
 		private ReactiveAuthenticationManager authenticationManager;
 
-		private ServerSecurityContextRepository securityContextRepository = NoOpServerSecurityContextRepository.getInstance();
+		private ServerSecurityContextRepository securityContextRepository;
 
 		private ServerAuthenticationEntryPoint entryPoint = new HttpBasicServerAuthenticationEntryPoint();
 
@@ -2846,9 +2866,7 @@ public class ServerHttpSecurity {
 				this.authenticationManager);
 			authenticationFilter.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(this.entryPoint));
 			authenticationFilter.setAuthenticationConverter(new ServerHttpBasicAuthenticationConverter());
-			if (this.securityContextRepository != null) {
-				authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
-			}
+			authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
 			http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.HTTP_BASIC);
 		}
 
@@ -2869,7 +2887,7 @@ public class ServerHttpSecurity {
 
 		private ReactiveAuthenticationManager authenticationManager;
 
-		private ServerSecurityContextRepository securityContextRepository = new WebSessionServerSecurityContextRepository();
+		private ServerSecurityContextRepository securityContextRepository;
 
 		private ServerAuthenticationEntryPoint authenticationEntryPoint;
 
@@ -2966,7 +2984,7 @@ public class ServerHttpSecurity {
 
 		/**
 		 * The {@link ServerSecurityContextRepository} used to save the {@code Authentication}. Defaults to
-		 * {@link NoOpServerSecurityContextRepository}. For the {@code SecurityContext} to be loaded on subsequent
+		 * {@link WebSessionServerSecurityContextRepository}. For the {@code SecurityContext} to be loaded on subsequent
 		 * requests the {@link ReactorContextWebFilter} must be configured to be able to load the value (they are not
 		 * implicitly linked).
 		 *

+ 57 - 0
config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java

@@ -26,11 +26,15 @@ import org.openqa.selenium.support.PageFactory;
 import org.springframework.security.authentication.ReactiveAuthenticationManager;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder;
 import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
 import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.WebFilterChainProxy;
 import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler;
+import org.springframework.security.web.server.context.ServerSecurityContextRepository;
 import org.springframework.security.web.server.csrf.CsrfToken;
 import org.springframework.stereotype.Controller;
 import org.springframework.test.web.reactive.server.WebTestClient;
@@ -44,12 +48,15 @@ import static org.assertj.core.api.Assertions.assertThatCode;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.springframework.security.config.Customizer.withDefaults;
 
 /**
  * @author Rob Winch
+ * @author Eddú Meléndez
  * @since 5.0
  */
 public class FormLoginTests {
@@ -272,6 +279,50 @@ public class FormLoginTests {
 		verifyZeroInteractions(defaultAuthenticationManager);
 	}
 
+	@Test
+	public void formLoginSecurityContextRepository() {
+		ServerSecurityContextRepository defaultSecContextRepository = mock(ServerSecurityContextRepository.class);
+		ServerSecurityContextRepository formLoginSecContextRepository = mock(ServerSecurityContextRepository.class);
+
+		TestingAuthenticationToken token = new TestingAuthenticationToken("rob", "rob", "ROLE_USER");
+
+		given(defaultSecContextRepository.save(any(), any())).willReturn(Mono.empty());
+		given(defaultSecContextRepository.load(any())).willReturn(authentication(token));
+		given(formLoginSecContextRepository.save(any(), any())).willReturn(Mono.empty());
+		given(formLoginSecContextRepository.load(any())).willReturn(authentication(token));
+
+		SecurityWebFilterChain securityWebFilter = this.http
+				.authorizeExchange()
+					.anyExchange().authenticated()
+					.and()
+				.securityContextRepository(defaultSecContextRepository)
+				.formLogin()
+					.securityContextRepository(formLoginSecContextRepository)
+					.and()
+				.build();
+
+		WebTestClient webTestClient = WebTestClientBuilder
+				.bindToWebFilters(securityWebFilter)
+				.build();
+
+		WebDriver driver = WebTestClientHtmlUnitDriverBuilder
+				.webTestClientSetup(webTestClient)
+				.build();
+
+		DefaultLoginPage loginPage = DefaultLoginPage.to(driver)
+				.assertAt();
+
+		HomePage homePage = loginPage.loginForm()
+				.username("user")
+				.password("password")
+				.submit(HomePage.class);
+
+		homePage.assertAt();
+
+		verify(defaultSecContextRepository, atLeastOnce()).load(any());
+		verify(formLoginSecContextRepository).save(any(), any());
+	}
+
 	public static class CustomLoginPage {
 
 		private WebDriver driver;
@@ -501,4 +552,10 @@ public class FormLoginTests {
 				+ "</html>");
 		}
 	}
+
+	Mono<SecurityContext> authentication(Authentication authentication) {
+		SecurityContext context = new SecurityContextImpl();
+		context.setAuthentication(authentication);
+		return Mono.just(context);
+	}
 }

+ 1 - 0
config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

@@ -428,6 +428,7 @@ public class OAuth2LoginTests {
 
 		ServerSecurityContextRepository securityContextRepository = config.securityContextRepository;
 		when(securityContextRepository.save(any(), any())).thenReturn(Mono.empty());
+		when(securityContextRepository.load(any())).thenReturn(authentication(token));
 
 		Map<String, Object> additionalParameters = new HashMap<>();
 		additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");

+ 25 - 2
config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java

@@ -75,6 +75,7 @@ import org.springframework.security.web.server.authentication.HttpBasicServerAut
 
 /**
  * @author Rob Winch
+ * @author Eddú Meléndez
  * @since 5.0
  */
 @RunWith(MockitoJUnitRunner.class)
@@ -117,7 +118,6 @@ public class ServerHttpSecurityTests {
 	public void basic() {
 		given(this.authenticationManager.authenticate(any())).willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN")));
 
-		this.http.securityContextRepository(new WebSessionServerSecurityContextRepository());
 		this.http.httpBasic();
 		this.http.authenticationManager(this.authenticationManager);
 		ServerHttpSecurity.AuthorizeExchangeSpec authorize = this.http.authorizeExchange();
@@ -137,6 +137,30 @@ public class ServerHttpSecurityTests {
 		assertThat(result.getResponseCookies().getFirst("SESSION")).isNull();
 	}
 
+	@Test
+	public void basicWithGlobalWebSessionServerSecurityContextRepository() {
+		given(this.authenticationManager.authenticate(any())).willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN")));
+
+		this.http.securityContextRepository(new WebSessionServerSecurityContextRepository());
+		this.http.httpBasic();
+		this.http.authenticationManager(this.authenticationManager);
+		ServerHttpSecurity.AuthorizeExchangeSpec authorize = this.http.authorizeExchange();
+		authorize.anyExchange().authenticated();
+
+		WebTestClient client = buildClient();
+
+		EntityExchangeResult<String> result = client.get()
+				.uri("/")
+				.headers(headers -> headers.setBasicAuth("rob", "rob"))
+				.exchange()
+				.expectStatus().isOk()
+				.expectHeader().valueMatches(HttpHeaders.CACHE_CONTROL, ".+")
+				.expectBody(String.class).consumeWith(b -> assertThat(b.getResponseBody()).isEqualTo("ok"))
+				.returnResult();
+
+		assertThat(result.getResponseCookies().getFirst("SESSION")).isNotNull();
+	}
+
 	@Test
 	public void basicWhenNoCredentialsThenUnauthorized() {
 		this.http.authorizeExchange().anyExchange().authenticated();
@@ -256,7 +280,6 @@ public class ServerHttpSecurityTests {
 	public void basicWithAnonymous() {
 		given(this.authenticationManager.authenticate(any())).willReturn(Mono.just(new TestingAuthenticationToken("rob", "rob", "ROLE_USER", "ROLE_ADMIN")));
 
-		this.http.securityContextRepository(new WebSessionServerSecurityContextRepository());
 		this.http.httpBasic().and().anonymous();
 		this.http.authenticationManager(this.authenticationManager);
 		ServerHttpSecurity.AuthorizeExchangeSpec authorize = this.http.authorizeExchange();