瀏覽代碼

Reactive OAuth2 DSL Customizations

Fixes: gh-5855
Rob Winch 7 年之前
父節點
當前提交
72301e548a

+ 134 - 14
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -29,6 +29,11 @@ import java.util.List;
 import java.util.Map;
 
 import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.oauth2.client.oidc.userinfo.OidcUserRequest;
+import org.springframework.security.oauth2.client.userinfo.OAuth2UserRequest;
+import org.springframework.security.oauth2.core.oidc.user.OidcUser;
+import org.springframework.security.oauth2.core.user.OAuth2User;
+import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
 import reactor.core.publisher.Mono;
 import reactor.util.context.Context;
 
@@ -512,6 +517,64 @@ public class ServerHttpSecurity {
 
 		private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
 
+		private ReactiveAuthenticationManager authenticationManager;
+
+		private ServerAuthenticationConverter authenticationConverter;
+
+		/**
+		 * Configures the {@link ReactiveAuthenticationManager} to use. The default is
+		 * {@link OAuth2AuthorizationCodeReactiveAuthenticationManager}
+		 * @param authenticationManager the manager to use
+		 * @return the {@link OAuth2LoginSpec} to customize
+		 */
+		public OAuth2LoginSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) {
+			this.authenticationManager = authenticationManager;
+			return this;
+		}
+
+		/**
+		 * Gets the {@link ReactiveAuthenticationManager} to use. First tries an explicitly configured manager, and
+		 * defaults to {@link OAuth2AuthorizationCodeReactiveAuthenticationManager}
+		 *
+		 * @return the {@link ReactiveAuthenticationManager} to use
+		 */
+		private ReactiveAuthenticationManager getAuthenticationManager() {
+			if (this.authenticationManager == null) {
+				this.authenticationManager = createDefault();
+			}
+			return this.authenticationManager;
+		}
+
+		private ReactiveAuthenticationManager createDefault() {
+			WebClientReactiveAuthorizationCodeTokenResponseClient client = new WebClientReactiveAuthorizationCodeTokenResponseClient();
+			ReactiveAuthenticationManager result = new OAuth2LoginReactiveAuthenticationManager(client, getOauth2UserService());
+
+			boolean oidcAuthenticationProviderEnabled = ClassUtils.isPresent(
+					"org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader());
+			if (oidcAuthenticationProviderEnabled) {
+				OidcAuthorizationCodeReactiveAuthenticationManager oidc = new OidcAuthorizationCodeReactiveAuthenticationManager(client, getOidcUserService());
+				result = new DelegatingReactiveAuthenticationManager(oidc, result);
+			}
+			return result;
+		}
+
+		/**
+		 * Sets the converter to use
+		 * @param authenticationConverter the converter to use
+		 * @return the {@link OAuth2LoginSpec} to customize
+		 */
+		public OAuth2LoginSpec authenticationConverter(ServerAuthenticationConverter authenticationConverter) {
+			this.authenticationConverter = authenticationConverter;
+			return this;
+		}
+
+		private ServerAuthenticationConverter getAuthenticationConverter(ReactiveClientRegistrationRepository clientRegistrationRepository) {
+			if (this.authenticationConverter == null) {
+				this.authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository);
+			}
+			return this.authenticationConverter;
+		}
+
 		public OAuth2LoginSpec clientRegistrationRepository(ReactiveClientRegistrationRepository clientRegistrationRepository) {
 			this.clientRegistrationRepository = clientRegistrationRepository;
 			return this;
@@ -541,21 +604,11 @@ public class ServerHttpSecurity {
 			ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository();
 			OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(clientRegistrationRepository);
 
-			WebClientReactiveAuthorizationCodeTokenResponseClient client = new WebClientReactiveAuthorizationCodeTokenResponseClient();
-			ReactiveOAuth2UserService userService = new DefaultReactiveOAuth2UserService();
-			ReactiveAuthenticationManager manager = new OAuth2LoginReactiveAuthenticationManager(client, userService);
-
-			boolean oidcAuthenticationProviderEnabled = ClassUtils.isPresent(
-					"org.springframework.security.oauth2.jwt.JwtDecoder", this.getClass().getClassLoader());
-			if (oidcAuthenticationProviderEnabled) {
-				OidcAuthorizationCodeReactiveAuthenticationManager oidc = new OidcAuthorizationCodeReactiveAuthenticationManager(client, new OidcReactiveOAuth2UserService());
-				manager = new DelegatingReactiveAuthenticationManager(oidc, manager);
-			}
+			ReactiveAuthenticationManager manager = getAuthenticationManager();
 
 			AuthenticationWebFilter authenticationFilter = new OAuth2LoginAuthenticationWebFilter(manager, authorizedClientRepository);
 			authenticationFilter.setRequiresAuthenticationMatcher(createAttemptAuthenticationRequestMatcher());
-			authenticationFilter.setServerAuthenticationConverter(new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(clientRegistrationRepository));
-
+			authenticationFilter.setServerAuthenticationConverter(getAuthenticationConverter(clientRegistrationRepository));
 			RedirectServerAuthenticationSuccessHandler redirectHandler = new RedirectServerAuthenticationSuccessHandler();
 
 			authenticationFilter.setAuthenticationSuccessHandler(redirectHandler);
@@ -589,6 +642,27 @@ public class ServerHttpSecurity {
 					.switchIfEmpty(ServerWebExchangeMatcher.MatchResult.match());
 			return new AndServerWebExchangeMatcher(loginPathMatcher, notAuthenticatedMatcher);
 		}
+
+		private ReactiveOAuth2UserService<OidcUserRequest, OidcUser> getOidcUserService() {
+			ResolvableType type = ResolvableType.forClassWithGenerics(ReactiveOAuth2UserService.class, OidcUserRequest.class, OidcUser.class);
+			ReactiveOAuth2UserService<OidcUserRequest, OidcUser> bean = getBeanOrNull(type);
+			if (bean == null) {
+				return new OidcReactiveOAuth2UserService();
+			}
+
+			return bean;
+		}
+
+		private ReactiveOAuth2UserService<OAuth2UserRequest, OAuth2User> getOauth2UserService() {
+			ResolvableType type = ResolvableType.forClassWithGenerics(ReactiveOAuth2UserService.class, OAuth2UserRequest.class, OAuth2User.class);
+			ReactiveOAuth2UserService<OAuth2UserRequest, OAuth2User> bean = getBeanOrNull(type);
+			if (bean == null) {
+				return new DefaultReactiveOAuth2UserService();
+			}
+
+			return bean;
+		}
+
 		private Map<String, String> getLinks() {
 			Iterable<ClientRegistration> registrations = getBeanOrNull(ResolvableType.forClassWithGenerics(Iterable.class, ClientRegistration.class));
 			if (registrations == null) {
@@ -662,8 +736,53 @@ public class ServerHttpSecurity {
 	public class OAuth2ClientSpec {
 		private ReactiveClientRegistrationRepository clientRegistrationRepository;
 
+		private ServerAuthenticationConverter authenticationConverter;
+
 		private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
 
+		private ReactiveAuthenticationManager authenticationManager;
+
+		/**
+		 * Sets the converter to use
+		 * @param authenticationConverter the converter to use
+		 * @return the {@link OAuth2ClientSpec} to customize
+		 */
+		public OAuth2ClientSpec authenticationConverter(ServerAuthenticationConverter authenticationConverter) {
+			this.authenticationConverter = authenticationConverter;
+			return this;
+		}
+
+		private ServerAuthenticationConverter getAuthenticationConverter() {
+			if (this.authenticationConverter == null) {
+				this.authenticationConverter = new ServerOAuth2AuthorizationCodeAuthenticationTokenConverter(getClientRegistrationRepository());
+			}
+			return this.authenticationConverter;
+		}
+
+		/**
+		 * Configures the {@link ReactiveAuthenticationManager} to use. The default is
+		 * {@link OAuth2AuthorizationCodeReactiveAuthenticationManager}
+		 * @param authenticationManager the manager to use
+		 * @return the {@link OAuth2ClientSpec} to customize
+		 */
+		public OAuth2ClientSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) {
+			this.authenticationManager = authenticationManager;
+			return this;
+		}
+
+		/**
+		 * Gets the {@link ReactiveAuthenticationManager} to use. First tries an explicitly configured manager, and
+		 * defaults to {@link OAuth2AuthorizationCodeReactiveAuthenticationManager}
+		 *
+		 * @return the {@link ReactiveAuthenticationManager} to use
+		 */
+		private ReactiveAuthenticationManager getAuthenticationManager() {
+			if (this.authenticationManager == null) {
+				this.authenticationManager = new OAuth2AuthorizationCodeReactiveAuthenticationManager(new WebClientReactiveAuthorizationCodeTokenResponseClient());
+			}
+			return this.authenticationManager;
+		}
+
 		/**
 		 * Configures the {@link ReactiveClientRegistrationRepository}. Default is to look the value up as a Bean.
 		 * @param clientRegistrationRepository the repository to use
@@ -687,9 +806,10 @@ public class ServerHttpSecurity {
 		protected void configure(ServerHttpSecurity http) {
 			ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository();
 			ServerOAuth2AuthorizedClientRepository authorizedClientRepository = getAuthorizedClientRepository();
-			ReactiveAuthenticationManager authenticationManager = new OAuth2AuthorizationCodeReactiveAuthenticationManager(new WebClientReactiveAuthorizationCodeTokenResponseClient());
+			ServerAuthenticationConverter authenticationConverter = getAuthenticationConverter();
+			ReactiveAuthenticationManager authenticationManager = getAuthenticationManager();
 			OAuth2AuthorizationCodeGrantWebFilter codeGrantWebFilter = new OAuth2AuthorizationCodeGrantWebFilter(authenticationManager,
-					clientRegistrationRepository,
+					authenticationConverter,
 					authorizedClientRepository);
 
 			OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(

+ 68 - 0
config/src/test/java/org/springframework/security/config/web/server/OAuth2ClientSpecTests.java

@@ -22,16 +22,27 @@ import org.junit.runner.RunWith;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.ApplicationContext;
 import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.security.authentication.ReactiveAuthenticationManager;
+import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
 import org.springframework.security.config.test.SpringTestRule;
 import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
 import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
+import org.springframework.security.oauth2.client.authentication.OAuth2AuthorizationCodeAuthenticationToken;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
 import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
+import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges;
 import org.springframework.security.test.context.annotation.SecurityTestExecutionListeners;
 import org.springframework.security.test.context.support.WithMockUser;
 import org.springframework.security.web.server.SecurityWebFilterChain;
+import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
 import org.springframework.test.context.junit4.SpringRunner;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.web.bind.annotation.GetMapping;
@@ -41,6 +52,7 @@ import reactor.core.publisher.Mono;
 
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
 /**
@@ -55,6 +67,8 @@ public class OAuth2ClientSpecTests {
 
 	private WebTestClient client;
 
+	private ClientRegistration registration = TestClientRegistrations.clientRegistration().build();
+
 	@Autowired
 	public void setApplicationContext(ApplicationContext context) {
 		this.client = WebTestClient.bindToApplicationContext(context).build();
@@ -117,4 +131,58 @@ public class OAuth2ClientSpecTests {
 			return "home";
 		}
 	}
+
+	@Test
+	public void oauth2ClientWhenCustomObjectsThenUsed() {
+		this.spring.register(ClientRegistrationConfig.class, OAuth2ClientCustomConfig.class, AuthorizedClientController.class).autowire();
+
+		OAuth2ClientCustomConfig config = this.spring.getContext().getBean(OAuth2ClientCustomConfig.class);
+
+		ServerAuthenticationConverter converter = config.authenticationConverter;
+		ReactiveAuthenticationManager manager = config.manager;
+
+		OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success();
+		OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes();
+
+		OAuth2AuthorizationCodeAuthenticationToken result = new OAuth2AuthorizationCodeAuthenticationToken(this.registration, exchange, accessToken);
+
+		when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c")));
+		when(manager.authenticate(any())).thenReturn(Mono.just(result));
+
+		this.client.get()
+			.uri("/authorize/oauth2/code/registration-id")
+			.exchange()
+			.expectStatus().is3xxRedirection();
+
+		verify(converter).convert(any());
+		verify(manager).authenticate(any());
+	}
+
+	@EnableWebFlux
+	@EnableWebFluxSecurity
+	static class ClientRegistrationConfig {
+		private ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration()
+				.build();
+
+		@Bean
+		InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() {
+			return new InMemoryReactiveClientRegistrationRepository(this.clientRegistration);
+		}
+	}
+
+	@Configuration
+	static class OAuth2ClientCustomConfig {
+		ReactiveAuthenticationManager manager = mock(ReactiveAuthenticationManager.class);
+
+		ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class);
+
+		@Bean
+		public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
+			http
+				.oauth2Client()
+					.authenticationConverter(this.authenticationConverter)
+					.authenticationManager(this.manager);
+			return http.build();
+		}
+	}
 }

+ 67 - 0
config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

@@ -17,20 +17,36 @@
 package org.springframework.security.config.web.server;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 import org.junit.Rule;
 import org.junit.Test;
 import org.openqa.selenium.WebDriver;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.security.authentication.ReactiveAuthenticationManager;
+import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
 import org.springframework.security.config.oauth2.client.CommonOAuth2Provider;
 import org.springframework.security.config.test.SpringTestRule;
 import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder;
+import org.springframework.security.oauth2.client.authentication.OAuth2LoginAuthenticationToken;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
+import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
+import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges;
+import org.springframework.security.oauth2.core.user.OAuth2User;
+import org.springframework.security.oauth2.core.user.TestOAuth2Users;
 import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
+import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.WebFilterChainProxy;
+import org.springframework.security.web.server.authentication.ServerAuthenticationConverter;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
@@ -115,6 +131,57 @@ public class OAuth2LoginTests {
 		}
 	}
 
+	@Test
+	public void oauth2LoginWhenCustomObjectsThenUsed() {
+		this.spring.register(OAuth2LoginWithSingleClientRegistrations.class,
+				OAuth2LoginMockAuthenticationManagerConfig.class).autowire();
+
+		WebTestClient webTestClient = WebTestClientBuilder
+				.bindToWebFilters(this.springSecurity)
+				.build();
+
+		OAuth2LoginMockAuthenticationManagerConfig config = this.spring.getContext()
+				.getBean(OAuth2LoginMockAuthenticationManagerConfig.class);
+		ServerAuthenticationConverter converter = config.authenticationConverter;
+		ReactiveAuthenticationManager manager = config.manager;
+
+		OAuth2AuthorizationExchange exchange = TestOAuth2AuthorizationExchanges.success();
+		OAuth2User user = TestOAuth2Users.create();
+		OAuth2AccessToken accessToken = TestOAuth2AccessTokens.noScopes();
+
+		OAuth2LoginAuthenticationToken result = new OAuth2LoginAuthenticationToken(github, exchange, user, user.getAuthorities(), accessToken);
+
+		when(converter.convert(any())).thenReturn(Mono.just(new TestingAuthenticationToken("a", "b", "c")));
+		when(manager.authenticate(any())).thenReturn(Mono.just(result));
+
+		webTestClient.get()
+			.uri("/login/oauth2/code/github")
+			.exchange()
+			.expectStatus().is3xxRedirection();
+
+		verify(converter).convert(any());
+		verify(manager).authenticate(any());
+	}
+
+	@Configuration
+	static class OAuth2LoginMockAuthenticationManagerConfig {
+		ReactiveAuthenticationManager manager = mock(ReactiveAuthenticationManager.class);
+
+		ServerAuthenticationConverter authenticationConverter = mock(ServerAuthenticationConverter.class);
+
+		@Bean
+		public SecurityWebFilterChain springSecurityFilter(ServerHttpSecurity http) {
+			http
+				.authorizeExchange()
+					.anyExchange().authenticated()
+					.and()
+				.oauth2Login()
+					.authenticationConverter(authenticationConverter)
+					.authenticationManager(manager);
+			return http.build();
+		}
+	}
+
 	static class GitHubWebFilter implements WebFilter {
 
 		@Override

+ 38 - 0
oauth2/oauth2-core/src/test/java/org/springframework/security/oauth2/core/user/TestOAuth2Users.java

@@ -0,0 +1,38 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.core.user;
+
+import org.springframework.security.core.GrantedAuthority;
+import org.springframework.security.core.authority.AuthorityUtils;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * @author Rob Winch
+ */
+public class TestOAuth2Users {
+
+	public static DefaultOAuth2User create() {
+		List<GrantedAuthority> roles = AuthorityUtils.createAuthorityList("ROLE_USER");
+		String attrName = "username";
+		Map<String, Object> attributes = new HashMap<>();
+		attributes.put(attrName, "user");
+		return new DefaultOAuth2User(roles, attributes, attrName);
+	}
+}