소스 검색

Register OAuth2AuthorizedClientArgumentResolver for XML Config

Closes gh-8669
Joe Grandja 5 년 전
부모
커밋
951e64185b

+ 53 - 12
config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java

@@ -15,18 +15,8 @@
  */
 package org.springframework.security.config.http;
 
-import java.security.SecureRandom;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Map;
-import java.util.function.Function;
-import javax.servlet.http.HttpServletRequest;
-
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
-import org.w3c.dom.Element;
-
 import org.springframework.beans.BeanMetadataElement;
 import org.springframework.beans.factory.config.BeanDefinition;
 import org.springframework.beans.factory.config.BeanReference;
@@ -63,8 +53,18 @@ import org.springframework.security.web.authentication.www.BasicAuthenticationEn
 import org.springframework.security.web.authentication.www.BasicAuthenticationFilter;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.util.Assert;
+import org.springframework.util.ClassUtils;
 import org.springframework.util.StringUtils;
 import org.springframework.util.xml.DomUtils;
+import org.w3c.dom.Element;
+
+import javax.servlet.http.HttpServletRequest;
+import java.security.SecureRandom;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
 
 import static org.springframework.security.config.http.SecurityFilters.ANONYMOUS_FILTER;
 import static org.springframework.security.config.http.SecurityFilters.BASIC_AUTH_FILTER;
