Bladeren bron

Resolve oauth2 client placeholders

Closes gh-8453
Evgeniy Cheban 5 jaren geleden
bovenliggende
commit
17f1540280

+ 41 - 35
config/src/main/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParser.java

@@ -42,6 +42,7 @@ import org.springframework.util.xml.DomUtils;
 
 /**
  * @author Ruby Hartono
+ * @author Evgeniy Cheban
  * @since 5.3
  */
 public final class ClientRegistrationsBeanDefinitionParser implements BeanDefinitionParser {
@@ -87,7 +88,7 @@ public final class ClientRegistrationsBeanDefinitionParser implements BeanDefini
 		CompositeComponentDefinition compositeDef = new CompositeComponentDefinition(element.getTagName(),
 				parserContext.extractSource(element));
 		parserContext.pushContainingComponent(compositeDef);
-		Map<String, Map<String, String>> providers = getProviders(element);
+		Map<String, Map<String, String>> providers = getProviders(element, parserContext);
 		List<ClientRegistration> clientRegistrations = getClientRegistrations(element, parserContext, providers);
 		BeanDefinition clientRegistrationRepositoryBean = BeanDefinitionBuilder
 				.rootBeanDefinition(InMemoryClientRegistrationRepository.class)
@@ -107,9 +108,10 @@ public final class ClientRegistrationsBeanDefinitionParser implements BeanDefini
 		for (Element clientRegistrationElt : clientRegistrationElts) {
 			String registrationId = clientRegistrationElt.getAttribute(ATT_REGISTRATION_ID);
 			String providerId = clientRegistrationElt.getAttribute(ATT_PROVIDER_ID);
-			ClientRegistration.Builder builder = getBuilderFromIssuerIfPossible(registrationId, providerId, providers);
+			ClientRegistration.Builder builder = getBuilderFromIssuerIfPossible(parserContext, registrationId,
+					providerId, providers);
 			if (builder == null) {
-				builder = getBuilder(registrationId, providerId, providers);
+				builder = getBuilder(parserContext, registrationId, providerId, providers);
 				if (builder == null) {
 					Object source = parserContext.extractSource(element);
 					parserContext.getReaderContext().error(getErrorMessage(providerId, registrationId), source);
@@ -117,50 +119,53 @@ public final class ClientRegistrationsBeanDefinitionParser implements BeanDefini
 					continue;
 				}
 			}
-			getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_ID)).ifPresent(builder::clientId);
-			getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_SECRET))
+			getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_CLIENT_ID))
+					.ifPresent(builder::clientId);
+			getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_CLIENT_SECRET))
 					.ifPresent(builder::clientSecret);
-			getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_AUTHENTICATION_METHOD))
+			getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_CLIENT_AUTHENTICATION_METHOD))
 					.map(ClientAuthenticationMethod::new).ifPresent(builder::clientAuthenticationMethod);
-			getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_AUTHORIZATION_GRANT_TYPE))
+			getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_AUTHORIZATION_GRANT_TYPE))
 					.map(AuthorizationGrantType::new).ifPresent(builder::authorizationGrantType);
-			getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_REDIRECT_URI)).ifPresent(builder::redirectUri);
-			getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_SCOPE))
+			getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_REDIRECT_URI))
+					.ifPresent(builder::redirectUri);
+			getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_SCOPE))
 					.map(StringUtils::commaDelimitedListToSet).ifPresent(builder::scope);
-			getOptionalIfNotEmpty(clientRegistrationElt.getAttribute(ATT_CLIENT_NAME)).ifPresent(builder::clientName);
+			getOptionalIfNotEmpty(parserContext, clientRegistrationElt.getAttribute(ATT_CLIENT_NAME))
+					.ifPresent(builder::clientName);
 			clientRegistrations.add(builder.build());
 		}
 		return clientRegistrations;
 	}
 
-	private Map<String, Map<String, String>> getProviders(Element element) {
+	private Map<String, Map<String, String>> getProviders(Element element, ParserContext parserContext) {
 		List<Element> providerElts = DomUtils.getChildElementsByTagName(element, ELT_PROVIDER);
 		Map<String, Map<String, String>> providers = new HashMap<>();
 		for (Element providerElt : providerElts) {
 			Map<String, String> provider = new HashMap<>();
 			String providerId = providerElt.getAttribute(ATT_PROVIDER_ID);
 			provider.put(ATT_PROVIDER_ID, providerId);
-			getOptionalIfNotEmpty(providerElt.getAttribute(ATT_AUTHORIZATION_URI))
+			getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_AUTHORIZATION_URI))
 					.ifPresent((value) -> provider.put(ATT_AUTHORIZATION_URI, value));
-			getOptionalIfNotEmpty(providerElt.getAttribute(ATT_TOKEN_URI))
+			getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_TOKEN_URI))
 					.ifPresent((value) -> provider.put(ATT_TOKEN_URI, value));
