瀏覽代碼

Look up ReactiveOAuth2AccessTokenResponseClient as a bean

Closes gh-11097
Steve Riesenberg 11 月之前
父節點
當前提交
cd7f6e09b0

+ 12 - 1
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -4813,11 +4813,22 @@ public class ServerHttpSecurity {
 		private ReactiveAuthenticationManager getAuthenticationManager() {
 			if (this.authenticationManager == null) {
 				this.authenticationManager = new OAuth2AuthorizationCodeReactiveAuthenticationManager(
-						new WebClientReactiveAuthorizationCodeTokenResponseClient());
+						getAuthorizationCodeTokenResponseClient());
 			}
 			return this.authenticationManager;
 		}
 
+		private ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> getAuthorizationCodeTokenResponseClient() {
+			ResolvableType resolvableType = ResolvableType.forClassWithGenerics(
+					ReactiveOAuth2AccessTokenResponseClient.class, OAuth2AuthorizationCodeGrantRequest.class);
+			ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient = getBeanOrNull(
+					resolvableType);
+			if (accessTokenResponseClient == null) {
+				accessTokenResponseClient = new WebClientReactiveAuthorizationCodeTokenResponseClient();
+			}
+			return accessTokenResponseClient;
+		}
+
 		/**
 		 * Configures the {@link ReactiveClientRegistrationRepository}. Default is to look
 		 * the value up as a Bean.

+ 106 - 1
config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,9 +17,11 @@
 package org.springframework.security.config.web.server;
 
 import java.net.URI;
+import java.util.Set;
 
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
 import reactor.core.publisher.Mono;
 
 import org.springframework.beans.factory.annotation.Autowired;
@@ -31,9 +33,12 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
 import org.springframework.security.config.test.SpringTestContext;
 import org.springframework.security.config.test.SpringTestContextExtension;
+import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
@@ -41,8 +46,10 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr
 import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
@@ -59,7 +66,9 @@ import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.reactive.config.EnableWebFlux;
+import org.springframework.web.server.ServerWebExchange;
 
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
@@ -215,6 +224,62 @@ public class OAuth2ClientSpecTests {
 		verify(requestCache).getRedirectUri(any());
 	}
 
+	@Test
+	@SuppressWarnings("unchecked")
+	public void oauth2ClientWhenCustomAccessTokenResponseClientThenUsed() {
+		this.spring.register(OAuth2ClientBeanConfig.class, AuthorizedClientController.class).autowire();
+		ReactiveClientRegistrationRepository clientRegistrationRepository = this.spring.getContext()
+			.getBean(ReactiveClientRegistrationRepository.class);
+		given(clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.just(this.registration));
+		ServerOAuth2AuthorizedClientRepository authorizedClientRepository = this.spring.getContext()
+			.getBean(ServerOAuth2AuthorizedClientRepository.class);
+		given(authorizedClientRepository.saveAuthorizedClient(any(OAuth2AuthorizedClient.class),
+				any(Authentication.class), any(ServerWebExchange.class)))
+			.willReturn(Mono.empty());
+		ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = this.spring
+			.getContext()
+			.getBean(ServerAuthorizationRequestRepository.class);
+		OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
+			.redirectUri("/authorize/oauth2/code/registration-id")
+			.build();
+		given(authorizationRequestRepository.loadAuthorizationRequest(any(ServerWebExchange.class)))
+			.willReturn(Mono.just(authorizationRequest));
+		given(authorizationRequestRepository.removeAuthorizationRequest(any(ServerWebExchange.class)))
+			.willReturn(Mono.just(authorizationRequest));
+		ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> accessTokenResponseClient = this.spring
+			.getContext()
+			.getBean(ReactiveOAuth2AccessTokenResponseClient.class);
+		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("token")
+			.tokenType(OAuth2AccessToken.TokenType.BEARER)
+			.scopes(Set.of())
+			.expiresIn(300)
+			.build();
+		given(accessTokenResponseClient.getTokenResponse(any(OAuth2AuthorizationCodeGrantRequest.class)))
+			.willReturn(Mono.just(accessTokenResponse));
+		// @formatter:off
+		this.client.get()
+			.uri((uriBuilder) -> uriBuilder
+				.path("/authorize/oauth2/code/registration-id")
+				.queryParam(OAuth2ParameterNames.CODE, "code")
+				.queryParam(OAuth2ParameterNames.STATE, "state")
+				.build()
+			)
+			.exchange()
+			.expectStatus().is3xxRedirection();
+		// @formatter:on
+		ArgumentCaptor<OAuth2AuthorizationCodeGrantRequest> grantRequestArgumentCaptor = ArgumentCaptor
+			.forClass(OAuth2AuthorizationCodeGrantRequest.class);
+		verify(accessTokenResponseClient).getTokenResponse(grantRequestArgumentCaptor.capture());
+		OAuth2AuthorizationCodeGrantRequest grantRequest = grantRequestArgumentCaptor.getValue();
+		assertThat(grantRequest.getClientRegistration()).isEqualTo(this.registration);
+		assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE);
+		assertThat(grantRequest.getAuthorizationExchange().getAuthorizationRequest()).isEqualTo(authorizationRequest);
+		assertThat(grantRequest.getAuthorizationExchange().getAuthorizationResponse().getCode()).isEqualTo("code");
+		assertThat(grantRequest.getAuthorizationExchange().getAuthorizationResponse().getState()).isEqualTo("state");
+		assertThat(grantRequest.getAuthorizationExchange().getAuthorizationResponse().getRedirectUri())
+			.startsWith("/authorize/oauth2/code/registration-id");
+	}
+
 	@Configuration
 	@EnableWebFlux
 	@EnableWebFluxSecurity
@@ -324,4 +389,44 @@ public class OAuth2ClientSpecTests {
 
 	}
 
+	@Configuration
+	@EnableWebFlux
+	@EnableWebFluxSecurity
+	static class OAuth2ClientBeanConfig {
+
+		@Bean
+		SecurityWebFilterChain securityWebFilterChain(ServerHttpSecurity http) {
+			// @formatter:off
+			http
+				.oauth2Client((oauth2Client) -> oauth2Client
+					.authorizationRequestRepository(authorizationRequestRepository())
+				);
+			// @formatter:on
+			return http.build();
+		}
+
+		@Bean
+		@SuppressWarnings("unchecked")
+		ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository() {
+			return mock(ServerAuthorizationRequestRepository.class);
+		}
+
+		@Bean
+		@SuppressWarnings("unchecked")
+		ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> authorizationCodeAccessTokenResponseClient() {
+			return mock(ReactiveOAuth2AccessTokenResponseClient.class);
+		}
+
+		@Bean
+		ReactiveClientRegistrationRepository clientRegistrationRepository() {
+			return mock(ReactiveClientRegistrationRepository.class);
+		}
+
+		@Bean
+		ServerOAuth2AuthorizedClientRepository authorizedClientRepository() {
+			return mock(ServerOAuth2AuthorizedClientRepository.class);
+		}
+
+	}
+
 }