@@ -160,12 +160,16 @@ final class AuthenticationConfigBuilder {
 
 	private String openIDLoginPage;
 
+	private boolean oauth2LoginEnabled;
+	private boolean defaultAuthorizedClientRepositoryRegistered;
 	private String oauth2LoginFilterId;
 	private BeanDefinition oauth2AuthorizationRequestRedirectFilter;
 	private BeanDefinition oauth2LoginEntryPoint;
 	private BeanReference oauth2LoginAuthenticationProviderRef;
 	private BeanReference oauth2LoginOidcAuthenticationProviderRef;
 	private BeanDefinition oauth2LoginLinks;
+
+	private boolean oauth2ClientEnabled;
 	private BeanDefinition authorizationRequestRedirectFilter;
 	private BeanDefinition authorizationCodeGrantFilter;
 	private BeanReference authorizationCodeAuthenticationProviderRef;
@@ -196,8 +200,7 @@ final class AuthenticationConfigBuilder {
 		createBasicFilter(authenticationManager);
 		createBearerTokenAuthenticationFilter(authenticationManager);
 		createFormLoginFilter(sessionStrategy, authenticationManager);
-		createOAuth2LoginFilter(sessionStrategy, authenticationManager);
-		createOAuth2ClientFilter(requestCache, authenticationManager);
+		createOAuth2ClientFilters(sessionStrategy, requestCache, authenticationManager);
 		createOpenIDLoginFilter(sessionStrategy, authenticationManager);
 		createX509Filter(authenticationManager);
 		createJeeFilter(authenticationManager);
@@ -274,15 +277,27 @@ final class AuthenticationConfigBuilder {
 		}
 	}
 
+	void createOAuth2ClientFilters(BeanReference sessionStrategy, BeanReference requestCache,
+			BeanReference authenticationManager) {
+		createOAuth2LoginFilter(sessionStrategy, authenticationManager);
+		createOAuth2ClientFilter(requestCache, authenticationManager);
+		registerOAuth2ClientPostProcessors();
+	}
+
 	void createOAuth2LoginFilter(BeanReference sessionStrategy, BeanReference authManager) {
 		Element oauth2LoginElt = DomUtils.getChildElementByTagName(this.httpElt, Elements.OAUTH2_LOGIN);
 		if (oauth2LoginElt == null) {
 			return;
 		}
+		this.oauth2LoginEnabled = true;
 
 		OAuth2LoginBeanDefinitionParser parser = new OAuth2LoginBeanDefinitionParser(requestCache, portMapper,
 				portResolver, sessionStrategy, allowSessionCreation);
 		BeanDefinition oauth2LoginFilterBean = parser.parse(oauth2LoginElt, this.pc);
+
+		BeanDefinition defaultAuthorizedClientRepository = parser.getDefaultAuthorizedClientRepository();
+		registerDefaultAuthorizedClientRepositoryIfNecessary(defaultAuthorizedClientRepository);
+
 		oauth2LoginFilterBean.getPropertyValues().addPropertyValue("authenticationManager", authManager);
 
 		// retrieve the other bean result
@@ -319,11 +334,15 @@ final class AuthenticationConfigBuilder {
 		if (oauth2ClientElt == null) {
 			return;
 		}
+		this.oauth2ClientEnabled = true;
 
 		OAuth2ClientBeanDefinitionParser parser = new OAuth2ClientBeanDefinitionParser(
 				requestCache, authenticationManager);
 		parser.parse(oauth2ClientElt, this.pc);
 
+		BeanDefinition defaultAuthorizedClientRepository = parser.getDefaultAuthorizedClientRepository();
+		registerDefaultAuthorizedClientRepositoryIfNecessary(defaultAuthorizedClientRepository);
+
 		this.authorizationRequestRedirectFilter = parser.getAuthorizationRequestRedirectFilter();
 		String authorizationRequestRedirectFilterId = pc.getReaderContext()
 				.generateBeanName(this.authorizationRequestRedirectFilter);
@@ -344,6 +363,28 @@ final class AuthenticationConfigBuilder {
 		this.authorizationCodeAuthenticationProviderRef = new RuntimeBeanReference(authorizationCodeAuthenticationProviderId);
 	}
 
+	void registerDefaultAuthorizedClientRepositoryIfNecessary(BeanDefinition defaultAuthorizedClientRepository) {
+		if (!this.defaultAuthorizedClientRepositoryRegistered && defaultAuthorizedClientRepository != null) {
+			String authorizedClientRepositoryId = pc.getReaderContext()
+					.generateBeanName(defaultAuthorizedClientRepository);
+			this.pc.registerBeanComponent(new BeanComponentDefinition(
+					defaultAuthorizedClientRepository, authorizedClientRepositoryId));
+			this.defaultAuthorizedClientRepositoryRegistered = true;
+		}
+	}
+
+	private void registerOAuth2ClientPostProcessors() {
+		if (!this.oauth2LoginEnabled && !this.oauth2ClientEnabled) {
+			return;
+		}
+
+		boolean webmvcPresent = ClassUtils.isPresent("org.springframework.web.servlet.DispatcherServlet", getClass().getClassLoader());
+		if (webmvcPresent) {
+			this.pc.getReaderContext().registerWithGeneratedName(
+					new RootBeanDefinition(OAuth2ClientWebMvcSecurityPostProcessor.class));
+		}
+	}
+
 	void createOpenIDLoginFilter(BeanReference sessionStrategy, BeanReference authManager) {
 		Element openIDLoginElt = DomUtils.getChildElementByTagName(httpElt,
 				Elements.OPENID_LOGIN);

+ 21 - 41
config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParser.java

@@ -23,27 +23,30 @@ import org.springframework.beans.factory.support.BeanDefinitionBuilder;
 import org.springframework.beans.factory.xml.BeanDefinitionParser;
 import org.springframework.beans.factory.xml.ParserContext;
 import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationProvider;
-import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter;
 import org.springframework.util.StringUtils;
 import org.springframework.util.xml.DomUtils;
 import org.w3c.dom.Element;
 
+import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.createAuthorizedClientRepository;
+import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.createDefaultAuthorizedClientRepository;
+import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientRepository;
+import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientService;
+import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getClientRegistrationRepository;
+
 /**
  * @author Joe Grandja
  * @since 5.3
  */
 final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser {
 	private static final String ELT_AUTHORIZATION_CODE_GRANT = "authorization-code-grant";
-	private static final String ATT_CLIENT_REGISTRATION_REPOSITORY_REF = "client-registration-repository-ref";
-	private static final String ATT_AUTHORIZED_CLIENT_REPOSITORY_REF = "authorized-client-repository-ref";
-	private static final String ATT_AUTHORIZED_CLIENT_SERVICE_REF = "authorized-client-service-ref";
 	private static final String ATT_AUTHORIZATION_REQUEST_REPOSITORY_REF = "authorization-request-repository-ref";
 	private static final String ATT_AUTHORIZATION_REQUEST_RESOLVER_REF = "authorization-request-resolver-ref";
 	private static final String ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF = "access-token-response-client-ref";
 	private final BeanReference requestCache;
 	private final BeanReference authenticationManager;
+	private BeanDefinition defaultAuthorizedClientRepository;
 	private BeanDefinition authorizationRequestRedirectFilter;
 	private BeanDefinition authorizationCodeGrantFilter;
 	private BeanDefinition authorizationCodeAuthenticationProvider;
@@ -58,8 +61,16 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser {
 		Element authorizationCodeGrantElt = DomUtils.getChildElementByTagName(element, ELT_AUTHORIZATION_CODE_GRANT);
 
 		BeanMetadataElement clientRegistrationRepository = getClientRegistrationRepository(element);
-		BeanMetadataElement authorizedClientRepository = getAuthorizedClientRepository(
-				element, clientRegistrationRepository);
+		BeanMetadataElement authorizedClientRepository = getAuthorizedClientRepository(element);
+		if (authorizedClientRepository == null) {
+			BeanMetadataElement authorizedClientService = getAuthorizedClientService(element);
+			if (authorizedClientService == null) {
+				this.defaultAuthorizedClientRepository = createDefaultAuthorizedClientRepository(clientRegistrationRepository);
+				authorizedClientRepository = this.defaultAuthorizedClientRepository;
+			} else {
+				authorizedClientRepository = createAuthorizedClientRepository(authorizedClientService);
+			}
+		}
 		BeanMetadataElement authorizationRequestRepository = getAuthorizationRequestRepository(
 				authorizationCodeGrantElt);
 
@@ -95,41 +106,6 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser {
 		return null;
 	}
 
-	private BeanMetadataElement getClientRegistrationRepository(Element element) {
-		BeanMetadataElement clientRegistrationRepository;
-		String clientRegistrationRepositoryRef = element.getAttribute(ATT_CLIENT_REGISTRATION_REPOSITORY_REF);
-		if (!StringUtils.isEmpty(clientRegistrationRepositoryRef)) {
-			clientRegistrationRepository = new RuntimeBeanReference(clientRegistrationRepositoryRef);
-		} else {
-			clientRegistrationRepository = new RuntimeBeanReference(ClientRegistrationRepository.class);
-		}
-		return clientRegistrationRepository;
-	}
-
-	private BeanMetadataElement getAuthorizedClientRepository(Element element,
-			BeanMetadataElement clientRegistrationRepository) {
-		BeanMetadataElement authorizedClientRepository;
-		String authorizedClientRepositoryRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_REPOSITORY_REF);
-		if (!StringUtils.isEmpty(authorizedClientRepositoryRef)) {
-			authorizedClientRepository = new RuntimeBeanReference(authorizedClientRepositoryRef);
-		} else {
-			BeanMetadataElement authorizedClientService;
-			String authorizedClientServiceRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_SERVICE_REF);
-			if (!StringUtils.isEmpty(authorizedClientServiceRef)) {
-				authorizedClientService = new RuntimeBeanReference(authorizedClientServiceRef);
-			} else {
-				authorizedClientService = BeanDefinitionBuilder
-						.rootBeanDefinition(
-								"org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService")
-						.addConstructorArgValue(clientRegistrationRepository).getBeanDefinition();
-			}
-			authorizedClientRepository = BeanDefinitionBuilder.rootBeanDefinition(
-					"org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository")
-					.addConstructorArgValue(authorizedClientService).getBeanDefinition();
-		}
-		return authorizedClientRepository;
-	}
-
 	private BeanMetadataElement getAuthorizationRequestRepository(Element element) {
 		BeanMetadataElement authorizationRequestRepository;
 		String authorizationRequestRepositoryRef = element != null ?
@@ -158,6 +134,10 @@ final class OAuth2ClientBeanDefinitionParser implements BeanDefinitionParser {
 		return accessTokenResponseClient;
 	}
 
+	BeanDefinition getDefaultAuthorizedClientRepository() {
+		return this.defaultAuthorizedClientRepository;
+	}
+
 	BeanDefinition getAuthorizationRequestRedirectFilter() {
 		return this.authorizationRequestRedirectFilter;
 	}

+ 79 - 0
config/src/main/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserUtils.java

@@ -0,0 +1,79 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.config.http;
+
+import org.springframework.beans.BeanMetadataElement;
+import org.springframework.beans.factory.config.BeanDefinition;
+import org.springframework.beans.factory.config.RuntimeBeanReference;
+import org.springframework.beans.factory.support.BeanDefinitionBuilder;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.util.StringUtils;
+import org.w3c.dom.Element;
+
+/**
+ * @author Joe Grandja
+ * @since 5.4
+ */
+final class OAuth2ClientBeanDefinitionParserUtils {
+	private static final String ATT_CLIENT_REGISTRATION_REPOSITORY_REF = "client-registration-repository-ref";
+	private static final String ATT_AUTHORIZED_CLIENT_REPOSITORY_REF = "authorized-client-repository-ref";
+	private static final String ATT_AUTHORIZED_CLIENT_SERVICE_REF = "authorized-client-service-ref";
+
+	static BeanMetadataElement getClientRegistrationRepository(Element element) {
+		BeanMetadataElement clientRegistrationRepository;
+		String clientRegistrationRepositoryRef = element.getAttribute(ATT_CLIENT_REGISTRATION_REPOSITORY_REF);
+		if (!StringUtils.isEmpty(clientRegistrationRepositoryRef)) {
+			clientRegistrationRepository = new RuntimeBeanReference(clientRegistrationRepositoryRef);
+		} else {
+			clientRegistrationRepository = new RuntimeBeanReference(ClientRegistrationRepository.class);
+		}
+		return clientRegistrationRepository;
+	}
+
+	static BeanMetadataElement getAuthorizedClientRepository(Element element) {
+		String authorizedClientRepositoryRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_REPOSITORY_REF);
+		if (!StringUtils.isEmpty(authorizedClientRepositoryRef)) {
+			return new RuntimeBeanReference(authorizedClientRepositoryRef);
+		}
+		return null;
+	}
+
+	static BeanMetadataElement getAuthorizedClientService(Element element) {
+		String authorizedClientServiceRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_SERVICE_REF);
+		if (!StringUtils.isEmpty(authorizedClientServiceRef)) {
+			return new RuntimeBeanReference(authorizedClientServiceRef);
+		}
+		return null;
+	}
+
+	static BeanMetadataElement createAuthorizedClientRepository(BeanMetadataElement authorizedClientService) {
+		return BeanDefinitionBuilder.rootBeanDefinition(
+				"org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository")
+				.addConstructorArgValue(authorizedClientService)
+				.getBeanDefinition();
+	}
+
+	static BeanDefinition createDefaultAuthorizedClientRepository(BeanMetadataElement clientRegistrationRepository) {
+		BeanDefinition authorizedClientService = BeanDefinitionBuilder.rootBeanDefinition(
+				"org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService")
+				.addConstructorArgValue(clientRegistrationRepository)
+				.getBeanDefinition();
+		return BeanDefinitionBuilder.rootBeanDefinition(
+				"org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository")
+				.addConstructorArgValue(authorizedClientService)
+				.getBeanDefinition();
+	}
+}

+ 91 - 0
config/src/main/java/org/springframework/security/config/http/OAuth2ClientWebMvcSecurityPostProcessor.java

@@ -0,0 +1,91 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.config.http;
+
+import org.springframework.beans.BeansException;
+import org.springframework.beans.PropertyValue;
+import org.springframework.beans.factory.BeanFactory;
+import org.springframework.beans.factory.BeanFactoryAware;
+import org.springframework.beans.factory.BeanFactoryUtils;
+import org.springframework.beans.factory.ListableBeanFactory;
+import org.springframework.beans.factory.config.BeanDefinition;
+import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
+import org.springframework.beans.factory.support.BeanDefinitionBuilder;
+import org.springframework.beans.factory.support.BeanDefinitionRegistry;
+import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor;
+import org.springframework.beans.factory.support.ManagedList;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClientManager;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.client.web.method.annotation.OAuth2AuthorizedClientArgumentResolver;
+import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter;
+
+/**
+ * @author Joe Grandja
+ * @since 5.4
+ */
+final class OAuth2ClientWebMvcSecurityPostProcessor implements BeanDefinitionRegistryPostProcessor, BeanFactoryAware {
+	private static final String ARGUMENT_RESOLVERS_PROPERTY = "argumentResolvers";
+	private BeanFactory beanFactory;
+
+	@Override
+	public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException {
+		String[] clientRegistrationRepositoryBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
+				(ListableBeanFactory) this.beanFactory, ClientRegistrationRepository.class, false, false);
+		String[] authorizedClientRepositoryBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
+				(ListableBeanFactory) this.beanFactory, OAuth2AuthorizedClientRepository.class, false, false);
+
+		if (clientRegistrationRepositoryBeanNames.length != 1 || authorizedClientRepositoryBeanNames.length != 1) {
+			return;
+		}
+
+		for (String beanName : registry.getBeanDefinitionNames()) {
+			BeanDefinition beanDefinition = registry.getBeanDefinition(beanName);
+			if (RequestMappingHandlerAdapter.class.getName().equals(beanDefinition.getBeanClassName())) {
+				PropertyValue currentArgumentResolvers =
+						beanDefinition.getPropertyValues().getPropertyValue(ARGUMENT_RESOLVERS_PROPERTY);
+				ManagedList<Object> argumentResolvers = new ManagedList<>();
+				if (currentArgumentResolvers != null) {
+					argumentResolvers.addAll((ManagedList<?>) currentArgumentResolvers.getValue());
+				}
+
+				String[] authorizedClientManagerBeanNames = BeanFactoryUtils.beanNamesForTypeIncludingAncestors(
+						(ListableBeanFactory) this.beanFactory, OAuth2AuthorizedClientManager.class, false, false);
+
+				BeanDefinitionBuilder beanDefinitionBuilder =
+						BeanDefinitionBuilder.genericBeanDefinition(OAuth2AuthorizedClientArgumentResolver.class);
+				if (authorizedClientManagerBeanNames.length == 1) {
+					beanDefinitionBuilder.addConstructorArgReference(authorizedClientManagerBeanNames[0]);
+				} else {
+					beanDefinitionBuilder.addConstructorArgReference(clientRegistrationRepositoryBeanNames[0]);
+					beanDefinitionBuilder.addConstructorArgReference(authorizedClientRepositoryBeanNames[0]);
+				}
+				argumentResolvers.add(beanDefinitionBuilder.getBeanDefinition());
+				beanDefinition.getPropertyValues().add(ARGUMENT_RESOLVERS_PROPERTY, argumentResolvers);
+				break;
+			}
+		}
+	}
+
+	@Override
+	public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException {
+	}
+
+	@Override
+	public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
+		this.beanFactory = beanFactory;
+	}
+}

