瀏覽代碼

Default login page supports Iterable<ClientRegistration>

Fixes gh-4596
Joe Grandja 8 年之前
父節點
當前提交
66647070ab

+ 15 - 53
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java

@@ -15,7 +15,6 @@
  */
 package org.springframework.security.config.annotation.web.configurers.oauth2.client;
 
-import org.springframework.beans.factory.BeanFactoryUtils;
 import org.springframework.context.ApplicationContext;
 import org.springframework.core.ResolvableType;
 import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
@@ -39,15 +38,12 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.security.web.util.matcher.RequestVariablesExtractor;
 import org.springframework.util.Assert;
-import org.springframework.util.CollectionUtils;
 
 import java.net.URI;
-import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collection;
-import java.util.List;
+import java.util.HashMap;
 import java.util.Map;
-import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import static org.springframework.security.oauth2.client.web.AuthorizationCodeRequestRedirectFilter.REGISTRATION_ID_URI_VARIABLE_NAME;
 
@@ -75,7 +71,6 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
 
 	public OAuth2LoginConfigurer<H> clients(ClientRegistration... clientRegistrations) {
 		Assert.notEmpty(clientRegistrations, "clientRegistrations cannot be empty");
-		this.getBuilder().setSharedObject(ClientRegistration[].class, clientRegistrations);
 		return this.clients(new InMemoryClientRegistrationRepository(Arrays.asList(clientRegistrations)));
 	}
 
@@ -230,56 +225,24 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
 	}
 
 	private static <H extends HttpSecurityBuilder<H>> ClientRegistrationRepository getDefaultClientRegistrationRepository(H http) {
-		List<ClientRegistration> clientRegistrations = getClientRegistrations(http);
-		if (!CollectionUtils.isEmpty(clientRegistrations)) {
-			return new InMemoryClientRegistrationRepository(clientRegistrations);
-		}
 		return http.getSharedObject(ApplicationContext.class).getBean(ClientRegistrationRepository.class);
 	}
 
