浏览代码

ServerOAuth2AuthorizedClientExchangeFilterFunction works with UnAuthenticatedServerOAuth2AuthorizedClientRepository

Fixes gh-7544
Joe Grandja 5 年之前
父节点
当前提交
80f256e425

+ 66 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

@@ -22,6 +22,7 @@ import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
 import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager;
@@ -35,6 +36,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.server.UnAuthenticatedServerOAuth2AuthorizedClientRepository;
 import org.springframework.util.Assert;
 import org.springframework.web.reactive.function.client.ClientRequest;
 import org.springframework.web.reactive.function.client.ClientResponse;
@@ -124,6 +126,17 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 						.clientCredentials()
 						.password()
 						.build();
+
+		// gh-7544
+		if (authorizedClientRepository instanceof UnAuthenticatedServerOAuth2AuthorizedClientRepository) {
+			UnAuthenticatedReactiveOAuth2AuthorizedClientManager unauthenticatedAuthorizedClientManager =
+					new UnAuthenticatedReactiveOAuth2AuthorizedClientManager(
+							clientRegistrationRepository,
+							(UnAuthenticatedServerOAuth2AuthorizedClientRepository) authorizedClientRepository);
+			unauthenticatedAuthorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
+			return unauthenticatedAuthorizedClientManager;
+		}
+
 		DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager(
 				clientRegistrationRepository, authorizedClientRepository);
 		authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
@@ -266,7 +279,11 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 						.clientCredentials(this::updateClientCredentialsProvider)
 						.password(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew))
 						.build();
-		((DefaultReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider);
+		if (this.authorizedClientManager instanceof UnAuthenticatedReactiveOAuth2AuthorizedClientManager) {
+			((UnAuthenticatedReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider);
+		} else {
+			((DefaultReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider);
+		}
 	}
 
 	private void updateClientCredentialsProvider(ReactiveOAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilder builder) {
@@ -376,4 +393,52 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
 					.headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue()))
 					.build();
 	}
+
+	private static class UnAuthenticatedReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager {
+		private final ReactiveClientRegistrationRepository clientRegistrationRepository;
+		private final UnAuthenticatedServerOAuth2AuthorizedClientRepository authorizedClientRepository;
+		private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider;
+
+		private UnAuthenticatedReactiveOAuth2AuthorizedClientManager(
+				ReactiveClientRegistrationRepository clientRegistrationRepository,
+				UnAuthenticatedServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
+			this.clientRegistrationRepository = clientRegistrationRepository;
+			this.authorizedClientRepository = authorizedClientRepository;
+		}
+
+		@Override
+		public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizeRequest authorizeRequest) {
+			Assert.notNull(authorizeRequest, "authorizeRequest cannot be null");
+
+			String clientRegistrationId = authorizeRequest.getClientRegistrationId();
+			Authentication principal = authorizeRequest.getPrincipal();
+
+			return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient())
+					.switchIfEmpty(Mono.defer(() -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, null)))
+					.flatMap(authorizedClient -> {
+						// Re-authorize
+						return Mono.just(OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient).principal(principal).build())
+								.flatMap(this.authorizedClientProvider::authorize)
+								.flatMap(reauthorizedClient -> this.authorizedClientRepository.saveAuthorizedClient(reauthorizedClient, principal, null).thenReturn(reauthorizedClient))
+								// Default to the existing authorizedClient if the client was not re-authorized
+								.defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ?
+										authorizeRequest.getAuthorizedClient() : authorizedClient);
+					})
+					.switchIfEmpty(Mono.deferWithContext(context ->
+						// Authorize
+						this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
+								.switchIfEmpty(Mono.error(() -> new IllegalArgumentException(
+										"Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
+								.flatMap(clientRegistration -> Mono.just(OAuth2AuthorizationContext.withClientRegistration(clientRegistration).principal(principal).build()))
+								.flatMap(this.authorizedClientProvider::authorize)
+								.flatMap(authorizedClient -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null).thenReturn(authorizedClient))
+								.subscriberContext(context)
+					));
+		}
+
+		private void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) {
+			Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null");
+			this.authorizedClientProvider = authorizedClientProvider;
+		}
+	}
 }

+ 38 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

@@ -57,6 +57,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.server.UnAuthenticatedServerOAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@@ -587,6 +588,43 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
 		verify(this.authorizedClientRepository).loadAuthorizedClient(eq(this.registration.getRegistrationId()), any(), eq(this.serverWebExchange));
 	}
 
+	// gh-7544
+	@Test
+	public void filterWhenClientCredentialsClientNotAuthorizedAndOutsideRequestContextThenGetNewToken() {
+		// Use UnAuthenticatedServerOAuth2AuthorizedClientRepository when operating outside of a request context
+		ServerOAuth2AuthorizedClientRepository unauthenticatedAuthorizedClientRepository = spy(new UnAuthenticatedServerOAuth2AuthorizedClientRepository());
+		this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(
+				this.clientRegistrationRepository, unauthenticatedAuthorizedClientRepository);
+		this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);
+
+		OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("new-token")
+				.tokenType(OAuth2AccessToken.TokenType.BEARER)
+				.expiresIn(360)
+				.build();
+		when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
+
+		ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
+		when(this.clientRegistrationRepository.findByRegistrationId(eq(registration.getRegistrationId()))).thenReturn(Mono.just(registration));
+
+		ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
+				.attributes(clientRegistrationId(registration.getRegistrationId()))
+				.build();
+
+		this.function.filter(request, this.exchange).block();
+
+		verify(unauthenticatedAuthorizedClientRepository).loadAuthorizedClient(any(), any(), any());
+		verify(this.clientCredentialsTokenResponseClient).getTokenResponse(any());
+		verify(unauthenticatedAuthorizedClientRepository).saveAuthorizedClient(any(), any(), any());
+
+		List<ClientRequest> requests = this.exchange.getRequests();
+		assertThat(requests).hasSize(1);
+		ClientRequest request1 = requests.get(0);
+		assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer new-token");
+		assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
+		assertThat(request1.method()).isEqualTo(HttpMethod.GET);
+		assertThat(getBody(request1)).isEmpty();
+	}
+
 	private Context serverWebExchange() {
 		return Context.of(ServerWebExchange.class, this.serverWebExchange);
 	}