Răsfoiți Sursa

Add OAuth2AuthorizedClientManager Registrar

Joe Grandja 2 ani în urmă
părinte
comite
f3d90b38e2

+ 174 - 4
config/src/main/java/org/springframework/security/config/annotation/web/configuration/OAuth2ClientConfiguration.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -18,17 +18,40 @@ package org.springframework.security.config.annotation.web.configuration;
 
 import java.util.List;
 
+import org.springframework.beans.BeanMetadataElement;
+import org.springframework.beans.BeansException;
+import org.springframework.beans.factory.BeanFactory;
+import org.springframework.beans.factory.BeanFactoryAware;
+import org.springframework.beans.factory.BeanFactoryUtils;
+import org.springframework.beans.factory.ListableBeanFactory;
 import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.beans.factory.config.BeanDefinition;
+import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
+import org.springframework.beans.factory.config.RuntimeBeanReference;
+import org.springframework.beans.factory.support.BeanDefinitionBuilder;
+import org.springframework.beans.factory.support.BeanDefinitionRegistry;
+import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
+import org.springframework.beans.factory.support.ManagedList;
+import org.springframework.context.annotation.AnnotationBeanNameGenerator;
+import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
 import org.springframework.context.annotation.Import;
 import org.springframework.context.annotation.ImportSelector;
+import org.springframework.core.ResolvableType;
 import org.springframework.core.type.AnnotationMetadata;
 import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.DelegatingOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientProviderBuilder;
+import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.DefaultOAuth2AuthorizedClientManager;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
@@ -48,7 +71,8 @@ import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
  * @since 5.1
  * @see OAuth2ImportSelector
  */
-@Import(OAuth2ClientConfiguration.OAuth2ClientWebMvcImportSelector.class)
+@Import({ OAuth2ClientConfiguration.OAuth2ClientWebMvcImportSelector.class,
+		OAuth2ClientConfiguration.OAuth2AuthorizedClientManagerConfiguration.class })
 final class OAuth2ClientConfiguration {
 
 	private static final boolean webMvcPresent;
@@ -65,8 +89,22 @@ final class OAuth2ClientConfiguration {
 			if (!webMvcPresent) {
 				return new String[0];
 			}
-			return new String[] { "org.springframework.security.config.annotation.web.configuration."
-					+ "OAuth2ClientConfiguration.OAuth2ClientWebMvcSecurityConfiguration" };
+			return new String[] {
+					OAuth2ClientConfiguration.class.getName() + ".OAuth2ClientWebMvcSecurityConfiguration" };
+		}
+
+	}
+
+	/**
+	 * @author Joe Grandja
+	 * @since 6.2.0
+	 */
+	@Configuration(proxyBeanMethods = false)
+	static class OAuth2AuthorizedClientManagerConfiguration {
+
+		@Bean
+		OAuth2AuthorizedClientManagerRegistrar authorizedClientManagerRegistrar() {
+			return new OAuth2AuthorizedClientManagerRegistrar();
 		}
 
 	}
@@ -160,4 +198,136 @@ final class OAuth2ClientConfiguration {
 
 	}
 