-	private static <H extends HttpSecurityBuilder<H>> List<ClientRegistration> getClientRegistrations(H http) {
-		ClientRegistration[] clientRegistrations = http.getSharedObject(ClientRegistration[].class);
-		if (clientRegistrations != null) {
-			return Arrays.asList(clientRegistrations);
-		}
-
-		List<ClientRegistration> clientRegistrationsList = new ArrayList<>();
-
-		// Check context for type -> Collection<ClientRegistration>
-		ResolvableType clientRegistrationsType = ResolvableType.forClassWithGenerics(
-			Collection.class, ClientRegistration.class);
-		Map<String, ?> clientRegistrationsMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
-			http.getSharedObject(ApplicationContext.class),
-			clientRegistrationsType.resolve(Collection.class));
-		clientRegistrationsMap.values().stream()
-			.filter(col -> Collection.class.isAssignableFrom(col.getClass()))
-			.filter(col -> ((Collection) col).stream()
-				.anyMatch(e -> ClientRegistration.class.isAssignableFrom(e.getClass())))
-			.flatMap(col -> ((Collection) col).stream())
-			.forEach(e -> clientRegistrationsList.add((ClientRegistration)e));
-		if (!clientRegistrationsList.isEmpty()) {
-			return clientRegistrationsList;
-		}
-
-		// Check context for type -> ClientRegistration[]
-		clientRegistrationsType = ResolvableType.forClass(ClientRegistration[].class);
-		clientRegistrationsMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
-			http.getSharedObject(ApplicationContext.class),
-			clientRegistrationsType.resolve(ClientRegistration[].class));
-		clientRegistrationsMap.values().stream()
-			.flatMap(array -> Arrays.stream((ClientRegistration[])array))
-			.forEach(clientRegistrationsList::add);
-
-		return clientRegistrationsList;
-	}
-
 	private void initDefaultLoginFilter(H http) {
 		DefaultLoginPageGeneratingFilter loginPageGeneratingFilter = http.getSharedObject(DefaultLoginPageGeneratingFilter.class);
 		if (loginPageGeneratingFilter == null || this.authorizationCodeAuthenticationFilterConfigurer.isCustomLoginPage()) {
 			return;
 		}
-		List<ClientRegistration> clientRegistrations = getClientRegistrations(http);
-		if (CollectionUtils.isEmpty(clientRegistrations)) {
+
+		Iterable<ClientRegistration> clientRegistrations = null;
+		ClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(http);
+		ResolvableType type = ResolvableType.forInstance(clientRegistrationRepository).as(Iterable.class);
+		if (type != ResolvableType.NONE) {
+			if (Stream.of(type.resolveGenerics()).anyMatch(ClientRegistration.class::isAssignableFrom)) {
+				clientRegistrations = (Iterable<ClientRegistration>) clientRegistrationRepository;
+			}
+		}
+		if (clientRegistrations == null) {
 			return;
 		}
 
@@ -298,10 +261,9 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
 			authorizationRequestBaseUri = AuthorizationCodeRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
 		}
 
-		Map<String, String> oauth2AuthenticationUrlToClientName = clientRegistrations.stream()
-			.collect(Collectors.toMap(
-				e -> authorizationRequestBaseUri + "/" + e.getRegistrationId(),
-				e -> e.getClientName()));
+		Map<String, String> oauth2AuthenticationUrlToClientName = new HashMap<>();
+		clientRegistrations.forEach(registration -> oauth2AuthenticationUrlToClientName.put(
+			authorizationRequestBaseUri + "/" + registration.getRegistrationId(), registration.getClientName()));
 		loginPageGeneratingFilter.setOauth2LoginEnabled(true);
 		loginPageGeneratingFilter.setOauth2AuthenticationUrlToClientName(oauth2AuthenticationUrlToClientName);
 		loginPageGeneratingFilter.setLoginPageUrl(this.authorizationCodeAuthenticationFilterConfigurer.getLoginUrl());

+ 7 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java

@@ -19,6 +19,7 @@ import org.springframework.util.Assert;
 
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 
@@ -29,7 +30,7 @@ import java.util.Map;
  * @since 5.0
  * @see ClientRegistration
  */
-public final class InMemoryClientRegistrationRepository implements ClientRegistrationRepository {
+public final class InMemoryClientRegistrationRepository implements ClientRegistrationRepository, Iterable<ClientRegistration> {
 	private final ClientRegistrationIdentifierStrategy<String> identifierStrategy = new RegistrationIdIdentifierStrategy();
 	private final Map<String, ClientRegistration> registrations;
 
@@ -54,4 +55,9 @@ public final class InMemoryClientRegistrationRepository implements ClientRegistr
 			.findFirst()
 			.orElse(null);
 	}
+
+	@Override
+	public Iterator<ClientRegistration> iterator() {
+		return Collections.unmodifiableCollection(this.registrations.values()).iterator();
+	}
 }

+ 0 - 5
samples/boot/oauth2login/src/integration-test/java/org/springframework/security/samples/OAuth2LoginApplicationTests.java

@@ -39,7 +39,6 @@ import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.oauth2.client.authentication.AuthorizationCodeAuthenticationToken;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
-import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
 import org.springframework.security.oauth2.client.user.OAuth2UserService;
 import org.springframework.security.oauth2.client.web.AuthorizationCodeAuthenticationProcessingFilter;
 import org.springframework.security.oauth2.client.web.AuthorizationCodeRequestRedirectFilter;
@@ -58,7 +57,6 @@ import org.springframework.web.util.UriComponentsBuilder;
 import java.net.URI;
 import java.net.URL;
 import java.net.URLDecoder;
-import java.util.Arrays;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
@@ -91,8 +89,6 @@ public class OAuth2LoginApplicationTests {
 	private WebClient webClient;
 
 	@Autowired
-	private ClientRegistration[] clientRegistrations;
-
 	private ClientRegistrationRepository clientRegistrationRepository;
 
 	private ClientRegistration googleClientRegistration;
@@ -103,7 +99,6 @@ public class OAuth2LoginApplicationTests {
 	@Before
 	public void setup() {
 		this.webClient.getCookieManager().clearCookies();
-		this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(Arrays.asList(this.clientRegistrations));
 		this.googleClientRegistration = this.clientRegistrationRepository.findByRegistrationId("google");
 		this.githubClientRegistration = this.clientRegistrationRepository.findByRegistrationId("github");
 		this.facebookClientRegistration = this.clientRegistrationRepository.findByRegistrationId("facebook");

+ 6 - 4
samples/boot/oauth2login/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/client/ClientRegistrationAutoConfiguration.java

@@ -40,6 +40,8 @@ import org.springframework.core.io.ClassPathResource;
 import org.springframework.core.type.AnnotatedTypeMetadata;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationProperties;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.InMemoryClientRegistrationRepository;
 import org.springframework.util.CollectionUtils;
 
 import java.util.ArrayList;
@@ -54,8 +56,8 @@ import java.util.stream.Collectors;
  */
 @Configuration
 @ConditionalOnWebApplication
-@ConditionalOnClass(ClientRegistration.class)
-@ConditionalOnMissingBean(ClientRegistration.class)
+@ConditionalOnClass(ClientRegistrationRepository.class)
+@ConditionalOnMissingBean(ClientRegistrationRepository.class)
 @AutoConfigureBefore(SecurityAutoConfiguration.class)
 public class ClientRegistrationAutoConfiguration {
 	private static final String CLIENTS_DEFAULTS_RESOURCE = "META-INF/oauth2-clients-defaults.yml";
@@ -72,7 +74,7 @@ public class ClientRegistrationAutoConfiguration {
 		}
 
 		@Bean
-		public ClientRegistration[] clientRegistrations() {
+		public ClientRegistrationRepository clientRegistrations() {
 			MutablePropertySources propertySources = ((ConfigurableEnvironment) this.environment).getPropertySources();
 			Properties clientsDefaultProperties = this.getClientsDefaultProperties();
 			if (clientsDefaultProperties != null) {
@@ -93,7 +95,7 @@ public class ClientRegistrationAutoConfiguration {
 				clientRegistrations.add(clientRegistration);
 			}
 
-			return clientRegistrations.toArray(new ClientRegistration[0]);
+			return new InMemoryClientRegistrationRepository(clientRegistrations);
 		}
 
 		private Properties getClientsDefaultProperties() {

+ 2 - 2
samples/boot/oauth2login/src/main/java/org/springframework/boot/autoconfigure/security/oauth2/client/OAuth2LoginAutoConfiguration.java

@@ -27,7 +27,7 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfiguration;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
-import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 
 /**
  * @author Joe Grandja
@@ -36,7 +36,7 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 @ConditionalOnWebApplication
 @ConditionalOnClass(EnableWebSecurity.class)
 @ConditionalOnMissingBean(WebSecurityConfiguration.class)
-@ConditionalOnBean(ClientRegistration[].class)
+@ConditionalOnBean(ClientRegistrationRepository.class)
 @AutoConfigureBefore(SecurityAutoConfiguration.class)
 @AutoConfigureAfter(ClientRegistrationAutoConfiguration.class)
 public class OAuth2LoginAutoConfiguration {