-			getOptionalIfNotEmpty(providerElt.getAttribute(ATT_USER_INFO_URI))
+			getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_USER_INFO_URI))
 					.ifPresent((value) -> provider.put(ATT_USER_INFO_URI, value));
-			getOptionalIfNotEmpty(providerElt.getAttribute(ATT_USER_INFO_AUTHENTICATION_METHOD))
+			getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_USER_INFO_AUTHENTICATION_METHOD))
 					.ifPresent((value) -> provider.put(ATT_USER_INFO_AUTHENTICATION_METHOD, value));
-			getOptionalIfNotEmpty(providerElt.getAttribute(ATT_USER_INFO_USER_NAME_ATTRIBUTE))
+			getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_USER_INFO_USER_NAME_ATTRIBUTE))
 					.ifPresent((value) -> provider.put(ATT_USER_INFO_USER_NAME_ATTRIBUTE, value));
-			getOptionalIfNotEmpty(providerElt.getAttribute(ATT_JWK_SET_URI))
+			getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_JWK_SET_URI))
 					.ifPresent((value) -> provider.put(ATT_JWK_SET_URI, value));
-			getOptionalIfNotEmpty(providerElt.getAttribute(ATT_ISSUER_URI))
+			getOptionalIfNotEmpty(parserContext, providerElt.getAttribute(ATT_ISSUER_URI))
 					.ifPresent((value) -> provider.put(ATT_ISSUER_URI, value));
 			providers.put(providerId, provider);
 		}
 		return providers;
 	}
 