+ 29 - 47
config/src/main/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParser.java

@@ -15,13 +15,6 @@
  */
 package org.springframework.security.config.http;
 
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.Map;
-
 import org.springframework.beans.BeanMetadataElement;
 import org.springframework.beans.BeansException;
 import org.springframework.beans.factory.config.BeanDefinition;
@@ -66,6 +59,19 @@ import org.springframework.web.accept.ContentNegotiationStrategy;
 import org.springframework.web.accept.HeaderContentNegotiationStrategy;
 import org.w3c.dom.Element;
 
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.createAuthorizedClientRepository;
+import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.createDefaultAuthorizedClientRepository;
+import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientRepository;
+import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getAuthorizedClientService;
+import static org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserUtils.getClientRegistrationRepository;
+
 /**
  * @author Ruby Hartono
  * @since 5.3
@@ -77,9 +83,6 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser {
 
 	private static final String ELT_CLIENT_REGISTRATION = "client-registration";
 	private static final String ATT_REGISTRATION_ID = "registration-id";
-	private static final String ATT_CLIENT_REGISTRATION_REPOSITORY_REF = "client-registration-repository-ref";
-	private static final String ATT_AUTHORIZED_CLIENT_REPOSITORY_REF = "authorized-client-repository-ref";
-	private static final String ATT_AUTHORIZED_CLIENT_SERVICE_REF = "authorized-client-service-ref";
 	private static final String ATT_AUTHORIZATION_REQUEST_REPOSITORY_REF = "authorization-request-repository-ref";
 	private static final String ATT_AUTHORIZATION_REQUEST_RESOLVER_REF = "authorization-request-resolver-ref";
 	private static final String ATT_ACCESS_TOKEN_RESPONSE_CLIENT_REF = "access-token-response-client-ref";
@@ -98,6 +101,8 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser {
 	private final BeanReference sessionStrategy;
 	private final boolean allowSessionCreation;
 
+	private BeanDefinition defaultAuthorizedClientRepository;
+
 	private BeanDefinition oauth2AuthorizationRequestRedirectFilter;
 
 	private BeanDefinition oauth2LoginAuthenticationEntryPoint;
@@ -128,8 +133,16 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser {
 
 		// configure filter
 		BeanMetadataElement clientRegistrationRepository = getClientRegistrationRepository(element);
-		BeanMetadataElement authorizedClientRepository = getAuthorizedClientRepository(element,
-				clientRegistrationRepository);
+		BeanMetadataElement authorizedClientRepository = getAuthorizedClientRepository(element);
+		if (authorizedClientRepository == null) {
+			BeanMetadataElement authorizedClientService = getAuthorizedClientService(element);
+			if (authorizedClientService == null) {
+				this.defaultAuthorizedClientRepository = createDefaultAuthorizedClientRepository(clientRegistrationRepository);
+				authorizedClientRepository = this.defaultAuthorizedClientRepository;
+			} else {
+				authorizedClientRepository = createAuthorizedClientRepository(authorizedClientService);
+			}
+		}
 		BeanMetadataElement accessTokenResponseClient = getAccessTokenResponseClient(element);
 		BeanMetadataElement oauth2UserService = getOAuth2UserService(element);
 		BeanMetadataElement authorizationRequestRepository = getAuthorizationRequestRepository(element);
@@ -251,41 +264,6 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser {
 		return authorizationRequestRepository;
 	}
 
-	private BeanMetadataElement getAuthorizedClientRepository(Element element,
-			BeanMetadataElement clientRegistrationRepository) {
-		BeanMetadataElement authorizedClientRepository;
-		String authorizedClientRepositoryRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_REPOSITORY_REF);
-		if (!StringUtils.isEmpty(authorizedClientRepositoryRef)) {
-			authorizedClientRepository = new RuntimeBeanReference(authorizedClientRepositoryRef);
-		} else {
-			BeanMetadataElement authorizedClientService;
-			String authorizedClientServiceRef = element.getAttribute(ATT_AUTHORIZED_CLIENT_SERVICE_REF);
-			if (!StringUtils.isEmpty(authorizedClientServiceRef)) {
-				authorizedClientService = new RuntimeBeanReference(authorizedClientServiceRef);
-			} else {
-				authorizedClientService = BeanDefinitionBuilder
-						.rootBeanDefinition(
-								"org.springframework.security.oauth2.client.InMemoryOAuth2AuthorizedClientService")
-						.addConstructorArgValue(clientRegistrationRepository).getBeanDefinition();
-			}
-			authorizedClientRepository = BeanDefinitionBuilder.rootBeanDefinition(
-					"org.springframework.security.oauth2.client.web.AuthenticatedPrincipalOAuth2AuthorizedClientRepository")
-					.addConstructorArgValue(authorizedClientService).getBeanDefinition();
-		}
-		return authorizedClientRepository;
-	}
-
-	private BeanMetadataElement getClientRegistrationRepository(Element element) {
-		BeanMetadataElement clientRegistrationRepository;
-		String clientRegistrationRepositoryRef = element.getAttribute(ATT_CLIENT_REGISTRATION_REPOSITORY_REF);
-		if (!StringUtils.isEmpty(clientRegistrationRepositoryRef)) {
-			clientRegistrationRepository = new RuntimeBeanReference(clientRegistrationRepositoryRef);
-		} else {
-			clientRegistrationRepository = new RuntimeBeanReference(ClientRegistrationRepository.class);
-		}
-		return clientRegistrationRepository;
-	}
-
 	private BeanDefinition getOidcAuthProvider(Element element,
 			BeanMetadataElement accessTokenResponseClient, String userAuthoritiesMapperRef) {
 
@@ -353,6 +331,10 @@ final class OAuth2LoginBeanDefinitionParser implements BeanDefinitionParser {
 		return accessTokenResponseClient;
 	}
 
+	BeanDefinition getDefaultAuthorizedClientRepository() {
+		return this.defaultAuthorizedClientRepository;
+	}
+
 	BeanDefinition getOAuth2AuthorizationRequestRedirectFilter() {
 		return oauth2AuthorizationRequestRedirectFilter;
 	}

+ 31 - 0
config/src/test/java/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests.java

@@ -24,6 +24,7 @@ import org.springframework.security.config.oauth2.client.CommonOAuth2Provider;
 import org.springframework.security.config.test.SpringTestRule;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
@@ -31,6 +32,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -41,6 +43,8 @@ import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.MvcResult;
 import org.springframework.util.LinkedMultiValueMap;
 import org.springframework.util.MultiValueMap;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.RestController;
 
 import java.util.HashMap;
 import java.util.Map;
@@ -51,6 +55,7 @@ import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.accessTokenResponse;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
+import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
 
@@ -200,6 +205,32 @@ public class OAuth2ClientBeanDefinitionParserTests {
 		verify(this.authorizedClientService).saveAuthorizedClient(any(), any());
 	}
 
+	@WithMockUser
+	@Test
+	public void requestWhenAuthorizedClientFoundThenMethodArgumentResolved() throws Exception {
+		this.spring.configLocations(xml("AuthorizedClientArgumentResolver")).autowire();
+
+		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google");
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				clientRegistration, "user", TestOAuth2AccessTokens.noScopes());
+		when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any()))
+				.thenReturn(authorizedClient);
+
+		this.mvc.perform(get("/authorized-client"))
+				.andExpect(status().isOk())
+				.andExpect(content().string("resolved"));
+	}
+
+	@RestController
+	static class AuthorizedClientController {
+
+		@GetMapping("/authorized-client")
+		String authorizedClient(@RegisteredOAuth2AuthorizedClient("google") OAuth2AuthorizedClient authorizedClient) {
+			return authorizedClient != null ? "resolved" : "not-resolved";
+		}
+	}
+
 	private static OAuth2AuthorizationRequest createAuthorizationRequest(ClientRegistration clientRegistration) {
 		Map<String, Object> attributes = new HashMap<>();
 		attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());

+ 42 - 4
config/src/test/java/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests.java

@@ -17,6 +17,7 @@ package org.springframework.security.config.http;
 
 import org.junit.Rule;
 import org.junit.Test;
+import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.ApplicationListener;
@@ -28,7 +29,9 @@ import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMapper;
+import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService;
+import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AccessTokenResponseClient;
 import org.springframework.security.oauth2.client.endpoint.OAuth2AuthorizationCodeGrantRequest;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
@@ -40,6 +43,7 @@ import org.springframework.security.oauth2.client.web.AuthorizationRequestReposi
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver;
 import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
 import org.springframework.security.oauth2.core.OAuth2AuthenticationException;
+import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
 import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
@@ -50,13 +54,22 @@ import org.springframework.security.oauth2.core.user.TestOAuth2Users;
 import org.springframework.security.oauth2.jwt.Jwt;
 import org.springframework.security.oauth2.jwt.JwtDecoderFactory;
 import org.springframework.security.oauth2.jwt.TestJwts;
+import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners;
+import org.springframework.security.test.context.support.WithMockUser;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.AuthenticationSuccessHandler;
 import org.springframework.security.web.savedrequest.RequestCache;
+import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.MvcResult;
 import org.springframework.util.LinkedMultiValueMap;
 import org.springframework.util.MultiValueMap;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.RestController;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
@@ -66,18 +79,17 @@ import static org.mockito.Mockito.when;
 import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.accessTokenResponse;
 import static org.springframework.security.oauth2.core.endpoint.TestOAuth2AccessTokenResponses.oidcAccessTokenResponse;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
+import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
 
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.Map;
-
 /**
  * Tests for {@link OAuth2LoginBeanDefinitionParser}.
  *
  * @author Ruby Hartono
  */
