소스 검색

Polish gh-7466

Joe Grandja 6 년 전
부모
커밋
08d2c93713

+ 7 - 10
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -76,7 +76,6 @@ import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserSer
 import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationCodeGrantWebFilter;
 import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
-import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationCodeAuthenticationTokenConverter;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
@@ -1106,13 +1105,14 @@ public class ServerHttpSecurity {
 		}
 
 		/**
-		 * Sets authorization request repository for {@link OAuth2AuthorizationRequestRedirectWebFilter}.
+		 * Sets the repository to use for storing {@link OAuth2AuthorizationRequest}'s.
 		 *
-		 * @param authorizationRequestRepository authorization request repository, must not be null
+		 * @since 5.2
+		 * @param authorizationRequestRepository the repository to use for storing {@link OAuth2AuthorizationRequest}'s
 		 * @return the {@link OAuth2LoginSpec} for further configuration
 		 */
-		public OAuth2LoginSpec authorizationRequestRepository(ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository) {
-			Assert.notNull(authorizationRequestRepository, "authorizationRequestRepository cannot be null");
+		public OAuth2LoginSpec authorizationRequestRepository(
+				ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository) {
 			this.authorizationRequestRepository = authorizationRequestRepository;
 			return this;
 		}
@@ -1163,9 +1163,7 @@ public class ServerHttpSecurity {
 			OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = getRedirectWebFilter();
 			ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
 					getAuthorizationRequestRepository();
-			if (authorizationRequestRepository != null) {
-				oauthRedirectFilter.setAuthorizationRequestRepository(authorizationRequestRepository);
-			}
+			oauthRedirectFilter.setAuthorizationRequestRepository(authorizationRequestRepository);
 			oauthRedirectFilter.setRequestCache(http.requestCache.requestCache);
 
 			ReactiveAuthenticationManager manager = getAuthenticationManager();
@@ -1267,10 +1265,9 @@ public class ServerHttpSecurity {
 			return result;
 		}
 
-		@SuppressWarnings("unchecked")
 		private ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> getAuthorizationRequestRepository() {
 			if (this.authorizationRequestRepository == null) {
-				this.authorizationRequestRepository = getBeanOrNull(ServerAuthorizationRequestRepository.class);
+				this.authorizationRequestRepository = new WebSessionOAuth2ServerAuthorizationRequestRepository();
 			}
 			return this.authorizationRequestRepository;
 		}

+ 75 - 10
config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

@@ -16,16 +16,10 @@
 
 package org.springframework.security.config.web.server;
 
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-
 import org.junit.Rule;
 import org.junit.Test;
 import org.mockito.stubbing.Answer;
 import org.openqa.selenium.WebDriver;
-import reactor.core.publisher.Mono;
-
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.ApplicationContext;
 import org.springframework.context.annotation.Bean;
@@ -41,6 +35,8 @@ import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
 import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
@@ -53,7 +49,9 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.client.userinfo.ReactiveOAuth2UserService;
+import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
 import org.springframework.security.oauth2.core.OAuth2Error;
@@ -84,20 +82,25 @@ import org.springframework.security.web.server.authentication.ServerAuthenticati
 import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler;
 import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
 import org.springframework.security.web.server.context.ServerSecurityContextRepository;
+import org.springframework.security.web.server.savedrequest.ServerRequestCache;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.test.web.reactive.server.WebTestClient;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.reactive.config.EnableWebFlux;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
 import org.springframework.web.server.WebFilterChain;
 import org.springframework.web.server.WebHandler;
+import reactor.core.publisher.Mono;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.spy;
-import static org.mockito.Mockito.verify;
-import static org.mockito.Mockito.when;
+import static org.mockito.Mockito.*;
 import static org.springframework.security.oauth2.jwt.TestJwts.jwt;
 
 /**
@@ -189,6 +192,68 @@ public class OAuth2LoginTests {
 		}
 	}
 
+	@Test
+	public void oauth2AuthorizeWhenCustomObjectsThenUsed() {
+		this.spring.register(OAuth2LoginWithSingleClientRegistrations.class,
+				OAuth2AuthorizeWithMockObjectsConfig.class,
+				AuthorizedClientController.class).autowire();
+
+		OAuth2AuthorizeWithMockObjectsConfig config = this.spring.getContext().getBean(OAuth2AuthorizeWithMockObjectsConfig.class);
+
+		ServerOAuth2AuthorizedClientRepository authorizedClientRepository = config.authorizedClientRepository;
+		ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = config.authorizationRequestRepository;
+		ServerRequestCache requestCache = config.requestCache;
+
+		when(authorizedClientRepository.loadAuthorizedClient(any(), any(), any())).thenReturn(Mono.empty());
+		when(authorizationRequestRepository.saveAuthorizationRequest(any(), any())).thenReturn(Mono.empty());
+		when(requestCache.removeMatchingRequest(any())).thenReturn(Mono.empty());
+		when(requestCache.saveRequest(any())).thenReturn(Mono.empty());
+
+		this.client.get()
+				.uri("/")
+				.exchange()
+				.expectStatus().is3xxRedirection();
+
+		verify(authorizedClientRepository).loadAuthorizedClient(any(), any(), any());
+		verify(authorizationRequestRepository).saveAuthorizationRequest(any(), any());
+		verify(requestCache).saveRequest(any());
+	}
+
+	@EnableWebFlux
+	static class OAuth2AuthorizeWithMockObjectsConfig {
+		ServerOAuth2AuthorizedClientRepository authorizedClientRepository =
+				mock(ServerOAuth2AuthorizedClientRepository.class);
+
+		ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
+				mock(ServerAuthorizationRequestRepository.class);
+
+		ServerRequestCache requestCache = mock(ServerRequestCache.class);
+
+		@Bean
+		SecurityWebFilterChain springSecurity(ServerHttpSecurity http) {
+			http
+				.requestCache()
+					.requestCache(this.requestCache)
+					.and()
+				.oauth2Login()
+					.authorizationRequestRepository(this.authorizationRequestRepository);
+			return http.build();
+		}
+
+		@Bean
+		ServerOAuth2AuthorizedClientRepository authorizedClientRepository() {
+			return this.authorizedClientRepository;
+		}
+	}
+
+	@RestController
+	static class AuthorizedClientController {
+		@GetMapping("/")
+		String home(@RegisteredOAuth2AuthorizedClient("github") OAuth2AuthorizedClient authorizedClient) {
+			return "home";
+		}
+	}
+
 	@Test
 	public void oauth2LoginWhenCustomObjectsThenUsed() {
 		this.spring.register(OAuth2LoginWithSingleClientRegistrations.class,

+ 3 - 79
config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java

@@ -20,14 +20,12 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
 import static org.springframework.security.config.Customizer.withDefaults;
 
 import java.util.Arrays;
-import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 import java.util.Optional;
@@ -43,7 +41,6 @@ import org.mockito.junit.MockitoJUnitRunner;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor;
 import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter;
-import org.springframework.web.server.handler.FilteringWebHandler;
 import reactor.core.publisher.Mono;
 import reactor.test.publisher.TestPublisher;
 
@@ -51,29 +48,18 @@ import org.springframework.security.authentication.ReactiveAuthenticationManager
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
 import org.springframework.security.core.context.SecurityContext;
-import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
-import org.springframework.security.oauth2.client.registration.ClientRegistration;
-import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
-import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
-import org.springframework.security.oauth2.client.web.server.OAuth2AuthorizationRequestRedirectWebFilter;
-import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
-import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
 import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.WebFilterChainProxy;
-import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests;
-import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
 import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler;
 import org.springframework.security.web.server.authentication.logout.LogoutWebFilter;
 import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler;
 import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler;
-import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
 import org.springframework.security.web.server.context.ServerSecurityContextRepository;
 import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
 import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
 import org.springframework.security.web.server.csrf.CsrfWebFilter;
 import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
-import org.springframework.security.web.server.savedrequest.ServerRequestCache;
 import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.test.web.reactive.server.EntityExchangeResult;
 import org.springframework.test.web.reactive.server.FluxExchangeResult;
@@ -82,7 +68,10 @@ import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
+import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
 import org.springframework.web.server.WebFilterChain;
+import org.springframework.security.web.server.authentication.AnonymousAuthenticationWebFilterTests;
+import org.springframework.security.web.server.authentication.HttpBasicServerAuthenticationEntryPoint;
 
 /**
  * @author Rob Winch
@@ -486,71 +475,6 @@ public class ServerHttpSecurityTests {
 		verify(customServerCsrfTokenRepository).loadToken(any());
 	}
 
-	@SuppressWarnings("UnassignedFluxMonoInstance")
-	@Test
-	public void configureOAuth2LoginUsingCustomCommonServerRequestCache() {
-		ServerRequestCache requestCacheMock = mock(ServerRequestCache.class);
-		when(requestCacheMock.saveRequest(any(ServerWebExchange.class))).thenReturn(Mono.empty());
-
-		ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
-		String registrationId = clientRegistration.getRegistrationId();
-
-		ReactiveClientRegistrationRepository clientRegistrationRepositoryMock =
-				mock(ReactiveClientRegistrationRepository.class);
-		when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId))
-				.thenReturn(Mono.just(clientRegistration));
-
-		SecurityWebFilterChain filterChain = http.requestCache().requestCache(requestCacheMock)
-				.and().oauth2Login().clientRegistrationRepository(clientRegistrationRepositoryMock)
-				.and().build();
-
-		Optional<OAuth2AuthorizationRequestRedirectWebFilter> redirectWebFilter =
-				getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class);
-		assertThat(redirectWebFilter.isPresent()).isTrue();
-
-		FilteringWebHandler webHandler = new FilteringWebHandler(
-				e -> Mono.error(new ClientAuthorizationRequiredException(registrationId)),
-				Collections.singletonList(redirectWebFilter.get())
-		);
-		WebTestClient client = WebTestClient.bindToWebHandler(webHandler).build();
-		client.get().uri("/foo/bar").exchange();
-		verify(requestCacheMock, times(1)).saveRequest(any(ServerWebExchange.class));
-	}
-
-	@Test(expected = IllegalArgumentException.class)
-	public void throwExceptionWhenNullPassedForOAuth2LoginAuthorizationRequestRepository() {
-		http.oauth2Login().authorizationRequestRepository(null).and().build();
-	}
-
-	@SuppressWarnings({"UnassignedFluxMonoInstance", "unchecked"})
-	@Test
-	public void configureOAuth2LoginUsingCustomAuthorizationRequestRepository() {
-		ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
-		String registrationId = clientRegistration.getRegistrationId();
-
-		ReactiveClientRegistrationRepository clientRegistrationRepositoryMock =
-				mock(ReactiveClientRegistrationRepository.class);
-		when(clientRegistrationRepositoryMock.findByRegistrationId(registrationId))
-				.thenReturn(Mono.just(clientRegistration));
-
-		ServerAuthorizationRequestRepository requestRepositoryMock = mock(ServerAuthorizationRequestRepository.class);
-		SecurityWebFilterChain filterChain = http.oauth2Login()
-				.clientRegistrationRepository(clientRegistrationRepositoryMock)
-				.authorizationRequestRepository(requestRepositoryMock)
-				.and().build();
-
-		Optional<OAuth2AuthorizationRequestRedirectWebFilter> redirectWebFilter =
-				getWebFilter(filterChain, OAuth2AuthorizationRequestRedirectWebFilter.class);
-		assertThat(redirectWebFilter.isPresent()).isTrue();
-
-		WebTestClient client = WebTestClient.bindToController(new SubscriberContextController())
-				.webFilter(redirectWebFilter.get())
-				.build();
-		client.get().uri("/oauth2/authorization/" + registrationId).exchange();
-		verify(requestRepositoryMock, times(1)).saveAuthorizationRequest(any(OAuth2AuthorizationRequest.class),
-				any(ServerWebExchange.class));
-	}
-
 	private boolean isX509Filter(WebFilter filter) {
 		try {
 			Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter");