瀏覽代碼

Add success handler modification of OAuth2LoginSpec

Add the ability to modify the success handler used in OAuth2LoginSpec. The
default success handler remains unchanged.

Closes #6863
Daniel Meier 6 年之前
父節點
當前提交
fcd8a38f0b

+ 17 - 2
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -695,6 +695,8 @@ public class ServerHttpSecurity {
 
 		private ServerWebExchangeMatcher authenticationMatcher;
 
+		private ServerAuthenticationSuccessHandler authenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler();
+
 		/**
 		 * Configures the {@link ReactiveAuthenticationManager} to use. The default is
 		 * {@link OAuth2AuthorizationCodeReactiveAuthenticationManager}
@@ -706,6 +708,20 @@ public class ServerHttpSecurity {
 			return this;
 		}
 
+		/**
+		 * The {@link ServerAuthenticationSuccessHandler} used after authentication success. Defaults to
+		 * {@link RedirectServerAuthenticationSuccessHandler} redirecting to "/".
+		 *
+		 * @since 5.2
+		 * @param authenticationSuccessHandler the success handler to use
+		 * @return the {@link OAuth2LoginSpec} to customize
+		 */
+		public OAuth2LoginSpec authenticationSuccessHandler(ServerAuthenticationSuccessHandler authenticationSuccessHandler) {
+			Assert.notNull(authenticationSuccessHandler, "authenticationSuccessHandler cannot be null");
+			this.authenticationSuccessHandler = authenticationSuccessHandler;
+			return this;
+		}
+
 		/**
 		 * Gets the {@link ReactiveAuthenticationManager} to use. First tries an explicitly configured manager, and
 		 * defaults to {@link OAuth2AuthorizationCodeReactiveAuthenticationManager}
@@ -821,9 +837,8 @@ public class ServerHttpSecurity {
 			AuthenticationWebFilter authenticationFilter = new OAuth2LoginAuthenticationWebFilter(manager, authorizedClientRepository);
 			authenticationFilter.setRequiresAuthenticationMatcher(getAuthenticationMatcher());
 			authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository));
-			RedirectServerAuthenticationSuccessHandler redirectHandler = new RedirectServerAuthenticationSuccessHandler();
 
-			authenticationFilter.setAuthenticationSuccessHandler(redirectHandler);
+			authenticationFilter.setAuthenticationSuccessHandler(this.authenticationSuccessHandler);
 			authenticationFilter.setAuthenticationFailureHandler(new ServerAuthenticationFailureHandler() {
 				@Override
 				public Mono<Void> onAuthenticationFailure(WebFilterExchange webFilterExchange,

+ 22 - 2
config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

@@ -23,7 +23,11 @@ import java.util.Map;
 
 import org.junit.Rule;
 import org.junit.Test;
+import org.mockito.stubbing.Answer;
 import org.openqa.selenium.WebDriver;
+import org.springframework.security.web.server.WebFilterExchange;
+import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler;
+import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
 import reactor.core.publisher.Mono;
 
 import org.springframework.beans.factory.annotation.Autowired;
@@ -184,6 +188,8 @@ public class OAuth2LoginTests {
 		this.spring.register(OAuth2LoginWithSingleClientRegistrations.class,
 				OAuth2LoginMockAuthenticationManagerConfig.class).autowire();
 
+		String redirectLocation = "/custom-redirect-location";
+
 		WebTestClient webTestClient = WebTestClientBuilder
 				.bindToWebFilters(this.springSecurity)
 				.build();
@@ -194,6 +200,7 @@ public class OAuth2LoginTests {
 		ReactiveAuthenticationManager manager = config.manager;
 		ServerWebExchangeMatcher matcher = config.matcher;
 		ServerOAuth2AuthorizationRequestResolver resolver = config.resolver;
+		ServerAuthenticationSuccessHandler successHandler = config.successHandler;
 
 		OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success();
 		OAuth2User user = TestOAuth2Users.create();
@@ -205,16 +212,25 @@ public class OAuth2LoginTests {
 		when(manager.authenticate(any())).thenReturn(Mono.just(result));
 		when(matcher.matches(any())).thenReturn(ServerWebExchangeMatcher.MatchResult.match());
 		when(resolver.resolve(any())).thenReturn(Mono.empty());
+		when(successHandler.onAuthenticationSuccess(any(), any())).thenAnswer((Answer<Mono<Void>>) invocation -> {
+			WebFilterExchange webFilterExchange = invocation.getArgument(0);
+			Authentication authentication = invocation.getArgument(1);
+
+			return new RedirectServerAuthenticationSuccessHandler(redirectLocation)
+					.onAuthenticationSuccess(webFilterExchange, authentication);
+		});
 
 		webTestClient.get()
 			.uri("/login/oauth2/code/github")
 			.exchange()
-			.expectStatus().is3xxRedirection();
+			.expectStatus().is3xxRedirection()
+			.expectHeader().valueEquals("Location", redirectLocation);
 
 		verify(converter).convert(any());
 		verify(manager).authenticate(any());
 		verify(matcher).matches(any());
 		verify(resolver).resolve(any());
+		verify(successHandler).onAuthenticationSuccess(any(), any());
 	}
 
 	@Configuration
@@ -227,6 +243,8 @@ public class OAuth2LoginTests {
 
 		ServerOAuth2AuthorizationRequestResolver resolver = mock(ServerOAuth2AuthorizationRequestResolver.class);
 
+		ServerAuthenticationSuccessHandler successHandler = mock(ServerAuthenticationSuccessHandler.class);
+
 		@Bean
 		public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
 			http
@@ -237,7 +255,8 @@ public class OAuth2LoginTests {
 					.authenticationConverter(authenticationConverter)
 					.authenticationManager(manager)
 					.authenticationMatcher(matcher)
-					.authorizationRequestResolver(resolver);
+					.authorizationRequestResolver(resolver)
+					.authenticationSuccessHandler(successHandler);
 			return http.build();
 		}
 	}
@@ -425,4 +444,5 @@ public class OAuth2LoginTests {
 	<T> T getBean(Class<T> beanClass) {
 		return this.spring.getContext().getBean(beanClass);
 	}
+
 }