+@RunWith(SpringJUnit4ClassRunner.class)
+@SecurityTestExecutionListeners
 public class OAuth2LoginBeanDefinitionParserTests {
 	private static final String CONFIG_LOCATION_PREFIX = "classpath:org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests";
 
@@ -489,6 +501,32 @@ public class OAuth2LoginBeanDefinitionParserTests {
 		verify(authorizedClientService).saveAuthorizedClient(any(), any());
 	}
 
+	@WithMockUser
+	@Test
+	public void requestWhenAuthorizedClientFoundThenMethodArgumentResolved() throws Exception {
+		this.spring.configLocations(xml("AuthorizedClientArgumentResolver")).autowire();
+
+		ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId("google-login");
+
+		OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
+				clientRegistration, "user", TestOAuth2AccessTokens.noScopes());
+		when(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any()))
+				.thenReturn(authorizedClient);
+
+		this.mvc.perform(get("/authorized-client"))
+				.andExpect(status().isOk())
+				.andExpect(content().string("resolved"));
+	}
+
+	@RestController
+	static class AuthorizedClientController {
+
+		@GetMapping("/authorized-client")
+		String authorizedClient(@RegisteredOAuth2AuthorizedClient("google") OAuth2AuthorizedClient authorizedClient) {
+			return authorizedClient != null ? "resolved" : "not-resolved";
+		}
+	}
+
 	private String xml(String configName) {
 		return CONFIG_LOCATION_PREFIX + "-" + configName + ".xml";
 	}

