Browse Source

OAuth2LoginConfigurer UserService Beans

Fixes gh-7232
Josh Cummings 6 năm trước cách đây
mục cha
commit
a00ad37168

+ 1 - 1
config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java

@@ -2044,7 +2044,7 @@ public final class HttpSecurity extends
 	 * @throws Exception
 	 */
 	public HttpSecurity oauth2Login(Customizer<OAuth2LoginConfigurer<HttpSecurity>> oauth2LoginCustomizer) throws Exception {
-		oauth2LoginCustomizer.customize(getOrApply(new OAuth2LoginConfigurer<>()));
+		oauth2LoginCustomizer.customize(getOrApply(new OAuth2LoginConfigurer<>(getContext())));
 		return HttpSecurity.this;
 	}
 

+ 49 - 17
config/src/main/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurer.java

@@ -135,6 +135,7 @@ import java.util.Map;
 public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> extends
 	AbstractAuthenticationFilterConfigurer<B, OAuth2LoginConfigurer<B>, OAuth2LoginAuthenticationFilter> {
 
+	private final ApplicationContext context;
 	private final AuthorizationEndpointConfig authorizationEndpointConfig = new AuthorizationEndpointConfig();
 	private final TokenEndpointConfig tokenEndpointConfig = new TokenEndpointConfig();
 	private final RedirectionEndpointConfig redirectionEndpointConfig = new RedirectionEndpointConfig();
@@ -142,6 +143,11 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
 	private String loginPage;
 	private String loginProcessingUrl = OAuth2LoginAuthenticationFilter.DEFAULT_FILTER_PROCESSES_URI;
 
+	public OAuth2LoginConfigurer(ApplicationContext context) {
+		Assert.notNull(context, "context cannot be null");
+		this.context = context;
+	}
+
 	/**
 	 * Sets the repository of client registrations.
 	 *
@@ -506,18 +512,7 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
 			accessTokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient();
 		}
 
-		OAuth2UserService<OAuth2UserRequest, OAuth2User> oauth2UserService = this.userInfoEndpointConfig.userService;
-		if (oauth2UserService == null) {
-			if (!this.userInfoEndpointConfig.customUserTypes.isEmpty()) {
-				List<OAuth2UserService<OAuth2UserRequest, OAuth2User>> userServices = new ArrayList<>();
-				userServices.add(new CustomUserTypesOAuth2UserService(this.userInfoEndpointConfig.customUserTypes));
-				userServices.add(new DefaultOAuth2UserService());
-				oauth2UserService = new DelegatingOAuth2UserService<>(userServices);
-			} else {
-				oauth2UserService = new DefaultOAuth2UserService();
-			}
-		}
-
+		OAuth2UserService<OAuth2UserRequest, OAuth2User> oauth2UserService = getOAuth2UserService();
 		OAuth2LoginAuthenticationProvider oauth2LoginAuthenticationProvider =
 			new OAuth2LoginAuthenticationProvider(accessTokenResponseClient, oauth2UserService);
 		GrantedAuthoritiesMapper userAuthoritiesMapper = this.getGrantedAuthoritiesMapper();
@@ -530,11 +525,7 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
 			"org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader());
 
 		if (oidcAuthenticationProviderEnabled) {
-			OAuth2UserService<OidcUserRequest, OidcUser> oidcUserService = this.userInfoEndpointConfig.oidcUserService;
-			if (oidcUserService == null) {
-				oidcUserService = new OidcUserService();
-			}
-
+			OAuth2UserService<OidcUserRequest, OidcUser> oidcUserService = getOidcUserService();
 			OidcAuthorizationCodeAuthenticationProvider oidcAuthorizationCodeAuthenticationProvider =
 				new OidcAuthorizationCodeAuthenticationProvider(accessTokenResponseClient, oidcUserService);
 			JwtDecoderFactory<ClientRegistration> jwtDecoderFactory = this.getJwtDecoderFactoryBean();
@@ -627,6 +618,47 @@ public final class OAuth2LoginConfigurer<B extends HttpSecurityBuilder<B>> exten
 		return (!grantedAuthoritiesMapperMap.isEmpty() ? grantedAuthoritiesMapperMap.values().iterator().next() : null);
 	}
 
+	private OAuth2UserService<OidcUserRequest, OidcUser> getOidcUserService() {
+		if (this.userInfoEndpointConfig.oidcUserService != null) {
+			return this.userInfoEndpointConfig.oidcUserService;
+		}
+		ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2UserService.class, OidcUserRequest.class, OidcUser.class);
+		OAuth2UserService<OidcUserRequest, OidcUser> bean = getBeanOrNull(type);
+		if (bean == null) {
+			return new OidcUserService();
+		}
+
+		return bean;
+	}
+
+	private OAuth2UserService<OAuth2UserRequest, OAuth2User> getOAuth2UserService() {
+		if (this.userInfoEndpointConfig.userService != null) {
+			return this.userInfoEndpointConfig.userService;
+		}
+		ResolvableType type = ResolvableType.forClassWithGenerics(OAuth2UserService.class, OAuth2UserRequest.class, OAuth2User.class);
+		OAuth2UserService<OAuth2UserRequest, OAuth2User> bean = getBeanOrNull(type);
+		if (bean == null) {
+			if (!this.userInfoEndpointConfig.customUserTypes.isEmpty()) {
+				List<OAuth2UserService<OAuth2UserRequest, OAuth2User>> userServices = new ArrayList<>();
+				userServices.add(new CustomUserTypesOAuth2UserService(this.userInfoEndpointConfig.customUserTypes));
+				userServices.add(new DefaultOAuth2UserService());
+				return new DelegatingOAuth2UserService<>(userServices);
+			} else {
+				return new DefaultOAuth2UserService();
+			}
+		}
+
+		return bean;
+	}
+
+	private <T> T getBeanOrNull(ResolvableType type) {
+		String[] names =  this.context.getBeanNamesForType(type);
+		if (names.length == 1) {
+			return (T) this.context.getBean(names[0]);
+		}
+		return null;
+	}
+
 	private void initDefaultLoginFilter(B http) {
 		DefaultLoginPageGeneratingFilter loginPageGeneratingFilter = http.getSharedObject(DefaultLoginPageGeneratingFilter.class);
 		if (loginPageGeneratingFilter == null || this.isCustomLoginPage()) {

+ 73 - 0
config/src/test/java/org/springframework/security/config/annotation/web/configurers/oauth2/client/OAuth2LoginConfigurerTests.java

@@ -271,6 +271,32 @@ public class OAuth2LoginConfigurerTests {
 		assertThat(authentication.getAuthorities()).last().hasToString("ROLE_OAUTH2_USER");
 	}
 
+	@Test
+	public void oauth2LoginCustomWithUserServiceBeanRegistration() throws Exception {
+		// setup application context
+		loadConfig(OAuth2LoginConfigCustomUserServiceBeanRegistration.class);
+
+		// setup authorization request
+		OAuth2AuthorizationRequest authorizationRequest = createOAuth2AuthorizationRequest();
+		this.authorizationRequestRepository.saveAuthorizationRequest(
+				authorizationRequest, this.request, this.response);
+
+		// setup authentication parameters
+		this.request.setParameter("code", "code123");
+		this.request.setParameter("state", authorizationRequest.getState());
+
+		// perform test
+		this.springSecurityFilterChain.doFilter(this.request, this.response, this.filterChain);
+
+		// assertions
+		Authentication authentication = this.securityContextRepository
+				.loadContext(new HttpRequestResponseHolder(this.request, this.response))
+				.getAuthentication();
+		assertThat(authentication.getAuthorities()).hasSize(2);
+		assertThat(authentication.getAuthorities()).first().hasToString("ROLE_USER");
+		assertThat(authentication.getAuthorities()).last().hasToString("ROLE_OAUTH2_USER");
+	}
+
 	// gh-5488
 	@Test
 	public void oauth2LoginConfigLoginProcessingUrl() throws Exception {
@@ -660,6 +686,53 @@ public class OAuth2LoginConfigurerTests {
 		}
 	}
 
+	@EnableWebSecurity
+	static class OAuth2LoginConfigCustomUserServiceBeanRegistration extends WebSecurityConfigurerAdapter {
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			http
+				.authorizeRequests()
+					.anyRequest().authenticated()
+					.and()
+				.securityContext()
+					.securityContextRepository(securityContextRepository())
+					.and()
+				.oauth2Login()
+					.tokenEndpoint()
+						.accessTokenResponseClient(createOauth2AccessTokenResponseClient());
+		}
+
+		@Bean
+		ClientRegistrationRepository clientRegistrationRepository() {
+			return new InMemoryClientRegistrationRepository(GOOGLE_CLIENT_REGISTRATION);
+		}
+
+		@Bean
+		GrantedAuthoritiesMapper grantedAuthoritiesMapper() {
+			return createGrantedAuthoritiesMapper();
+		}
+
+		@Bean
+		SecurityContextRepository securityContextRepository() {
+			return new HttpSessionSecurityContextRepository();
+		}
+
+		@Bean
+		HttpSessionOAuth2AuthorizationRequestRepository oauth2AuthorizationRequestRepository() {
+			return new HttpSessionOAuth2AuthorizationRequestRepository();
+		}
+
+		@Bean
+		OAuth2UserService<OAuth2UserRequest, OAuth2User> oauth2UserService() {
+			return createOauth2UserService();
+		}
+
+		@Bean
+		OAuth2UserService<OidcUserRequest, OidcUser> oidcUserService() {
+			return createOidcUserService();
+		}
+	}
+
 	@EnableWebSecurity
 	static class OAuth2LoginConfigLoginProcessingUrl extends CommonWebSecurityConfigurerAdapter {
 		@Override