Преглед на файлове

Simplify reactive OAuth2 Client configuration

Closes gh-13763
Steve Riesenberg преди 1 година
родител
ревизия
80a8d3831a

+ 413 - 0
config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientConfiguration.java

@@ -0,0 +1,413 @@
+/*
+ * Copyright 2002-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.annotation.web.reactive;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Set;
+import java.util.function.Consumer;
+
+import org.springframework.beans.BeansException;
+import org.springframework.beans.factory.BeanFactory;
+import org.springframework.beans.factory.BeanFactoryAware;
+import org.springframework.beans.factory.BeanFactoryUtils;
+import org.springframework.beans.factory.BeanInitializationException;
+import org.springframework.beans.factory.ListableBeanFactory;
+import org.springframework.beans.factory.NoSuchBeanDefinitionException;
+import org.springframework.beans.factory.ObjectProvider;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.config.BeanDefinition;
+import org.springframework.beans.factory.support.BeanDefinitionBuilder;
+import org.springframework.beans.factory.support.BeanDefinitionRegistry;
+import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
+import org.springframework.context.annotation.AnnotationBeanNameGenerator;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.context.annotation.Import;
+import org.springframework.core.ResolvableType;
+import org.springframework.security.oauth2.client.AuthorizationCodeReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.DelegatingReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.JwtBearerReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.PasswordReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.RefreshTokenReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.TokenExchangeReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
+import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
+import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
+import org.springframework.web.reactive.config.WebFluxConfigurer;
+import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer;
+
+/**
+ * {@link Configuration} for OAuth 2.0 Client support.
+ *
+ * <p>
+ * This {@code Configuration} is conditionally imported by
+ * {@link ReactiveOAuth2ClientImportSelector} when the
+ * {@code spring-security-oauth2-client} module is present on the classpath.
+ *
+ * @author Steve Riesenberg
+ * @since 6.3
+ * @see ReactiveOAuth2ClientImportSelector
+ */
+@Import({ ReactiveOAuth2ClientConfiguration.ReactiveOAuth2AuthorizedClientManagerConfiguration.class,
+		ReactiveOAuth2ClientConfiguration.OAuth2ClientWebFluxSecurityConfiguration.class })
+final class ReactiveOAuth2ClientConfiguration {
+
+	@Configuration
+	static class ReactiveOAuth2AuthorizedClientManagerConfiguration {
+
+		@Bean(name = ReactiveOAuth2AuthorizedClientManagerRegistrar.BEAN_NAME)
+		ReactiveOAuth2AuthorizedClientManagerRegistrar authorizedClientManagerRegistrar() {
+			return new ReactiveOAuth2AuthorizedClientManagerRegistrar();
+		}
+
+	}
+
+	@Configuration(proxyBeanMethods = false)
+	static class OAuth2ClientWebFluxSecurityConfiguration implements WebFluxConfigurer {
+
+		private ReactiveOAuth2AuthorizedClientManager authorizedClientManager;
+
+		private ReactiveOAuth2AuthorizedClientManagerRegistrar authorizedClientManagerRegistrar;
+
+		@Override
+		public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) {
+			ReactiveOAuth2AuthorizedClientManager authorizedClientManager = getAuthorizedClientManager();
+			if (authorizedClientManager != null) {
+				configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager));
+			}
+		}
+
+		@Autowired(required = false)
+		void setAuthorizedClientManager(List<ReactiveOAuth2AuthorizedClientManager> authorizedClientManager) {
+			if (authorizedClientManager.size() == 1) {
+				this.authorizedClientManager = authorizedClientManager.get(0);
+			}
+		}
+
+		@Autowired
+		void setAuthorizedClientManagerRegistrar(
+				ReactiveOAuth2AuthorizedClientManagerRegistrar authorizedClientManagerRegistrar) {
+			this.authorizedClientManagerRegistrar = authorizedClientManagerRegistrar;
+		}
+
+		private ReactiveOAuth2AuthorizedClientManager getAuthorizedClientManager() {
+			if (this.authorizedClientManager != null) {
+				return this.authorizedClientManager;
+			}
+			return this.authorizedClientManagerRegistrar.getAuthorizedClientManagerIfAvailable();
+		}
+
+	}
+
+	/**
+	 * A registrar for registering the default
+	 * {@link ReactiveOAuth2AuthorizedClientManager} bean definition, if not already
+	 * present.
+	 */
+	static final class ReactiveOAuth2AuthorizedClientManagerRegistrar
+			implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware {
+
+		static final String BEAN_NAME = "authorizedClientManagerRegistrar";
+
+		static final String FACTORY_METHOD_NAME = "getAuthorizedClientManager";
+
+		// @formatter:off
+		private static final Set<Class<?>> KNOWN_AUTHORIZED_CLIENT_PROVIDERS = Set.of(
+				AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.class,
+				RefreshTokenReactiveOAuth2AuthorizedClientProvider.class,
+				ClientCredentialsReactiveOAuth2AuthorizedClientProvider.class,
+				PasswordReactiveOAuth2AuthorizedClientProvider.class,
+				JwtBearerReactiveOAuth2AuthorizedClientProvider.class,
+				TokenExchangeReactiveOAuth2AuthorizedClientProvider.class
+		);
+		// @formatter:on
+
+		private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator();
+
+		private ListableBeanFactory beanFactory;
+
+		@Override
+		public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
+			if (getBeanNamesForType(ReactiveOAuth2AuthorizedClientManager.class).length != 0
+					|| getBeanNamesForType(ReactiveClientRegistrationRepository.class).length != 1
+					|| getBeanNamesForType(ServerOAuth2AuthorizedClientRepository.class).length != 1
+							&& getBeanNamesForType(ReactiveOAuth2AuthorizedClientService.class).length != 1) {
+				return;
+			}
+
+			BeanDefinition beanDefinition = BeanDefinitionBuilder
+				.rootBeanDefinition(ReactiveOAuth2AuthorizedClientManager.class)
+				.setFactoryMethodOnBean(FACTORY_METHOD_NAME, BEAN_NAME)
+				.getBeanDefinition();
+
+			registry.registerBeanDefinition(this.beanNameGenerator.generateBeanName(beanDefinition, registry),
+					beanDefinition);
+		}
+
+		@Override
+		public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
+			this.beanFactory = (ListableBeanFactory) beanFactory;
+		}
+
+		ReactiveOAuth2AuthorizedClientManager getAuthorizedClientManagerIfAvailable() {
+			if (getBeanNamesForType(ReactiveClientRegistrationRepository.class).length != 1
+					|| getBeanNamesForType(ServerOAuth2AuthorizedClientRepository.class).length != 1
+							&& getBeanNamesForType(ReactiveOAuth2AuthorizedClientService.class).length != 1) {
+				return null;
+			}
+			return getAuthorizedClientManager();
+		}
+
+		ReactiveOAuth2AuthorizedClientManager getAuthorizedClientManager() {
+			ReactiveClientRegistrationRepository clientRegistrationRepository = BeanFactoryUtils
+				.beanOfTypeIncludingAncestors(this.beanFactory, ReactiveClientRegistrationRepository.class, true, true);
+
+			ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
+			try {
+				authorizedClientRepository = BeanFactoryUtils.beanOfTypeIncludingAncestors(this.beanFactory,
+						ServerOAuth2AuthorizedClientRepository.class, true, true);
+			}
+			catch (NoSuchBeanDefinitionException ex) {
+				ReactiveOAuth2AuthorizedClientService authorizedClientService = BeanFactoryUtils
+					.beanOfTypeIncludingAncestors(this.beanFactory, ReactiveOAuth2AuthorizedClientService.class, true,
+							true);
+				authorizedClientRepository = new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(
+						authorizedClientService);
+			}
+
+			Collection<ReactiveOAuth2AuthorizedClientProvider> authorizedClientProviderBeans = BeanFactoryUtils
+				.beansOfTypeIncludingAncestors(this.beanFactory, ReactiveOAuth2AuthorizedClientProvider.class, true,
+						true)
+				.values();
+
+			ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider;
+			if (hasDelegatingAuthorizedClientProvider(authorizedClientProviderBeans)) {
+				authorizedClientProvider = authorizedClientProviderBeans.iterator().next();
+			}
+			else {
+				List<ReactiveOAuth2AuthorizedClientProvider> authorizedClientProviders = new ArrayList<>();
+				authorizedClientProviders
+					.add(getAuthorizationCodeAuthorizedClientProvider(authorizedClientProviderBeans));
+				authorizedClientProviders.add(getRefreshTokenAuthorizedClientProvider(authorizedClientProviderBeans));
+				authorizedClientProviders
+					.add(getClientCredentialsAuthorizedClientProvider(authorizedClientProviderBeans));
+				authorizedClientProviders.add(getPasswordAuthorizedClientProvider(authorizedClientProviderBeans));
+
+				ReactiveOAuth2AuthorizedClientProvider jwtBearerAuthorizedClientProvider = getJwtBearerAuthorizedClientProvider(
+						authorizedClientProviderBeans);
+				if (jwtBearerAuthorizedClientProvider != null) {
+					authorizedClientProviders.add(jwtBearerAuthorizedClientProvider);
+				}
+
+				ReactiveOAuth2AuthorizedClientProvider tokenExchangeAuthorizedClientProvider = getTokenExchangeAuthorizedClientProvider(
+						authorizedClientProviderBeans);
+				if (tokenExchangeAuthorizedClientProvider != null) {
+					authorizedClientProviders.add(tokenExchangeAuthorizedClientProvider);
+				}
+
+				authorizedClientProviders.addAll(getAdditionalAuthorizedClientProviders(authorizedClientProviderBeans));
+				authorizedClientProvider = new DelegatingReactiveOAuth2AuthorizedClientProvider(
+						authorizedClientProviders);
+			}
+
+			DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager(
+					clientRegistrationRepository, authorizedClientRepository);
+			authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
+
+			Consumer<DefaultReactiveOAuth2AuthorizedClientManager> authorizedClientManagerConsumer = getBeanOfType(
+					ResolvableType.forClassWithGenerics(Consumer.class,
+							DefaultReactiveOAuth2AuthorizedClientManager.class));
+			if (authorizedClientManagerConsumer != null) {
+				authorizedClientManagerConsumer.accept(authorizedClientManager);
+			}
+
+			return authorizedClientManager;
+		}
+
+		private boolean hasDelegatingAuthorizedClientProvider(
+				Collection<ReactiveOAuth2AuthorizedClientProvider> authorizedClientProviders) {
+			if (authorizedClientProviders.size() != 1) {
+				return false;
+			}
+			return authorizedClientProviders.iterator()
+				.next() instanceof DelegatingReactiveOAuth2AuthorizedClientProvider;
+		}
+
+		private ReactiveOAuth2AuthorizedClientProvider getAuthorizationCodeAuthorizedClientProvider(
+				Collection<ReactiveOAuth2AuthorizedClientProvider> authorizedClientProviders) {
+			AuthorizationCodeReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
+					authorizedClientProviders, AuthorizationCodeReactiveOAuth2AuthorizedClientProvider.class);
+			if (authorizedClientProvider == null) {
+				authorizedClientProvider = new AuthorizationCodeReactiveOAuth2AuthorizedClientProvider();
+			}
+
+			return authorizedClientProvider;
+		}
+
+		private ReactiveOAuth2AuthorizedClientProvider getRefreshTokenAuthorizedClientProvider(
+				Collection<ReactiveOAuth2AuthorizedClientProvider> authorizedClientProviders) {
+			RefreshTokenReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
+					authorizedClientProviders, RefreshTokenReactiveOAuth2AuthorizedClientProvider.class);
+			if (authorizedClientProvider == null) {
+				authorizedClientProvider = new RefreshTokenReactiveOAuth2AuthorizedClientProvider();
+			}
+
+			ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> accessTokenResponseClient = getBeanOfType(
+					ResolvableType.forClassWithGenerics(ReactiveOAuth2AccessTokenResponseClient.class,
+							OAuth2RefreshTokenGrantRequest.class));
+			if (accessTokenResponseClient != null) {
+				authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
+			}
+
+			return authorizedClientProvider;
+		}
+
+		private ReactiveOAuth2AuthorizedClientProvider getClientCredentialsAuthorizedClientProvider(
+				Collection<ReactiveOAuth2AuthorizedClientProvider> authorizedClientProviders) {
+			ClientCredentialsReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
+					authorizedClientProviders, ClientCredentialsReactiveOAuth2AuthorizedClientProvider.class);
+			if (authorizedClientProvider == null) {
+				authorizedClientProvider = new ClientCredentialsReactiveOAuth2AuthorizedClientProvider();
+			}
+
+			ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> accessTokenResponseClient = getBeanOfType(
+					ResolvableType.forClassWithGenerics(ReactiveOAuth2AccessTokenResponseClient.class,
+							OAuth2ClientCredentialsGrantRequest.class));
+			if (accessTokenResponseClient != null) {
+				authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
+			}
+
+			return authorizedClientProvider;
+		}
+
+		private ReactiveOAuth2AuthorizedClientProvider getPasswordAuthorizedClientProvider(
+				Collection<ReactiveOAuth2AuthorizedClientProvider> authorizedClientProviders) {
+			PasswordReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
+					authorizedClientProviders, PasswordReactiveOAuth2AuthorizedClientProvider.class);
+			if (authorizedClientProvider == null) {
+				authorizedClientProvider = new PasswordReactiveOAuth2AuthorizedClientProvider();
+			}
+
+			ReactiveOAuth2AccessTokenResponseClient<OAuth2PasswordGrantRequest> accessTokenResponseClient = getBeanOfType(
+					ResolvableType.forClassWithGenerics(ReactiveOAuth2AccessTokenResponseClient.class,
+							OAuth2PasswordGrantRequest.class));
+			if (accessTokenResponseClient != null) {
+				authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
+			}
+
+			return authorizedClientProvider;
+		}
+
+		private ReactiveOAuth2AuthorizedClientProvider getJwtBearerAuthorizedClientProvider(
+				Collection<ReactiveOAuth2AuthorizedClientProvider> authorizedClientProviders) {
+			JwtBearerReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
+					authorizedClientProviders, JwtBearerReactiveOAuth2AuthorizedClientProvider.class);
+
+			ReactiveOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> accessTokenResponseClient = getBeanOfType(
+					ResolvableType.forClassWithGenerics(ReactiveOAuth2AccessTokenResponseClient.class,
+							JwtBearerGrantRequest.class));
+			if (accessTokenResponseClient != null) {
+				if (authorizedClientProvider == null) {
+					authorizedClientProvider = new JwtBearerReactiveOAuth2AuthorizedClientProvider();
+				}
+
+				authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
+			}
+
+			return authorizedClientProvider;
+		}
+
+		private ReactiveOAuth2AuthorizedClientProvider getTokenExchangeAuthorizedClientProvider(
+				Collection<ReactiveOAuth2AuthorizedClientProvider> authorizedClientProviders) {
+			TokenExchangeReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = getAuthorizedClientProviderByType(
+					authorizedClientProviders, TokenExchangeReactiveOAuth2AuthorizedClientProvider.class);
+
+			ReactiveOAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> accessTokenResponseClient = getBeanOfType(
+					ResolvableType.forClassWithGenerics(ReactiveOAuth2AccessTokenResponseClient.class,
+							TokenExchangeGrantRequest.class));
+			if (accessTokenResponseClient != null) {
+				if (authorizedClientProvider == null) {
+					authorizedClientProvider = new TokenExchangeReactiveOAuth2AuthorizedClientProvider();
+				}
+
+				authorizedClientProvider.setAccessTokenResponseClient(accessTokenResponseClient);
+			}
+
+			return authorizedClientProvider;
+		}
+
+		private List<ReactiveOAuth2AuthorizedClientProvider> getAdditionalAuthorizedClientProviders(
+				Collection<ReactiveOAuth2AuthorizedClientProvider> authorizedClientProviders) {
+			List<ReactiveOAuth2AuthorizedClientProvider> additionalAuthorizedClientProviders = new ArrayList<>(
+					authorizedClientProviders);
+			additionalAuthorizedClientProviders
+				.removeIf((provider) -> KNOWN_AUTHORIZED_CLIENT_PROVIDERS.contains(provider.getClass()));
+			return additionalAuthorizedClientProviders;
+		}
+
+		private <T extends ReactiveOAuth2AuthorizedClientProvider> T getAuthorizedClientProviderByType(
+				Collection<ReactiveOAuth2AuthorizedClientProvider> authorizedClientProviders, Class<T> providerClass) {
+			T authorizedClientProvider = null;
+			for (ReactiveOAuth2AuthorizedClientProvider current : authorizedClientProviders) {
+				if (providerClass.isInstance(current)) {
+					assertAuthorizedClientProviderIsNull(authorizedClientProvider);
+					authorizedClientProvider = providerClass.cast(current);
+				}
+			}
+			return authorizedClientProvider;
+		}
+
+		private static void assertAuthorizedClientProviderIsNull(
+				ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) {
+			if (authorizedClientProvider != null) {
+				// @formatter:off
+				throw new BeanInitializationException(String.format(
+						"Unable to create a %s bean. Expected one bean of type %s, but found multiple. " +
+						"Please consider defining only a single bean of this type, or define a %s bean yourself.",
+						ReactiveOAuth2AuthorizedClientManager.class.getName(),
+						authorizedClientProvider.getClass().getName(),
+						ReactiveOAuth2AuthorizedClientManager.class.getName()));
+				// @formatter:on
+			}
+		}
+
+		private <T> String[] getBeanNamesForType(Class<T> beanClass) {
+			return BeanFactoryUtils.beanNamesForTypeIncludingAncestors(this.beanFactory, beanClass, true, true);
+		}
+
+		private <T> T getBeanOfType(ResolvableType resolvableType) {
+			ObjectProvider<T> objectProvider = this.beanFactory.getBeanProvider(resolvableType, true);
+			return objectProvider.getIfAvailable();
+		}
+
+	}
+
+}

