浏览代码

Allow ServerOAuth2AuthorizationRequestResolver to be set on oauth2 client configuration

Closes gh-12430
Spas Poptchev 2 年之前
父节点
当前提交
919280b3e4

+ 23 - 2
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -3831,9 +3831,31 @@ public class ServerHttpSecurity {
 
 		private ServerRedirectStrategy authorizationRedirectStrategy;
 
+		private ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver;
+
 		private OAuth2ClientSpec() {
 		}
 
+		/**
+		 * Sets the resolver used for resolving {@link OAuth2AuthorizationRequest}'s.
+		 * @param authorizationRequestResolver the resolver used for resolving
+		 * {@link OAuth2AuthorizationRequest}'s
+		 * @return the {@link OAuth2ClientSpec} for further configuration
+		 * @since 6.1
+		 */
+		public OAuth2ClientSpec authorizationRequestResolver(
+				ServerOAuth2AuthorizationRequestResolver authorizationRequestResolver) {
+			this.authorizationRequestResolver = authorizationRequestResolver;
+			return this;
+		}
+
+		private OAuth2AuthorizationRequestRedirectWebFilter getRedirectWebFilter() {
+			if (this.authorizationRequestResolver != null) {
+				return new OAuth2AuthorizationRequestRedirectWebFilter(this.authorizationRequestResolver);
+			}
+			return new OAuth2AuthorizationRequestRedirectWebFilter(getClientRegistrationRepository());
+		}
+
 		/**
 		 * Sets the converter to use
 		 * @param authenticationConverter the converter to use
@@ -3960,8 +3982,7 @@ public class ServerHttpSecurity {
 				codeGrantWebFilter.setRequestCache(http.requestCache.requestCache);
 			}
 
-			OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(
-					clientRegistrationRepository);
+			OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = getRedirectWebFilter();
 			oauthRedirectFilter.setAuthorizationRequestRepository(getAuthorizationRequestRepository());
 			oauthRedirectFilter.setAuthorizationRedirectStrategy(getAuthorizationRedirectStrategy());
 			if (http.requestCache != null) {

+ 7 - 0
config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java

@@ -40,6 +40,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
@@ -134,6 +135,7 @@ public class OAuth2ClientSpecTests {
 		ReactiveAuthenticationManager manager = config.manager;
 		ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = config.authorizationRequestRepository;
 		ServerRequestCache requestCache = config.requestCache;
+		ServerOAuth2AuthorizationRequestResolver resolver = config.resolver;
 		OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
 				.redirectUri("/authorize/oauth2/code/registration-id").build();
 		OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success()
@@ -145,6 +147,7 @@ public class OAuth2ClientSpecTests {
 				this.registration, authorizationExchange, accessToken);
 		given(authorizationRequestRepository.loadAuthorizationRequest(any()))
 				.willReturn(Mono.just(authorizationRequest));
+		given(resolver.resolve(any())).willReturn(Mono.empty());
 		given(converter.convert(any())).willReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c")));
 		given(manager.authenticate(any())).willReturn(Mono.just(result));
 		given(requestCache.getRedirectUri(any())).willReturn(Mono.just(URI.create("/saved-request")));
@@ -162,6 +165,7 @@ public class OAuth2ClientSpecTests {
 		verify(converter).convert(any());
 		verify(manager).authenticate(any());
 		verify(requestCache).getRedirectUri(any());
+		verify(resolver).resolve(any());
 	}
 
 	@Test
@@ -266,6 +270,8 @@ public class OAuth2ClientSpecTests {
 
 		ServerRequestCache requestCache = mock(ServerRequestCache.class);
 
+		ServerOAuth2AuthorizationRequestResolver resolver = mock(ServerOAuth2AuthorizationRequestResolver.class);
+
 		@Bean
 		SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
 			// @formatter:off
@@ -274,6 +280,7 @@ public class OAuth2ClientSpecTests {
 					.authenticationConverter(this.authenticationConverter)
 					.authenticationManager(this.manager)
 					.authorizationRequestRepository(this.authorizationRequestRepository)
+					.authorizationRequestResolver(this.resolver)
 					.and()
 				.requestCache((c) -> c.requestCache(this.requestCache));
 			// @formatter:on