浏览代码

Use the custom ServerRequestCache that the user configures

on for the default authentication entry point and authentication
success handler

Fixes gh-7721

https://github.com/spring-projects/spring-security/issues/7721

Set RequestCache on the Oauth2LoginSpec default authentication success handler

import static ReflectionTestUtils.getField

Feedback incorporated per

https://github.com/spring-projects/spring-security/pull/7734#pullrequestreview-332150359
Filip Hanik 5 年之前
父节点
当前提交
9aa333ca4d

+ 19 - 4
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -76,9 +76,11 @@ import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserSer
 import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter;
 import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
+import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.server.WebSessionOAuth2ServerAuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.oidc.user.OidcUser;
@@ -984,7 +986,7 @@ public class ServerHttpSecurity {
 
 		private ServerWebExchangeMatcher authenticationMatcher;
 
-		private ServerAuthenticationSuccessHandler authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler();
+		private ServerAuthenticationSuccessHandler authenticationSuccessHandler;
 
 		private ServerAuthenticationFailureHandler authenticationFailureHandler;
 
@@ -1175,7 +1177,7 @@ public class ServerHttpSecurity {
 			authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher());
 			authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository));
 
-			authenticationFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler);
+			authenticationFilter.setAuthenticationSuccessHandler(getAuthenticationSuccessHandler(http));
 			authenticationFilter.setAuthenticationFailureHandler(getAuthenticationFailureHandler());
 			authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
 
@@ -1183,16 +1185,29 @@ public class ServerHttpSecurity {
 					MediaType.TEXT_HTML);
 			htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL));
 			Map<String, String> urlToText = http.oauth2Login.getLinks();
+			String authenticationEntryPointRedirectPath;
 			if (urlToText.size() == 1) {
-				http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint(urlToText.keySet().iterator().next())));
+				authenticationEntryPointRedirectPath = urlToText.keySet().iterator().next();
 			} else {
-				http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint("/login")));
+				authenticationEntryPointRedirectPath = "/login";
 			}
+			RedirectServerAuthenticationEntryPoint entryPoint = new RedirectServerAuthenticationEntryPoint(authenticationEntryPointRedirectPath);
+			entryPoint.setRequestCache(http.requestCache.requestCache);
+			http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, entryPoint));
 
 			http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC);
 			http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION);
 		}
 
+		private ServerAuthenticationSuccessHandler getAuthenticationSuccessHandler(ServerHttpSecurity http) {
+			if (this.authenticationSuccessHandler == null) {
+				RedirectServerAuthenticationSuccessHandler handler = new RedirectServerAuthenticationSuccessHandler();
+				handler.setRequestCache(http.requestCache.requestCache);
+				this.authenticationSuccessHandler = handler;
+			}
+			return this.authenticationSuccessHandler;
+		}
+
 		private ServerAuthenticationFailureHandler getAuthenticationFailureHandler() {
 			if (this.authenticationFailureHandler == null) {
 				this.authenticationFailureHandler = new RedirectServerAuthenticationFailureHandler("/login?error");

+ 38 - 6
config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java

@@ -20,10 +20,12 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
 import static org.springframework.security.config.Customizer.withDefaults;
+import static org.springframework.test.util.ReflectionTestUtils.getField;
 
 import java.util.Arrays;
 import java.util.List;
@@ -35,16 +37,20 @@ import org.apache.http.HttpHeaders;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
+import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
 import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
 import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
+import org.springframework.security.web.server.savedrequest.ServerRequestCache;
+import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache;
 import reactor.core.publisher.Mono;
 import reactor.test.publisher.TestPublisher;
 
@@ -64,7 +70,6 @@ import org.springframework.security.web.server.context.WebSessionServerSecurityC
 import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
 import org.springframework.security.web.server.csrf.CsrfWebFilter;
 import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
-import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.test.web.reactive.server.EntityExchangeResult;
 import org.springframework.test.web.reactive.server.FluxExchangeResult;
 import org.springframework.test.web.reactive.server.WebTestClient;
@@ -200,7 +205,7 @@ public class ServerHttpSecurityTests {
 				.isNotPresent();
 
 		Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
-				.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
+				.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
 
 		assertThat(logoutHandler)
 				.get()
@@ -213,17 +218,17 @@ public class ServerHttpSecurityTests {
 
 		assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class))
 				.get()
-				.extracting(csrfWebFilter -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository"))
+				.extracting(csrfWebFilter -> getField(csrfWebFilter, "csrfTokenRepository"))
 				.isEqualTo(this.csrfTokenRepository);
 
 		Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
-				.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
+				.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
 
 		assertThat(logoutHandler)
 				.get()
 				.isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
 				.extracting(delegatingLogoutHandler ->
-						((List<ServerLogoutHandler>) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
+						((List<ServerLogoutHandler>) getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
 								.map(ServerLogoutHandler::getClass)
 								.collect(Collectors.toList()))
 				.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
@@ -479,6 +484,33 @@ public class ServerHttpSecurityTests {
 		verify(customServerCsrfTokenRepository).loadToken(any());
 	}
 
+	@Test
+	public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler() {
+		ServerRequestCache requestCache = spy(new WebSessionServerRequestCache());
+		ReactiveClientRegistrationRepository clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class);
+
+		SecurityWebFilterChain securityFilterChain = this.http
+				.oauth2Login()
+				.clientRegistrationRepository(clientRegistrationRepository)
+				.and()
+				.authorizeExchange().anyExchange().authenticated()
+				.and()
+				.requestCache(c -> c.requestCache(requestCache))
+				.build();
+
+		WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build();
+		client.get().uri("/test").exchange();
+		ArgumentCaptor<ServerWebExchange> captor = ArgumentCaptor.forClass(ServerWebExchange.class);
+		verify(requestCache).saveRequest(captor.capture());
+		assertThat(captor.getValue().getRequest().getURI().toString()).isEqualTo("/test");
+
+
+		OAuth2LoginAuthenticationWebFilter authenticationWebFilter =
+				getWebFilter(securityFilterChain, OAuth2LoginAuthenticationWebFilter.class).get();
+		Object handler = getField(authenticationWebFilter, "authenticationSuccessHandler");
+		assertThat(getField(handler, "requestCache")).isSameAs(requestCache);
+	}
+
 	@Test
 	public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() {
 		ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class);
@@ -503,7 +535,7 @@ public class ServerHttpSecurityTests {
 
 	private boolean isX509Filter(WebFilter filter) {
 		try {
-			Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");
+			Object converter = getField(filter, "authenticationConverter");
 			return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class);
 		} catch (IllegalArgumentException e) {
 			// field doesn't exist