浏览代码

Allow to customize OAuth2AuthorizationRequestRedirectWebFilter in OAuth2LoginSpec

Fixes gh-7466
Roman Chigvintsev 6 年之前
父节点
当前提交
9bae0a4dbd

+ 29 - 0
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -76,6 +76,7 @@ import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserSer
 import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter;
 import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
+import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
@@ -972,6 +973,8 @@ public class ServerHttpSecurity {
 
 		private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
 
+		private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository;
+
 		private ReactiveAuthenticationManager authenticationManager;
 
 		private ServerSecurityContextRepository securityContextRepository;
@@ -1102,6 +1105,18 @@ public class ServerHttpSecurity {
 			return this;
 		}
 
+		/**
+		 * Sets authorization request repository for {@link OAuth2AuthorizationRequestRedirectWebFilter}.
+		 *
+		 * @param authorizationRequestRepository authorization request repository, must not be null
+		 * @return the {@link OAuth2LoginSpec} for further configuration
+		 */
+		public OAuth2LoginSpec authorizationRequestRepository(ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository) {
+			Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null");
+			this.authorizationRequestRepository = authorizationRequestRepository;
+			return this;
+		}
+
 		/**
 		 * Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s.
 		 *
@@ -1146,6 +1161,12 @@ public class ServerHttpSecurity {
 			ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository();
 			ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository();
 			OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = getRedirectWebFilter();
+			ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
+					getAuthorizationRequestRepository();
+			if (authorizationRequestRepository != null) {
+				oauthRedirectFilter.setAuthorizationRequestRepository(authorizationRequestRepository);
+			}
+			oauthRedirectFilter.setRequestCache(http.requestCache.requestCache);
 
 			ReactiveAuthenticationManager manager = getAuthenticationManager();
 
@@ -1246,6 +1267,14 @@ public class ServerHttpSecurity {
 			return result;
 		}
 
+		@SuppressWarnings("unchecked")
+		private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> getAuthorizationRequestRepository() {
+			if (this.authorizationRequestRepository == null) {
+				this.authorizationRequestRepository = getBeanOrNull(ServerAuthorizationRequestRepository.class);
+			}
+			return this.authorizationRequestRepository;
+		}
+
 		private ReactiveOAuth2AuthorizedClientService getAuthorizedClientService() {
 			ReactiveOAuth2AuthorizedClientService service = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class);
 			if (service == null) {

+ 79 - 3
config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java

@@ -20,12 +20,14 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
 import static org.springframework.security.config.Customizer.withDefaults;
 
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
@@ -41,6 +43,7 @@ import org.mockito.junit.MockitoJUnitRunner;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
 import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
+import org.springframework.web.server.handler.FilteringWebHandler;
 import reactor.core.publisher.Mono;
 import reactor.test.publisher.TestPublisher;
 
@@ -48,18 +51,29 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
 import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
+import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
 import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.WebFilterChainProxy;
+import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests;
+import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
 import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler;
 import org.springframework.security.web.server.authentication.logout.LogoutWebFilter;
 import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler;
 import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler;
+import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
 import org.springframework.security.web.server.context.ServerSecurityContextRepository;
 import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
 import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
 import org.springframework.security.web.server.csrf.CsrfWebFilter;
 import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
+import org.springframework.security.web.server.savedrequest.ServerRequestCache;
 import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.test.web.reactive.server.EntityExchangeResult;
 import org.springframework.test.web.reactive.server.FluxExchangeResult;
@@ -68,10 +82,7 @@ import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
-import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
 import org.springframework.web.server.WebFilterChain;
-import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests;
-import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
 
 /**
  * @author Rob Winch
@@ -475,6 +486,71 @@ public class ServerHttpSecurityTests {
 		verify(customServerCsrfTokenRepository).loadToken(any());
 	}
 
+	@SuppressWarnings("UnassignedFluxMonoInstance")
+	@Test
+	public void configureOAuth2LoginUsingCustomCommonServerRequestCache() {
+		ServerRequestCache requestCacheMock = mock(ServerRequestCache.class);
+		when(requestCacheMock.saveRequest(any(ServerWebExchange.class))).thenReturn(Mono.empty());
+
+		ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
+		String registrationId = clientRegistration.getRegistrationId();
+
+		ReactiveClientRegistrationRepository clientRegistrationRepositoryMock =
+				mock(ReactiveClientRegistrationRepository.class);
+		when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId))
+				.thenReturn(Mono.just(clientRegistration));
+
+		SecurityWebFilterChain filterChain = http.requestCache().requestCache(requestCacheMock)
+				.and().oauth2Login().clientRegistrationRepository(clientRegistrationRepositoryMock)
+				.and().build();
+
+		Optional<OAuth2AuthorizationRequestRedirectWebFilter> redirectWebFilter =
+				getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class);
+		assertThat(redirectWebFilter.isPresent()).isTrue();
+
+		FilteringWebHandler webHandler = new FilteringWebHandler(
+				e -> Mono.error(new ClientAuthorizationRequiredException(registrationId)),
+				Collections.singletonList(redirectWebFilter.get())
+		);
+		WebTestClient client = WebTestClient.bindToWebHandler(webHandler).build();
+		client.get().uri("/foo/bar").exchange();
+		verify(requestCacheMock, times(1)).saveRequest(any(ServerWebExchange.class));
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void throwExceptionWhenNullPassedForOAuth2LoginAuthorizationRequestRepository() {
+		http.oauth2Login().authorizationRequestRepository(null).and().build();
+	}
+
+	@SuppressWarnings({"UnassignedFluxMonoInstance", "unchecked"})
+	@Test
+	public void configureOAuth2LoginUsingCustomAuthorizationRequestRepository() {
+		ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
+		String registrationId = clientRegistration.getRegistrationId();
+
+		ReactiveClientRegistrationRepository clientRegistrationRepositoryMock =
+				mock(ReactiveClientRegistrationRepository.class);
+		when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId))
+				.thenReturn(Mono.just(clientRegistration));
+
+		ServerAuthorizationRequestRepository requestRepositoryMock = mock(ServerAuthorizationRequestRepository.class);
+		SecurityWebFilterChain filterChain = http.oauth2Login()
+				.clientRegistrationRepository(clientRegistrationRepositoryMock)
+				.authorizationRequestRepository(requestRepositoryMock)
+				.and().build();
+
+		Optional<OAuth2AuthorizationRequestRedirectWebFilter> redirectWebFilter =
+				getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class);
+		assertThat(redirectWebFilter.isPresent()).isTrue();
+
+		WebTestClient client = WebTestClient.bindToController(new SubscriberContextController())
+				.webFilter(redirectWebFilter.get())
+				.build();
+		client.get().uri("/oauth2/authorization/" + registrationId).exchange();
+		verify(requestRepositoryMock, times(1)).saveAuthorizationRequest(any(OAuth2AuthorizationRequest.class),
+				any(ServerWebExchange.class));
+	}
+
 	private boolean isX509Filter(WebFilter filter) {
 		try {
 			Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");