|
@@ -15,7 +15,6 @@
|
|
|
*/
|
|
|
package org.springframework.security.config.annotation.web.configurers.oauth2.client;
|
|
|
|
|
|
-import org.springframework.beans.factory.BeanFactoryUtils;
|
|
|
import org.springframework.context.ApplicationContext;
|
|
|
import org.springframework.core.ResolvableType;
|
|
|
import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
|
|
@@ -39,15 +38,12 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
|
|
import org.springframework.security.web.util.matcher.RequestMatcher;
|
|
|
import org.springframework.security.web.util.matcher.RequestVariablesExtractor;
|
|
|
import org.springframework.util.Assert;
|
|
|
-import org.springframework.util.CollectionUtils;
|
|
|
|
|
|
import java.net.URI;
|
|
|
-import java.util.ArrayList;
|
|
|
import java.util.Arrays;
|
|
|
-import java.util.Collection;
|
|
|
-import java.util.List;
|
|
|
+import java.util.HashMap;
|
|
|
import java.util.Map;
|
|
|
-import java.util.stream.Collectors;
|
|
|
+import java.util.stream.Stream;
|
|
|
|
|
|
import static org.springframework.security.oauth2.client.web.AuthorizationCodeRequestRedirectFilter.REGISTRATION_ID_URI_VARIABLE_NAME;
|
|
|
|
|
@@ -75,7 +71,6 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
|
|
|
|
|
|
public OAuth2LoginConfigurer<H> clients(ClientRegistration... clientRegistrations) {
|
|
|
Assert.notEmpty(clientRegistrations, "clientRegistrations cannot be empty");
|
|
|
- this.getBuilder().setSharedObject(ClientRegistration[].class, clientRegistrations);
|
|
|
return this.clients(new InMemoryClientRegistrationRepository(Arrays.asList(clientRegistrations)));
|
|
|
}
|
|
|
|
|
@@ -230,56 +225,24 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
|
|
|
}
|
|
|
|
|
|
private static <H extends HttpSecurityBuilder<H>> ClientRegistrationRepository getDefaultClientRegistrationRepository(H http) {
|
|
|
- List<ClientRegistration> clientRegistrations = getClientRegistrations(http);
|
|
|
- if (!CollectionUtils.isEmpty(clientRegistrations)) {
|
|
|
- return new InMemoryClientRegistrationRepository(clientRegistrations);
|
|
|
- }
|
|
|
return http.getSharedObject(ApplicationContext.class).getBean(ClientRegistrationRepository.class);
|
|
|
}
|
|
|
|
|
|
- private static <H extends HttpSecurityBuilder<H>> List<ClientRegistration> getClientRegistrations(H http) {
|
|
|
- ClientRegistration[] clientRegistrations = http.getSharedObject(ClientRegistration[].class);
|
|
|
- if (clientRegistrations != null) {
|
|
|
- return Arrays.asList(clientRegistrations);
|
|
|
- }
|
|
|
-
|
|
|
- List<ClientRegistration> clientRegistrationsList = new ArrayList<>();
|
|
|
-
|
|
|
- // Check context for type -> Collection<ClientRegistration>
|
|
|
- ResolvableType clientRegistrationsType = ResolvableType.forClassWithGenerics(
|
|
|
- Collection.class, ClientRegistration.class);
|
|
|
- Map<String, ?> clientRegistrationsMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
|
|
|
- http.getSharedObject(ApplicationContext.class),
|
|
|
- clientRegistrationsType.resolve(Collection.class));
|
|
|
- clientRegistrationsMap.values().stream()
|
|
|
- .filter(col -> Collection.class.isAssignableFrom(col.getClass()))
|
|
|
- .filter(col -> ((Collection) col).stream()
|
|
|
- .anyMatch(e -> ClientRegistration.class.isAssignableFrom(e.getClass())))
|
|
|
- .flatMap(col -> ((Collection) col).stream())
|
|
|
- .forEach(e -> clientRegistrationsList.add((ClientRegistration)e));
|
|
|
- if (!clientRegistrationsList.isEmpty()) {
|
|
|
- return clientRegistrationsList;
|
|
|
- }
|
|
|
-
|
|
|
- // Check context for type -> ClientRegistration[]
|
|
|
- clientRegistrationsType = ResolvableType.forClass(ClientRegistration[].class);
|
|
|
- clientRegistrationsMap = BeanFactoryUtils.beansOfTypeIncludingAncestors(
|
|
|
- http.getSharedObject(ApplicationContext.class),
|
|
|
- clientRegistrationsType.resolve(ClientRegistration[].class));
|
|
|
- clientRegistrationsMap.values().stream()
|
|
|
- .flatMap(array -> Arrays.stream((ClientRegistration[])array))
|
|
|
- .forEach(clientRegistrationsList::add);
|
|
|
-
|
|
|
- return clientRegistrationsList;
|
|
|
- }
|
|
|
-
|
|
|
private void initDefaultLoginFilter(H http) {
|
|
|
DefaultLoginPageGeneratingFilter loginPageGeneratingFilter = http.getSharedObject(DefaultLoginPageGeneratingFilter.class);
|
|
|
if (loginPageGeneratingFilter == null || this.authorizationCodeAuthenticationFilterConfigurer.isCustomLoginPage()) {
|
|
|
return;
|
|
|
}
|
|
|
- List<ClientRegistration> clientRegistrations = getClientRegistrations(http);
|
|
|
- if (CollectionUtils.isEmpty(clientRegistrations)) {
|
|
|
+
|
|
|
+ Iterable<ClientRegistration> clientRegistrations = null;
|
|
|
+ ClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository(http);
|
|
|
+ ResolvableType type = ResolvableType.forInstance(clientRegistrationRepository).as(Iterable.class);
|
|
|
+ if (type != ResolvableType.NONE) {
|
|
|
+ if (Stream.of(type.resolveGenerics()).anyMatch(ClientRegistration.class::isAssignableFrom)) {
|
|
|
+ clientRegistrations = (Iterable<ClientRegistration>) clientRegistrationRepository;
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (clientRegistrations == null) {
|
|
|
return;
|
|
|
}
|
|
|
|
|
@@ -298,10 +261,9 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
|
|
|
authorizationRequestBaseUri = AuthorizationCodeRequestRedirectFilter.DEFAULT_AUTHORIZATION_REQUEST_BASE_URI;
|
|
|
}
|
|
|
|
|
|
- Map<String, String> oauth2AuthenticationUrlToClientName = clientRegistrations.stream()
|
|
|
- .collect(Collectors.toMap(
|
|
|
- e -> authorizationRequestBaseUri + "/" + e.getRegistrationId(),
|
|
|
- e -> e.getClientName()));
|
|
|
+ Map<String, String> oauth2AuthenticationUrlToClientName = new HashMap<>();
|
|
|
+ clientRegistrations.forEach(registration -> oauth2AuthenticationUrlToClientName.put(
|
|
|
+ authorizationRequestBaseUri + "/" + registration.getRegistrationId(), registration.getClientName()));
|
|
|
loginPageGeneratingFilter.setOauth2LoginEnabled(true);
|
|
|
loginPageGeneratingFilter.setOauth2AuthenticationUrlToClientName(oauth2AuthenticationUrlToClientName);
|
|
|
loginPageGeneratingFilter.setLoginPageUrl(this.authorizationCodeAuthenticationFilterConfigurer.getLoginUrl());
|