-	private static ClientRegistration.Builder getBuilderFromIssuerIfPossible(String registrationId,
-			String configuredProviderId, Map<String, Map<String, String>> providers) {
+	private static ClientRegistration.Builder getBuilderFromIssuerIfPossible(ParserContext parserContext,
+			String registrationId, String configuredProviderId, Map<String, Map<String, String>> providers) {
 		String providerId = (configuredProviderId != null) ? configuredProviderId : registrationId;
 		if (providers.containsKey(providerId)) {
 			Map<String, String> provider = providers.get(providerId);
@@ -168,14 +173,14 @@ public final class ClientRegistrationsBeanDefinitionParser implements BeanDefini
 			if (!StringUtils.isEmpty(issuer)) {
 				ClientRegistration.Builder builder = ClientRegistrations.fromIssuerLocation(issuer)
 						.registrationId(registrationId);
-				return getBuilder(builder, provider);
+				return getBuilder(parserContext, builder, provider);
 			}
 		}
 		return null;
 	}
 
-	private static ClientRegistration.Builder getBuilder(String registrationId, String configuredProviderId,
-			Map<String, Map<String, String>> providers) {
+	private static ClientRegistration.Builder getBuilder(ParserContext parserContext, String registrationId,
+			String configuredProviderId, Map<String, Map<String, String>> providers) {
 		String providerId = (configuredProviderId != null) ? configuredProviderId : registrationId;
 		CommonOAuth2Provider provider = getCommonProvider(providerId);
 		if (provider == null && !providers.containsKey(providerId)) {
@@ -184,26 +189,27 @@ public final class ClientRegistrationsBeanDefinitionParser implements BeanDefini
 		ClientRegistration.Builder builder = (provider != null) ? provider.getBuilder(registrationId)
 				: ClientRegistration.withRegistrationId(registrationId);
 		if (providers.containsKey(providerId)) {
-			return getBuilder(builder, providers.get(providerId));
+			return getBuilder(parserContext, builder, providers.get(providerId));
 		}
 		return builder;
 	}
 
-	private static ClientRegistration.Builder getBuilder(ClientRegistration.Builder builder,
-			Map<String, String> provider) {
-		getOptionalIfNotEmpty(provider.get(ATT_AUTHORIZATION_URI)).ifPresent(builder::authorizationUri);
-		getOptionalIfNotEmpty(provider.get(ATT_TOKEN_URI)).ifPresent(builder::tokenUri);
-		getOptionalIfNotEmpty(provider.get(ATT_USER_INFO_URI)).ifPresent(builder::userInfoUri);
-		getOptionalIfNotEmpty(provider.get(ATT_USER_INFO_AUTHENTICATION_METHOD)).map(AuthenticationMethod::new)
-				.ifPresent(builder::userInfoAuthenticationMethod);
-		getOptionalIfNotEmpty(provider.get(ATT_JWK_SET_URI)).ifPresent(builder::jwkSetUri);
-		getOptionalIfNotEmpty(provider.get(ATT_USER_INFO_USER_NAME_ATTRIBUTE))
+	private static ClientRegistration.Builder getBuilder(ParserContext parserContext,
+			ClientRegistration.Builder builder, Map<String, String> provider) {
+		getOptionalIfNotEmpty(parserContext, provider.get(ATT_AUTHORIZATION_URI)).ifPresent(builder::authorizationUri);
+		getOptionalIfNotEmpty(parserContext, provider.get(ATT_TOKEN_URI)).ifPresent(builder::tokenUri);
+		getOptionalIfNotEmpty(parserContext, provider.get(ATT_USER_INFO_URI)).ifPresent(builder::userInfoUri);
+		getOptionalIfNotEmpty(parserContext, provider.get(ATT_USER_INFO_AUTHENTICATION_METHOD))
+				.map(AuthenticationMethod::new).ifPresent(builder::userInfoAuthenticationMethod);
+		getOptionalIfNotEmpty(parserContext, provider.get(ATT_JWK_SET_URI)).ifPresent(builder::jwkSetUri);
+		getOptionalIfNotEmpty(parserContext, provider.get(ATT_USER_INFO_USER_NAME_ATTRIBUTE))
 				.ifPresent(builder::userNameAttributeName);
 		return builder;
 	}
 
-	private static Optional<String> getOptionalIfNotEmpty(String str) {
-		return Optional.ofNullable(str).filter((s) -> !s.isEmpty());
+	private static Optional<String> getOptionalIfNotEmpty(ParserContext parserContext, String str) {
+		return Optional.ofNullable(str).filter((s) -> !s.isEmpty())
+				.map(parserContext.getReaderContext().getEnvironment()::resolvePlaceholders);
 	}
 
 	private static CommonOAuth2Provider getCommonProvider(String providerId) {

+ 15 - 0
config/src/test/java/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParserTests.java

@@ -41,6 +41,7 @@ import static org.assertj.core.api.Assertions.assertThat;
  * Tests for {@link ClientRegistrationsBeanDefinitionParser}.
  *
  * @author Ruby Hartono
+ * @author Evgeniy Cheban
  */
 public class ClientRegistrationsBeanDefinitionParserTests {
 
@@ -218,6 +219,20 @@ public class ClientRegistrationsBeanDefinitionParserTests {
 		assertThat(githubProviderDetails.getUserInfoEndpoint().getUserNameAttributeName()).isEqualTo("id");
 	}
 
+	@Test
+	public void parseWhenClientPlaceholdersThenResolvePlaceholders() {
+		System.setProperty("oauth2.client.id", "github-client-id");
+		System.setProperty("oauth2.client.secret", "github-client-secret");
+
+		this.spring.configLocations(xml("ClientPlaceholders")).autowire();
+
+		assertThat(this.clientRegistrationRepository).isInstanceOf(InMemoryClientRegistrationRepository.class);
+
+		ClientRegistration githubRegistration = this.clientRegistrationRepository.findByRegistrationId("github");
+		assertThat(githubRegistration.getClientId()).isEqualTo("github-client-id");
+		assertThat(githubRegistration.getClientSecret()).isEqualTo("github-client-secret");
+	}
+
 	private static MockResponse jsonResponse(String json) {
 		return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json);
 	}

+ 32 - 0
config/src/test/resources/org/springframework/security/config/oauth2/client/ClientRegistrationsBeanDefinitionParserTests-ClientPlaceholders.xml

@@ -0,0 +1,32 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+  ~ Copyright 2002-2020 the original author or authors.
+  ~
+  ~ Licensed under the Apache License, Version 2.0 (the "License");
+  ~ you may not use this file except in compliance with the License.
+  ~ You may obtain a copy of the License at
+  ~
+  ~       https://www.apache.org/licenses/LICENSE-2.0
+  ~
+  ~ Unless required by applicable law or agreed to in writing, software
+  ~ distributed under the License is distributed on an "AS IS" BASIS,
+  ~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+  ~ See the License for the specific language governing permissions and
+  ~ limitations under the License.
+  -->
+
+<b:beans xmlns:b="http://www.springframework.org/schema/beans"
+		 xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+		 xmlns="http://www.springframework.org/schema/security"
+		 xsi:schemaLocation="
+			http://www.springframework.org/schema/security
+			https://www.springframework.org/schema/security/spring-security.xsd
+			http://www.springframework.org/schema/beans
+			https://www.springframework.org/schema/beans/spring-beans.xsd">
+	<client-registrations>
+		<client-registration registration-id="github"
+							 client-id="${oauth2.client.id}"
+							 client-secret="${oauth2.client.secret}"
+							 provider-id="github"/>
+	</client-registrations>
+</b:beans>