Quellcode durchsuchen

OAuth2LoginSpec discovers ReactiveOAuth2AccessTokenResponseClient @Bean

Fixes: gh-6477
Aanuoluwapo Otitoola vor 6 Jahren
Ursprung
Commit
ad9dc49d55

+ 12 - 1
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -31,6 +31,8 @@ import java.util.Optional;
 import java.util.UUID;
 import java.util.function.Function;
 
+import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
 import reactor.core.publisher.Mono;
 import reactor.util.context.Context;
 
@@ -621,7 +623,7 @@ public class ServerHttpSecurity {
 		}
 
 		private ReactiveAuthenticationManager createDefault() {
-			WebClientReactiveAuthorizationCodeTokenResponseClient client = new WebClientReactiveAuthorizationCodeTokenResponseClient();
+			ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> client = getAccessTokenResponseClient();
 			ReactiveAuthenticationManager result = new OAuth2LoginReactiveAuthenticationManager(client, getOauth2UserService());
 
 			boolean oidcAuthenticationProviderEnabled = ClassUtils.isPresent(
@@ -788,6 +790,15 @@ public class ServerHttpSecurity {
 			return result;
 		}
 
+		private ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> getAccessTokenResponseClient() {
+			ResolvableType type = ResolvableType.forClassWithGenerics(ReactiveOAuth2AccessTokenResponseClient.class, OAuth2AuthorizationCodeGrantRequest.class);
+			ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> bean = getBeanOrNull(type);
+			if (bean == null) {
+				return new WebClientReactiveAuthorizationCodeTokenResponseClient();
+			}
+			return bean;
+		}
+
 		private ReactiveClientRegistrationRepository getClientRegistrationRepository() {
 			if (this.clientRegistrationRepository == null) {
 				this.clientRegistrationRepository = getBeanOrNull(ReactiveClientRegistrationRepository.class);

+ 11 - 5
config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

@@ -218,16 +218,16 @@ public class OAuth2LoginTests {
 	}
 
 	@Test
-	public void oauth2LoginWhenCustomJwtDecoderFactoryThenUsed() {
+	public void oauth2LoginWhenCustomBeansThenUsed() {
 		this.spring.register(OAuth2LoginWithMultipleClientRegistrations.class,
-				OAuth2LoginWithJwtDecoderFactoryBeanConfig.class).autowire();
+				OAuth2LoginWithCustomBeansConfig.class).autowire();
 
 		WebTestClient webTestClient = WebTestClientBuilder
 				.bindToWebFilters(this.springSecurity)
 				.build();
 
-		OAuth2LoginWithJwtDecoderFactoryBeanConfig config = this.spring.getContext()
-				.getBean(OAuth2LoginWithJwtDecoderFactoryBeanConfig.class);
+		OAuth2LoginWithCustomBeansConfig config = this.spring.getContext()
+				.getBean(OAuth2LoginWithCustomBeansConfig.class);
 
 		OAuth2AuthorizationRequest request = TestOAuth2AuthorizationRequests.request().scope("openid").build();
 		OAuth2AuthorizationResponse response = TestOAuth2AuthorizationResponses.success().build();
@@ -258,10 +258,11 @@ public class OAuth2LoginTests {
 				.expectStatus().is3xxRedirection();
 
 		verify(config.jwtDecoderFactory).createDecoder(any());
+		verify(tokenResponseClient).getTokenResponse(any());
 	}
 
 	@Configuration
-	static class OAuth2LoginWithJwtDecoderFactoryBeanConfig {
+	static class OAuth2LoginWithCustomBeansConfig {
 
 		ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class);
 
@@ -298,6 +299,11 @@ public class OAuth2LoginTests {
 			return jwtDecoderFactory;
 		}
 
+		@Bean
+		public ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient() {
+			return tokenResponseClient;
+		}
+
 		private static class JwtDecoderFactory implements ReactiveJwtDecoderFactory<ClientRegistration> {
 
 			@Override