+ 52 - 0
config/src/test/resources/org/springframework/security/config/http/OAuth2ClientBeanDefinitionParserTests-AuthorizedClientArgumentResolver.xml

@@ -0,0 +1,52 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+  ~ Copyright 2002-2020 the original author or authors.
+  ~
+  ~ Licensed under the Apache License, Version 2.0 (the "License");
+  ~ you may not use this file except in compliance with the License.
+  ~ You may obtain a copy of the License at
+  ~
+  ~       https://www.apache.org/licenses/LICENSE-2.0
+  ~
+  ~ Unless required by applicable law or agreed to in writing, software
+  ~ distributed under the License is distributed on an "AS IS" BASIS,
+  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  ~ See the License for the specific language governing permissions and
+  ~ limitations under the License.
+  -->
+
+<b:beans xmlns:b="http://www.springframework.org/schema/beans"
+		 xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+		 xmlns:mvc="http://www.springframework.org/schema/mvc"
+		 xmlns="http://www.springframework.org/schema/security"
+		xsi:schemaLocation="
+			http://www.springframework.org/schema/security
+			https://www.springframework.org/schema/security/spring-security.xsd
+			http://www.springframework.org/schema/beans
+			https://www.springframework.org/schema/beans/spring-beans.xsd
+			http://www.springframework.org/schema/mvc
+			https://www.springframework.org/schema/mvc/spring-mvc.xsd">
+
+	<http auto-config="true">
+		<oauth2-client authorized-client-repository-ref="authorizedClientRepository" />
+	</http>
+
+	<mvc:annotation-driven />
+
+	<client-registrations>
+		<client-registration registration-id="google"
+							 client-id="google-client-id"
+							 client-secret="google-client-secret"
+							 redirect-uri="http://localhost/callback/google"
+							 scope="scope1,scope2"
+							 provider-id="google"/>
+	</client-registrations>
+
+	<b:bean id="authorizedClientRepository" class="org.mockito.Mockito" factory-method="mock">
+		<b:constructor-arg value="org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository"/>
+	</b:bean>
+
+	<b:bean name="authorizedClientController" class="org.springframework.security.config.http.OAuth2ClientBeanDefinitionParserTests.AuthorizedClientController" />
+
+	<b:import resource="userservice.xml"/>
+</b:beans>