+	/**
+	 * A registrar for registering the default {@link OAuth2AuthorizedClientManager} bean
+	 * definition, if not already present.
+	 *
+	 * @author Joe Grandja
+	 * @since 6.2.0
+	 */
+	static class OAuth2AuthorizedClientManagerRegistrar
+			implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware {
+
+		private final AnnotationBeanNameGenerator beanNameGenerator = new AnnotationBeanNameGenerator();
+
+		private BeanFactory beanFactory;
+
+		@Override
+		public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
+			String[] authorizedClientManagerBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
+					(ListableBeanFactory) this.beanFactory, OAuth2AuthorizedClientManager.class, true, true);
+			if (authorizedClientManagerBeanNames.length != 0) {
+				return;
+			}
+
+			String[] clientRegistrationRepositoryBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
+					(ListableBeanFactory) this.beanFactory, ClientRegistrationRepository.class, true, true);
+			String[] authorizedClientRepositoryBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
+					(ListableBeanFactory) this.beanFactory, OAuth2AuthorizedClientRepository.class, true, true);
+			if (clientRegistrationRepositoryBeanNames.length != 1 || authorizedClientRepositoryBeanNames.length != 1) {
+				return;
+			}
+
+			BeanDefinition beanDefinition = BeanDefinitionBuilder
+					.genericBeanDefinition(DefaultOAuth2AuthorizedClientManager.class)
+					.addConstructorArgReference(clientRegistrationRepositoryBeanNames[0])
+					.addConstructorArgReference(authorizedClientRepositoryBeanNames[0])
+					.addPropertyValue("authorizedClientProvider", getAuthorizedClientProvider()).getBeanDefinition();
+
+			registry.registerBeanDefinition(this.beanNameGenerator.generateBeanName(beanDefinition, registry),
+					beanDefinition);
+		}
+
+		@Override
+		public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
+		}
+
+		private BeanDefinition getAuthorizedClientProvider() {
+			ManagedList<Object> authorizedClientProviders = new ManagedList<>();
+			authorizedClientProviders.add(getAuthorizationCodeAuthorizedClientProvider());
+			authorizedClientProviders.add(getRefreshTokenAuthorizedClientProvider());
+			authorizedClientProviders.add(getClientCredentialsAuthorizedClientProvider());
+			authorizedClientProviders.add(getPasswordAuthorizedClientProvider());
+			return BeanDefinitionBuilder.genericBeanDefinition(DelegatingOAuth2AuthorizedClientProvider.class)
+					.addConstructorArgValue(authorizedClientProviders).getBeanDefinition();
+		}
+
+		private BeanMetadataElement getAuthorizationCodeAuthorizedClientProvider() {
+			String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
+					(ListableBeanFactory) this.beanFactory, AuthorizationCodeOAuth2AuthorizedClientProvider.class, true,
+					true);
+			if (beanNames.length == 1) {
+				return new RuntimeBeanReference(beanNames[0]);
+			}
+
+			return BeanDefinitionBuilder.genericBeanDefinition(AuthorizationCodeOAuth2AuthorizedClientProvider.class)
+					.getBeanDefinition();
+		}
+
+		private BeanMetadataElement getRefreshTokenAuthorizedClientProvider() {
+			String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
+					(ListableBeanFactory) this.beanFactory, RefreshTokenOAuth2AuthorizedClientProvider.class, true,
+					true);
+			if (beanNames.length == 1) {
+				return new RuntimeBeanReference(beanNames[0]);
+			}
+
+			BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder
+					.genericBeanDefinition(RefreshTokenOAuth2AuthorizedClientProvider.class);
+			ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
+					OAuth2RefreshTokenGrantRequest.class);
+			beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors((ListableBeanFactory) this.beanFactory,
+					resolvableType, true, true);
+			if (beanNames.length == 1) {
+				beanDefinitionBuilder.addPropertyReference("accessTokenResponseClient", beanNames[0]);
+			}
+			return beanDefinitionBuilder.getBeanDefinition();
+		}
+
+		private BeanMetadataElement getClientCredentialsAuthorizedClientProvider() {
+			String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
+					(ListableBeanFactory) this.beanFactory, ClientCredentialsOAuth2AuthorizedClientProvider.class, true,
+					true);
+			if (beanNames.length == 1) {
+				return new RuntimeBeanReference(beanNames[0]);
+			}
+
+			BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder
+					.genericBeanDefinition(ClientCredentialsOAuth2AuthorizedClientProvider.class);
+			ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
+					OAuth2ClientCredentialsGrantRequest.class);
+			beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors((ListableBeanFactory) this.beanFactory,
+					resolvableType, true, true);
+			if (beanNames.length == 1) {
+				beanDefinitionBuilder.addPropertyReference("accessTokenResponseClient", beanNames[0]);
+			}
+			return beanDefinitionBuilder.getBeanDefinition();
+		}
+
+		private BeanMetadataElement getPasswordAuthorizedClientProvider() {
+			String[] beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
+					(ListableBeanFactory) this.beanFactory, PasswordOAuth2AuthorizedClientProvider.class, true, true);
+			if (beanNames.length == 1) {
+				return new RuntimeBeanReference(beanNames[0]);
+			}
+
+			BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder
+					.genericBeanDefinition(PasswordOAuth2AuthorizedClientProvider.class);
+			ResolvableType resolvableType = ResolvableType.forClassWithGenerics(OAuth2AccessTokenResponseClient.class,
+					OAuth2PasswordGrantRequest.class);
+			beanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors((ListableBeanFactory) this.beanFactory,
+					resolvableType, true, true);
+			if (beanNames.length == 1) {
+				beanDefinitionBuilder.addPropertyReference("accessTokenResponseClient", beanNames[0]);
+			}
+			return beanDefinitionBuilder.getBeanDefinition();
+		}
+
+		@Override
+		public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
+			this.beanFactory = beanFactory;
+		}
+
+	}
+
 }