+ 5 - 96
config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -16,27 +16,13 @@
 
 
 package org.springframework.security.config.annotation.web.reactive;
 package org.springframework.security.config.annotation.web.reactive;
 
 
-import java.util.List;
-
-import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.context.annotation.Configuration;
 import org.springframework.context.annotation.ImportSelector;
 import org.springframework.context.annotation.ImportSelector;
 import org.springframework.core.type.AnnotationMetadata;
 import org.springframework.core.type.AnnotationMetadata;
-import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager;
-import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProvider;
-import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientProviderBuilder;
-import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
-import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
-import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
-import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
-import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
-import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
 import org.springframework.util.ClassUtils;
 import org.springframework.util.ClassUtils;
-import org.springframework.web.reactive.config.WebFluxConfigurer;
-import org.springframework.web.reactive.result.method.annotation.ArgumentResolverConfigurer;
 
 
 /**
 /**
- * {@link Configuration} for OAuth 2.0 Client support.
+ * Used by {@link EnableWebFluxSecurity} to conditionally import
+ * {@link ReactiveOAuth2ClientConfiguration}.
  *
  *
  * <p>
  * <p>
  * This {@code Configuration} is imported by {@link EnableWebFluxSecurity}
  * This {@code Configuration} is imported by {@link EnableWebFluxSecurity}
@@ -60,85 +46,8 @@ final class ReactiveOAuth2ClientImportSelector implements ImportSelector {
 		if (!oauth2ClientPresent) {
 		if (!oauth2ClientPresent) {
 			return new String[0];
 			return new String[0];
 		}
 		}
-		return new String[] { "org.springframework.security.config.annotation.web.reactive."
-				+ "ReactiveOAuth2ClientImportSelector$OAuth2ClientWebFluxSecurityConfiguration" };
-	}
-
-	@Configuration(proxyBeanMethods = false)
-	static class OAuth2ClientWebFluxSecurityConfiguration implements WebFluxConfigurer {
-
-		private ReactiveClientRegistrationRepository clientRegistrationRepository;
-
-		private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
-
-		private ReactiveOAuth2AuthorizedClientService authorizedClientService;
-
-		private ReactiveOAuth2AuthorizedClientManager authorizedClientManager;
-
-		@Override
-		public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) {
-			ReactiveOAuth2AuthorizedClientManager authorizedClientManager = getAuthorizedClientManager();
-			if (authorizedClientManager != null) {
-				configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(authorizedClientManager));
-			}
-		}
-
-		@Autowired(required = false)
-		void setClientRegistrationRepository(ReactiveClientRegistrationRepository clientRegistrationRepository) {
-			this.clientRegistrationRepository = clientRegistrationRepository;
-		}
-
-		@Autowired(required = false)
-		void setAuthorizedClientRepository(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
-			this.authorizedClientRepository = authorizedClientRepository;
-		}
-
-		@Autowired(required = false)
-		void setAuthorizedClientService(List<ReactiveOAuth2AuthorizedClientService> authorizedClientService) {
-			if (authorizedClientService.size() == 1) {
-				this.authorizedClientService = authorizedClientService.get(0);
-			}
-		}
-
-		@Autowired(required = false)
-		void setAuthorizedClientManager(List<ReactiveOAuth2AuthorizedClientManager> authorizedClientManager) {
-			if (authorizedClientManager.size() == 1) {
-				this.authorizedClientManager = authorizedClientManager.get(0);
-			}
-		}
-
-		private ServerOAuth2AuthorizedClientRepository getAuthorizedClientRepository() {
-			if (this.authorizedClientRepository != null) {
-				return this.authorizedClientRepository;
-			}
-			if (this.authorizedClientService != null) {
-				return new AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository(this.authorizedClientService);
-			}
-			return null;
-		}
-
-		private ReactiveOAuth2AuthorizedClientManager getAuthorizedClientManager() {
-			if (this.authorizedClientManager != null) {
-				return this.authorizedClientManager;
-			}
-			ReactiveOAuth2AuthorizedClientManager authorizedClientManager = null;
-			if (this.authorizedClientRepository != null && this.clientRegistrationRepository != null) {
-				ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder
-					.builder()
-					.authorizationCode()
-					.refreshToken()
-					.clientCredentials()
-					.password()
-					.build();
-				DefaultReactiveOAuth2AuthorizedClientManager defaultReactiveOAuth2AuthorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager(
-						this.clientRegistrationRepository, getAuthorizedClientRepository());
-				defaultReactiveOAuth2AuthorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
-				authorizedClientManager = defaultReactiveOAuth2AuthorizedClientManager;
-			}
-
-			return authorizedClientManager;
-		}
-
+		return new String[] {
+				"org.springframework.security.config.annotation.web.reactive.ReactiveOAuth2ClientConfiguration" };
 	}
 	}
 
 
 }
 }

+ 589 - 0
config/src/test/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2AuthorizedClientManagerConfigurationTests.java

@@ -0,0 +1,589 @@
+/*
+ * Copyright 2002-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.annotation.web.reactive;
+
+import java.time.Duration;
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+import java.util.Objects;
+import java.util.function.Consumer;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.ArgumentCaptor;
+import reactor.core.publisher.Mono;
+
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.http.MediaType;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.web.server.MockServerWebExchange;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.config.oauth2.client.CommonOAuth2Provider;
+import org.springframework.security.config.test.SpringTestContext;
+import org.springframework.security.oauth2.client.AuthorizationCodeReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
+import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.JwtBearerReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
+import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.PasswordReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager;
+import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.RefreshTokenReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.TokenExchangeReactiveOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.endpoint.AbstractOAuth2AuthorizationGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.JwtBearerGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.TokenExchangeGrantRequest;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
+import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.server.WebSessionServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
+import org.springframework.security.oauth2.core.OAuth2Error;
+import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses;
+import org.springframework.security.oauth2.jwt.JoseHeaderNames;
+import org.springframework.security.oauth2.jwt.Jwt;
+import org.springframework.security.oauth2.jwt.JwtClaimNames;
+import org.springframework.security.oauth2.server.resource.authentication.JwtAuthenticationToken;
+import org.springframework.util.StringUtils;
+import org.springframework.web.server.ServerWebExchange;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyString;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+
+/**
+ * Tests for
+ * {@link ReactiveOAuth2ClientConfiguration.ReactiveOAuth2AuthorizedClientManagerConfiguration}.
+ *
+ * @author Steve Riesenberg
+ */
+public class ReactiveOAuth2AuthorizedClientManagerConfigurationTests {
+
+	private static ReactiveOAuth2AccessTokenResponseClient<? super AbstractOAuth2AuthorizationGrantRequest> MOCK_RESPONSE_CLIENT;
+
+	public final SpringTestContext spring = new SpringTestContext(this);
+
+	@Autowired
+	private ReactiveOAuth2AuthorizedClientManager authorizedClientManager;
+
+	@Autowired
+	private ReactiveClientRegistrationRepository clientRegistrationRepository;
+
+	@Autowired(required = false)
+	private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
+
+	@Autowired(required = false)
+	private ReactiveOAuth2AuthorizedClientService authorizedClientService;
+
+	@Autowired(required = false)
+	private AuthorizationCodeReactiveOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider;
+
+	private MockServerWebExchange exchange;
+
+	@BeforeEach
+	@SuppressWarnings("unchecked")
+	public void setUp() {
+		MOCK_RESPONSE_CLIENT = mock(ReactiveOAuth2AccessTokenResponseClient.class);
+		MockServerHttpRequest request = MockServerHttpRequest.get("/").build();
+		this.exchange = MockServerWebExchange.builder(request).build();
+	}
+
+	@Test
+	public void loadContextWhenOAuth2ClientEnabledThenConfigured() {
+		this.spring.register(MinimalOAuth2ClientConfig.class).autowire();
+		assertThat(this.authorizedClientManager).isNotNull();
+	}
+
+	@Test
+	public void authorizeWhenAuthorizationCodeAuthorizedClientProviderBeanThenUsed() {
+		this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire();
+
+		TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", null, "ROLE_USER");
+		// @formatter:off
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
+				.withClientRegistrationId("google")
+				.principal(authentication)
+				.attribute(ServerWebExchange.class.getName(), this.exchange)
+				.build();
+		assertThatExceptionOfType(ClientAuthorizationRequiredException.class)
+				.isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
+				.extracting(OAuth2AuthorizationException::getError)
+				.extracting(OAuth2Error::getErrorCode)
+				.isEqualTo("client_authorization_required");
+		// @formatter:on
+
+		verify(this.authorizationCodeAuthorizedClientProvider).authorize(any(OAuth2AuthorizationContext.class));
+	}
+
+	@Test
+	public void authorizeWhenAuthorizedClientServiceBeanThenUsed() {
+		this.spring.register(CustomAuthorizedClientServiceConfig.class).autowire();
+
+		TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", null, "ROLE_USER");
+		// @formatter:off
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
+				.withClientRegistrationId("google")
+				.principal(authentication)
+				.attribute(ServerWebExchange.class.getName(), this.exchange)
+				.build();
+		assertThatExceptionOfType(ClientAuthorizationRequiredException.class)
+				.isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest).block())
+				.extracting(OAuth2AuthorizationException::getError)
+				.extracting(OAuth2Error::getErrorCode)
+				.isEqualTo("client_authorization_required");
+		// @formatter:on
+
+		verify(this.authorizedClientService).loadAuthorizedClient(authorizeRequest.getClientRegistrationId(),
+				authentication.getName());
+	}
+
+	@Test
+	public void authorizeWhenRefreshTokenAccessTokenResponseClientBeanThenUsed() {
+		this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire();
+		testRefreshTokenGrant();
+	}
+
+	@Test
+	public void authorizeWhenRefreshTokenAuthorizedClientProviderBeanThenUsed() {
+		this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire();
+		testRefreshTokenGrant();
+	}
+
+	private void testRefreshTokenGrant() {
+		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(OAuth2RefreshTokenGrantRequest.class)))
+			.willReturn(Mono.just(accessTokenResponse));
+
+		TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", null, "ROLE_USER");
+		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google")
+			.block();
+		assertThat(clientRegistration).isNotNull();
+		OAuth2AuthorizedClient existingAuthorizedClient = new OAuth2AuthorizedClient(clientRegistration,
+				authentication.getName(), getExpiredAccessToken(), TestOAuth2RefreshTokens.refreshToken());
+		this.authorizedClientRepository.saveAuthorizedClient(existingAuthorizedClient, authentication, this.exchange)
+			.block();
+		// @formatter:off
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
+				.withClientRegistrationId(clientRegistration.getRegistrationId())
+				.principal(authentication)
+				.attribute(ServerWebExchange.class.getName(), this.exchange)
+				.build();
+		// @formatter:on
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
+		assertThat(authorizedClient).isNotNull();
+
+		ArgumentCaptor<OAuth2RefreshTokenGrantRequest> grantRequestCaptor = ArgumentCaptor
+			.forClass(OAuth2RefreshTokenGrantRequest.class);
+		verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture());
+
+		OAuth2RefreshTokenGrantRequest grantRequest = grantRequestCaptor.getValue();
+		assertThat(grantRequest.getClientRegistration().getRegistrationId())
+			.isEqualTo(clientRegistration.getRegistrationId());
+		assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.REFRESH_TOKEN);
+		assertThat(grantRequest.getAccessToken()).isEqualTo(existingAuthorizedClient.getAccessToken());
+		assertThat(grantRequest.getRefreshToken()).isEqualTo(existingAuthorizedClient.getRefreshToken());
+	}
+
+	@Test
+	public void authorizeWhenClientCredentialsAccessTokenResponseClientBeanThenUsed() {
+		this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire();
+		testClientCredentialsGrant();
+	}
+
+	@Test
+	public void authorizeWhenClientCredentialsAuthorizedClientProviderBeanThenUsed() {
+		this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire();
+		testClientCredentialsGrant();
+	}
+
+	private void testClientCredentialsGrant() {
+		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(OAuth2ClientCredentialsGrantRequest.class)))
+			.willReturn(Mono.just(accessTokenResponse));
+
+		TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", null, "ROLE_USER");
+		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("github")
+			.block();
+		assertThat(clientRegistration).isNotNull();
+		// @formatter:off
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
+				.withClientRegistrationId(clientRegistration.getRegistrationId())
+				.principal(authentication)
+				.attribute(ServerWebExchange.class.getName(), this.exchange)
+				.build();
+		// @formatter:on
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
+		assertThat(authorizedClient).isNotNull();
+
+		ArgumentCaptor<OAuth2ClientCredentialsGrantRequest> grantRequestCaptor = ArgumentCaptor
+			.forClass(OAuth2ClientCredentialsGrantRequest.class);
+		verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture());
+
+		OAuth2ClientCredentialsGrantRequest grantRequest = grantRequestCaptor.getValue();
+		assertThat(grantRequest.getClientRegistration().getRegistrationId())
+			.isEqualTo(clientRegistration.getRegistrationId());
+		assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS);
+	}
+
+	@Test
+	public void authorizeWhenPasswordAccessTokenResponseClientBeanThenUsed() {
+		this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire();
+		testPasswordGrant();
+	}
+
+	@Test
+	public void authorizeWhenPasswordAuthorizedClientProviderBeanThenUsed() {
+		this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire();
+		testPasswordGrant();
+	}
+
+	private void testPasswordGrant() {
+		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(OAuth2PasswordGrantRequest.class)))
+			.willReturn(Mono.just(accessTokenResponse));
+
+		TestingAuthenticationToken authentication = new TestingAuthenticationToken("user", "password", "ROLE_USER");
+		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("facebook")
+			.block();
+		assertThat(clientRegistration).isNotNull();
+		MockServerHttpRequest request = MockServerHttpRequest.post("/")
+			.contentType(MediaType.APPLICATION_FORM_URLENCODED)
+			.body("username=user&password=password");
+		this.exchange = MockServerWebExchange.builder(request).build();
+		// @formatter:off
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
+				.withClientRegistrationId(clientRegistration.getRegistrationId())
+				.principal(authentication)
+				.attribute(ServerWebExchange.class.getName(), this.exchange)
+				.build();
+		// @formatter:on
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
+		assertThat(authorizedClient).isNotNull();
+
+		ArgumentCaptor<OAuth2PasswordGrantRequest> grantRequestCaptor = ArgumentCaptor
+			.forClass(OAuth2PasswordGrantRequest.class);
+		verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture());
+
+		OAuth2PasswordGrantRequest grantRequest = grantRequestCaptor.getValue();
+		assertThat(grantRequest.getClientRegistration().getRegistrationId())
+			.isEqualTo(clientRegistration.getRegistrationId());
+		assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.PASSWORD);
+		assertThat(grantRequest.getUsername()).isEqualTo("user");
+		assertThat(grantRequest.getPassword()).isEqualTo("password");
+	}
+
+	@Test
+	public void authorizeWhenJwtBearerAccessTokenResponseClientBeanThenUsed() {
+		this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire();
+		testJwtBearerGrant();
+	}
+
+	@Test
+	public void authorizeWhenJwtBearerAuthorizedClientProviderBeanThenUsed() {
+		this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire();
+		testJwtBearerGrant();
+	}
+
+	private void testJwtBearerGrant() {
+		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(JwtBearerGrantRequest.class)))
+			.willReturn(Mono.just(accessTokenResponse));
+
+		JwtAuthenticationToken authentication = new JwtAuthenticationToken(getJwt());
+		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("okta").block();
+		assertThat(clientRegistration).isNotNull();
+		// @formatter:off
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
+				.withClientRegistrationId(clientRegistration.getRegistrationId())
+				.principal(authentication)
+				.attribute(ServerWebExchange.class.getName(), this.exchange)
+				.build();
+		// @formatter:on
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
+		assertThat(authorizedClient).isNotNull();
+
+		ArgumentCaptor<JwtBearerGrantRequest> grantRequestCaptor = ArgumentCaptor.forClass(JwtBearerGrantRequest.class);
+		verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture());
+
+		JwtBearerGrantRequest grantRequest = grantRequestCaptor.getValue();
+		assertThat(grantRequest.getClientRegistration().getRegistrationId())
+			.isEqualTo(clientRegistration.getRegistrationId());
+		assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.JWT_BEARER);
+		assertThat(grantRequest.getJwt().getSubject()).isEqualTo("user");
+	}
+
+	@Test
+	public void authorizeWhenTokenExchangeAccessTokenResponseClientBeanThenUsed() {
+		this.spring.register(CustomAccessTokenResponseClientsConfig.class).autowire();
+		testTokenExchangeGrant();
+	}
+
+	@Test
+	public void authorizeWhenTokenExchangeAuthorizedClientProviderBeanThenUsed() {
+		this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire();
+		testTokenExchangeGrant();
+	}
+
+	private void testTokenExchangeGrant() {
+		OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
+		given(MOCK_RESPONSE_CLIENT.getTokenResponse(any(TokenExchangeGrantRequest.class)))
+			.willReturn(Mono.just(accessTokenResponse));
+
+		JwtAuthenticationToken authentication = new JwtAuthenticationToken(getJwt());
+		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("auth0").block();
+		assertThat(clientRegistration).isNotNull();
+		// @formatter:off
+		OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
+				.withClientRegistrationId(clientRegistration.getRegistrationId())
+				.principal(authentication)
+				.attribute(ServerWebExchange.class.getName(), this.exchange)
+				.build();
+		// @formatter:on
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest).block();
+		assertThat(authorizedClient).isNotNull();
+
+		ArgumentCaptor<TokenExchangeGrantRequest> grantRequestCaptor = ArgumentCaptor
+			.forClass(TokenExchangeGrantRequest.class);
+		verify(MOCK_RESPONSE_CLIENT).getTokenResponse(grantRequestCaptor.capture());
+
+		TokenExchangeGrantRequest grantRequest = grantRequestCaptor.getValue();
+		assertThat(grantRequest.getClientRegistration().getRegistrationId())
+			.isEqualTo(clientRegistration.getRegistrationId());
+		assertThat(grantRequest.getGrantType()).isEqualTo(AuthorizationGrantType.TOKEN_EXCHANGE);
+		assertThat(grantRequest.getSubjectToken()).isEqualTo(authentication.getToken());
+	}
+
+	private static OAuth2AccessToken getExpiredAccessToken() {
+		Instant expiresAt = Instant.now().minusSeconds(60);
+		Instant issuedAt = expiresAt.minus(Duration.ofDays(1));
+		return new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "scopes", issuedAt, expiresAt,
+				new HashSet<>(Arrays.asList("read", "write")));
+	}
+
+	private static Jwt getJwt() {
+		Instant issuedAt = Instant.now();
+		return new Jwt("token", issuedAt, issuedAt.plusSeconds(300),
+				Collections.singletonMap(JoseHeaderNames.ALG, "RS256"),
+				Collections.singletonMap(JwtClaimNames.SUB, "user"));
+	}
+
+	@Configuration
+	@EnableWebFluxSecurity
+	static class MinimalOAuth2ClientConfig extends OAuth2ClientBaseConfig {
+
+		@Bean
+		ServerOAuth2AuthorizedClientRepository authorizedClientRepository() {
+			return new WebSessionServerOAuth2AuthorizedClientRepository();
+		}
+
+	}
+
+	@Configuration
+	@EnableWebFluxSecurity
+	static class CustomAuthorizedClientServiceConfig extends OAuth2ClientBaseConfig {
+
+		@Bean
+		ReactiveOAuth2AuthorizedClientService authorizedClientService() {
+			ReactiveOAuth2AuthorizedClientService authorizedClientService = mock(
+					ReactiveOAuth2AuthorizedClientService.class);
+			given(authorizedClientService.loadAuthorizedClient(anyString(), anyString())).willReturn(Mono.empty());
+			return authorizedClientService;
+		}
+
+	}
+
+	@Configuration
+	@EnableWebFluxSecurity
+	static class CustomAccessTokenResponseClientsConfig extends MinimalOAuth2ClientConfig {
+
+		@Bean
+		ReactiveOAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> authorizationCodeAccessTokenResponseClient() {
+			return new MockAccessTokenResponseClient<>();
+		}
+
+		@Bean
+		ReactiveOAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> refreshTokenTokenAccessResponseClient() {
+			return new MockAccessTokenResponseClient<>();
+		}
+
+		@Bean
+		ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsAccessTokenResponseClient() {
+			return new MockAccessTokenResponseClient<>();
+		}
+
+		@Bean
+		ReactiveOAuth2AccessTokenResponseClient<OAuth2PasswordGrantRequest> passwordAccessTokenResponseClient() {
+			return new MockAccessTokenResponseClient<>();
+		}
+
+		@Bean
+		ReactiveOAuth2AccessTokenResponseClient<JwtBearerGrantRequest> jwtBearerAccessTokenResponseClient() {
+			return new MockAccessTokenResponseClient<>();
+		}
+
+		@Bean
+		ReactiveOAuth2AccessTokenResponseClient<TokenExchangeGrantRequest> tokenExchangeAccessTokenResponseClient() {
+			return new MockAccessTokenResponseClient<>();
+		}
+
+	}
+
+	@Configuration
+	@EnableWebFluxSecurity
+	static class CustomAuthorizedClientProvidersConfig extends MinimalOAuth2ClientConfig {
+
+		@Bean
+		AuthorizationCodeReactiveOAuth2AuthorizedClientProvider authorizationCode() {
+			return spy(new AuthorizationCodeReactiveOAuth2AuthorizedClientProvider());
+		}
+
+		@Bean
+		RefreshTokenReactiveOAuth2AuthorizedClientProvider refreshToken() {
+			RefreshTokenReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new RefreshTokenReactiveOAuth2AuthorizedClientProvider();
+			authorizedClientProvider.setAccessTokenResponseClient(new MockAccessTokenResponseClient<>());
+			return authorizedClientProvider;
+		}
+
+		@Bean
+		ClientCredentialsReactiveOAuth2AuthorizedClientProvider clientCredentials() {
+			ClientCredentialsReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new ClientCredentialsReactiveOAuth2AuthorizedClientProvider();
+			authorizedClientProvider.setAccessTokenResponseClient(new MockAccessTokenResponseClient<>());
+			return authorizedClientProvider;
+		}
+
+		@Bean
+		PasswordReactiveOAuth2AuthorizedClientProvider password() {
+			PasswordReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new PasswordReactiveOAuth2AuthorizedClientProvider();
+			authorizedClientProvider.setAccessTokenResponseClient(new MockAccessTokenResponseClient<>());
+			return authorizedClientProvider;
+		}
+
+		@Bean
+		JwtBearerReactiveOAuth2AuthorizedClientProvider jwtBearer() {
+			JwtBearerReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new JwtBearerReactiveOAuth2AuthorizedClientProvider();
+			authorizedClientProvider.setAccessTokenResponseClient(new MockAccessTokenResponseClient<>());
+			return authorizedClientProvider;
+		}
+
+		@Bean
+		TokenExchangeReactiveOAuth2AuthorizedClientProvider tokenExchange() {
+			TokenExchangeReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = new TokenExchangeReactiveOAuth2AuthorizedClientProvider();
+			authorizedClientProvider.setAccessTokenResponseClient(new MockAccessTokenResponseClient<>());
+			return authorizedClientProvider;
+		}
+
+	}
+
+	abstract static class OAuth2ClientBaseConfig {
+
+		@Bean
+		ReactiveClientRegistrationRepository clientRegistrationRepository() {
+			// @formatter:off
+			return new InMemoryReactiveClientRegistrationRepository(
+					CommonOAuth2Provider.GOOGLE.getBuilder("google")
+						.clientId("google-client-id")
+						.clientSecret("google-client-secret")
+						.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+						.build(),
+					CommonOAuth2Provider.GITHUB.getBuilder("github")
+						.clientId("github-client-id")
+						.clientSecret("github-client-secret")
+						.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
+						.build(),
+					CommonOAuth2Provider.FACEBOOK.getBuilder("facebook")
+						.clientId("facebook-client-id")
+						.clientSecret("facebook-client-secret")
+						.authorizationGrantType(AuthorizationGrantType.PASSWORD)
+						.build(),
+					CommonOAuth2Provider.OKTA.getBuilder("okta")
+						.clientId("okta-client-id")
+						.clientSecret("okta-client-secret")
+						.authorizationGrantType(AuthorizationGrantType.JWT_BEARER)
+						.build(),
+					ClientRegistration.withRegistrationId("auth0")
+						.clientName("Auth0")
+						.clientId("auth0-client-id")
+						.clientSecret("auth0-client-secret")
+						.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
+						.authorizationGrantType(AuthorizationGrantType.TOKEN_EXCHANGE)
+						.scope("user.read", "user.write")
+						.build());
+			// @formatter:on
+		}
+
+		@Bean
+		Consumer<DefaultReactiveOAuth2AuthorizedClientManager> authorizedClientManagerConsumer() {
+			return (authorizedClientManager) -> authorizedClientManager
+				.setContextAttributesMapper((authorizeRequest) -> {
+					ServerWebExchange exchange = Objects
+						.requireNonNull(authorizeRequest.getAttribute(ServerWebExchange.class.getName()));
+					return exchange.getFormData().map((parameters) -> {
+						String username = parameters.getFirst(OAuth2ParameterNames.USERNAME);
+						String password = parameters.getFirst(OAuth2ParameterNames.PASSWORD);
+
+						Map<String, Object> attributes = Collections.emptyMap();
+						if (StringUtils.hasText(username) && StringUtils.hasText(password)) {
+							attributes = new HashMap<>();
+							attributes.put(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, username);
+							attributes.put(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, password);
+						}
+
+						return attributes;
+					});
+				});
+
+		}
+
+	}
+
+	private static class MockAccessTokenResponseClient<T extends AbstractOAuth2AuthorizationGrantRequest>
+			implements ReactiveOAuth2AccessTokenResponseClient<T> {
+
+		@Override
+		public Mono<OAuth2AccessTokenResponse> getTokenResponse(T grantRequest) {
+			return MOCK_RESPONSE_CLIENT.getTokenResponse(grantRequest);
+		}
+
+	}
+
+}