+ 45 - 0
config/src/test/resources/org/springframework/security/config/http/OAuth2LoginBeanDefinitionParserTests-AuthorizedClientArgumentResolver.xml

@@ -0,0 +1,45 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+  ~ Copyright 2002-2020 the original author or authors.
+  ~
+  ~ Licensed under the Apache License, Version 2.0 (the "License");
+  ~ you may not use this file except in compliance with the License.
+  ~ You may obtain a copy of the License at
+  ~
+  ~       https://www.apache.org/licenses/LICENSE-2.0
+  ~
+  ~ Unless required by applicable law or agreed to in writing, software
+  ~ distributed under the License is distributed on an "AS IS" BASIS,
+  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  ~ See the License for the specific language governing permissions and
+  ~ limitations under the License.
+  -->
+
+<b:beans xmlns:b="http://www.springframework.org/schema/beans"
+		 xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+		 xmlns:mvc="http://www.springframework.org/schema/mvc"
+		 xmlns="http://www.springframework.org/schema/security"
+		xsi:schemaLocation="
+			http://www.springframework.org/schema/security
+			https://www.springframework.org/schema/security/spring-security.xsd
+			http://www.springframework.org/schema/beans
+			https://www.springframework.org/schema/beans/spring-beans.xsd
+			http://www.springframework.org/schema/mvc
+			https://www.springframework.org/schema/mvc/spring-mvc.xsd">
+
+	<http auto-config="true">
+		<intercept-url pattern="/**" access="authenticated"/>
+		<oauth2-login authorized-client-repository-ref="authorizedClientRepository" />
+	</http>
+
+	<mvc:annotation-driven />
+
+	<b:bean id="authorizedClientRepository" class="org.mockito.Mockito" factory-method="mock">
+		<b:constructor-arg value="org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository"/>
+	</b:bean>
+
+	<b:bean name="authorizedClientController" class="org.springframework.security.config.http.OAuth2LoginBeanDefinitionParserTests.AuthorizedClientController" />
+
+	<b:import resource="../oauth2/client/google-github-registration.xml"/>
+	<b:import resource="userservice.xml"/>
+</b:beans>