+ 218 - 0
config/src/test/java/org/springframework/security/config/annotation/web/configuration/OAuth2AuthorizedClientManagerConfigurationTests.java

@@ -0,0 +1,218 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.annotation.web.configuration;
+
+import java.util.Arrays;
+
+import org.junit.jupiter.api.Test;
+
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.http.converter.FormHttpMessageConverter;
+import org.springframework.http.converter.json.MappingJackson2HttpMessageConverter;
+import org.springframework.security.config.Customizer;
+import org.springframework.security.config.annotation.web.builders.HttpSecurity;
+import org.springframework.security.config.test.SpringTestContext;
+import org.springframework.security.oauth2.client.AuthorizationCodeOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.ClientCredentialsOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
+import org.springframework.security.oauth2.client.PasswordOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.RefreshTokenOAuth2AuthorizedClientProvider;
+import org.springframework.security.oauth2.client.endpoint.DefaultAuthorizationCodeTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.DefaultClientCredentialsTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.DefaultPasswordTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.DefaultRefreshTokenTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
+import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.OAuth2PasswordGrantRequest;
+import org.springframework.security.oauth2.client.endpoint.OAuth2RefreshTokenGrantRequest;
+import org.springframework.security.oauth2.client.http.OAuth2ErrorResponseErrorHandler;
+import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
+import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserService;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.userinfo.DefaultOAuth2UserService;
+import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
+import org.springframework.security.oauth2.client.userinfo.OAuth2UserService;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.http.converter.OAuth2AccessTokenResponseHttpMessageConverter;
+import org.springframework.security.oauth2.core.oidc.user.OidcUser;
+import org.springframework.security.oauth2.core.user.OAuth2User;
+import org.springframework.security.web.SecurityFilterChain;
+import org.springframework.web.client.RestOperations;
+import org.springframework.web.client.RestTemplate;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+
+/**
+ * Tests for {@link OAuth2ClientConfiguration.OAuth2AuthorizedClientManagerConfiguration}.
+ *
+ * @author Joe Grandja
+ */
+public class OAuth2AuthorizedClientManagerConfigurationTests {
+
+	public final SpringTestContext spring = new SpringTestContext(this);
+
+	@Autowired
+	private OAuth2AuthorizedClientManager authorizedClientManager;
+
+	@Autowired(required = false)
+	private AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeAuthorizedClientProvider;
+
+	@Autowired(required = false)
+	private RefreshTokenOAuth2AuthorizedClientProvider refreshTokenAuthorizedClientProvider;
+
+	@Autowired(required = false)
+	private ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsAuthorizedClientProvider;
+
+	@Autowired(required = false)
+	private PasswordOAuth2AuthorizedClientProvider passwordAuthorizedClientProvider;
+
+	@Test
+	public void loadContextWhenCustomRestOperationsThenConfigured() {
+		this.spring.register(CustomRestOperationsConfig.class).autowire();
+		assertThat(this.authorizedClientManager).isNotNull();
+	}
+
+	@Test
+	public void loadContextWhenCustomAuthorizedClientProvidersThenConfigured() {
+		this.spring.register(CustomAuthorizedClientProvidersConfig.class).autowire();
+		assertThat(this.authorizedClientManager).isNotNull();
+	}
+
+	@Configuration
+	@EnableWebSecurity
+	static class CustomRestOperationsConfig extends OAuth2ClientBaseConfig {
+
+		// TODO This needs to be autoconfigured in OAuth2LoginConfigurer and
+		// OAuth2ClientConfigurer
+		@Bean
+		OAuth2AccessTokenResponseClient<OAuth2AuthorizationCodeGrantRequest> authorizationCodeTokenResponseClient() {
+			DefaultAuthorizationCodeTokenResponseClient tokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient();
+			tokenResponseClient.setRestOperations(restOperations());
+			return spy(tokenResponseClient);
+		}
+
+		@Bean
+		OAuth2AccessTokenResponseClient<OAuth2RefreshTokenGrantRequest> refreshTokenTokenResponseClient() {
+			DefaultRefreshTokenTokenResponseClient tokenResponseClient = new DefaultRefreshTokenTokenResponseClient();
+			tokenResponseClient.setRestOperations(restOperations());
+			return spy(tokenResponseClient);
+		}
+
+		@Bean
+		OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient() {
+			DefaultClientCredentialsTokenResponseClient tokenResponseClient = new DefaultClientCredentialsTokenResponseClient();
+			tokenResponseClient.setRestOperations(restOperations());
+			return spy(tokenResponseClient);
+		}
+
+		@Bean
+		OAuth2AccessTokenResponseClient<OAuth2PasswordGrantRequest> passwordTokenResponseClient() {
+			DefaultPasswordTokenResponseClient tokenResponseClient = new DefaultPasswordTokenResponseClient();
+			tokenResponseClient.setRestOperations(restOperations());
+			return spy(tokenResponseClient);
+		}
+
+		// NOTE: This is autoconfigured in OAuth2LoginConfigurer and
+		// OAuth2ClientConfigurer
+		@Bean
+		OAuth2UserService<OAuth2UserRequest, OAuth2User> oauth2UserService() {
+			DefaultOAuth2UserService userService = new DefaultOAuth2UserService();
+			userService.setRestOperations(restOperations());
+			return spy(userService);
+		}
+
+		// NOTE: This is autoconfigured in OAuth2LoginConfigurer and
+		// OAuth2ClientConfigurer
+		@Bean
+		OAuth2UserService<OidcUserRequest, OidcUser> oidcUserService() {
+			OidcUserService userService = new OidcUserService();
+			userService.setOauth2UserService(oauth2UserService());
+			return spy(userService);
+		}
+
+		@Bean
+		RestOperations restOperations() {
+			// Minimum required configuration
+			RestTemplate restTemplate = new RestTemplate(Arrays.asList(new FormHttpMessageConverter(),
+					new OAuth2AccessTokenResponseHttpMessageConverter(), new MappingJackson2HttpMessageConverter()));
+			restTemplate.setErrorHandler(new OAuth2ErrorResponseErrorHandler());
+
+			// TODO Add custom configuration, eg. Proxy, TLS, etc
+
+			return spy(restTemplate);
+		}
+
+	}
+
+	@Configuration
+	@EnableWebSecurity
+	static class CustomAuthorizedClientProvidersConfig extends OAuth2ClientBaseConfig {
+
+		@Bean
+		AuthorizationCodeOAuth2AuthorizedClientProvider authorizationCodeProvider() {
+			return mock(AuthorizationCodeOAuth2AuthorizedClientProvider.class);
+		}
+
+		@Bean
+		RefreshTokenOAuth2AuthorizedClientProvider refreshTokenProvider() {
+			return mock(RefreshTokenOAuth2AuthorizedClientProvider.class);
+		}
+
+		@Bean
+		ClientCredentialsOAuth2AuthorizedClientProvider clientCredentialsProvider() {
+			return mock(ClientCredentialsOAuth2AuthorizedClientProvider.class);
+		}
+
+		@Bean
+		PasswordOAuth2AuthorizedClientProvider passwordProvider() {
+			return mock(PasswordOAuth2AuthorizedClientProvider.class);
+		}
+
+	}
+
+	abstract static class OAuth2ClientBaseConfig {
+
+		@Bean
+		SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+				.authorizeHttpRequests(authorize ->
+					authorize.anyRequest().authenticated())
+				.oauth2Login(Customizer.withDefaults())
+				.oauth2Client(Customizer.withDefaults());
+			return http.build();
+			// @formatter:on
+		}
+
+		@Bean
+		ClientRegistrationRepository clientRegistrationRepository() {
+			return mock(ClientRegistrationRepository.class);
+		}
+
+		@Bean
+		OAuth2AuthorizedClientRepository authorizedClientRepository() {
+			return mock(OAuth2AuthorizedClientRepository.class);
+		}
+
+	}
+
+}