فهرست منبع

Single ClientRegistration redirects by default

Fixes: gh-5339
Rob Winch 7 سال پیش
والد
کامیت
32e368d9b7

+ 47 - 15
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -183,6 +183,8 @@ public class ServerHttpSecurity {
 
 	private LogoutSpec logout = new LogoutSpec();
 
+	private LoginPageSpec loginPage = new LoginPageSpec();
+
 	private ReactiveAuthenticationManager authenticationManager;
 
 	private ServerSecurityContextRepository securityContextRepository = new WebSessionServerSecurityContextRepository();
@@ -387,6 +389,16 @@ public class ServerHttpSecurity {
 			});
 			authenticationFilter.setSecurityContextRepository(new WebSessionServerSecurityContextRepository());
 
+			MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher(
+					MediaType.TEXT_HTML);
+			htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL));
+			Map<String, String> urlToText = http.oauth2Login.getLinks();
+			if (urlToText.size() == 1) {
+				http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint(urlToText.keySet().iterator().next())));
+			} else {
+				http.defaultEntryPoints.add(new DelegateEntry(htmlMatcher, new RedirectServerAuthenticationEntryPoint("/login")));
+			}
+
 			http.addFilterAt(oauthRedirectFilter, SecurityWebFiltersOrder.HTTP_BASIC);
 			http.addFilterAt(authenticationFilter, SecurityWebFiltersOrder.AUTHENTICATION);
 		}
@@ -610,31 +622,17 @@ public class ServerHttpSecurity {
 			this.httpBasic.authenticationManager(this.authenticationManager);
 			this.httpBasic.configure(this);
 		}
-		LoginPageGeneratingWebFilter loginPageFilter = null;
 		if(this.formLogin != null) {
 			this.formLogin.authenticationManager(this.authenticationManager);
 			if(this.securityContextRepository != null) {
 				this.formLogin.securityContextRepository(this.securityContextRepository);
 			}
-			if (this.authenticationEntryPoint == null) {
-				loginPageFilter = new LoginPageGeneratingWebFilter();
-				loginPageFilter.setFormLoginEnabled(true);
-				this.authenticationEntryPoint = this.formLogin.authenticationEntryPoint;
-			}
 			this.formLogin.configure(this);
 		}
 		if (this.oauth2Login != null) {
-			if (this.authenticationEntryPoint == null) {
-				loginPageFilter = new LoginPageGeneratingWebFilter();
-				loginPageFilter.setOauth2AuthenticationUrlToClientName(this.oauth2Login.getLinks());
-			}
 			this.oauth2Login.configure(this);
 		}
-		if (loginPageFilter != null) {
-			this.authenticationEntryPoint = new RedirectServerAuthenticationEntryPoint("/login");
-			this.webFilters.add(new OrderedWebFilter(loginPageFilter, SecurityWebFiltersOrder.LOGIN_PAGE_GENERATING.getOrder()));
-			this.webFilters.add(new OrderedWebFilter(new LogoutPageGeneratingWebFilter(), SecurityWebFiltersOrder.LOGOUT_PAGE_GENERATING.getOrder()));
-		}
+		this.loginPage.configure(this);
 		if(this.logout != null) {
 			this.logout.configure(this);
 		}
@@ -1084,6 +1082,8 @@ public class ServerHttpSecurity {
 
 		private ServerAuthenticationEntryPoint authenticationEntryPoint;
 
+		private boolean isEntryPointExplicit;
+
 		private ServerWebExchangeMatcher requiresAuthenticationMatcher;
 
 		private ServerAuthenticationFailureHandler authenticationFailureHandler;
@@ -1206,7 +1206,10 @@ public class ServerHttpSecurity {
 
 		protected void configure(ServerHttpSecurity http) {
 			if(this.authenticationEntryPoint == null) {
+				this.isEntryPointExplicit = false;
 				loginPage("/login");
+			} else {
+				this.isEntryPointExplicit = true;
 			}
 			if(http.requestCache != null) {
 				ServerRequestCache requestCache = http.requestCache.requestCache;
@@ -1233,6 +1236,35 @@ public class ServerHttpSecurity {
 		}
 	}
 
+	private class LoginPageSpec {
+		protected void configure(ServerHttpSecurity http) {
+			if (http.authenticationEntryPoint != null) {
+				return;
+			}
+			if (http.formLogin != null && http.formLogin.isEntryPointExplicit) {
+				return;
+			}
+			LoginPageGeneratingWebFilter loginPage = null;
+			if (http.formLogin != null && !http.formLogin.isEntryPointExplicit) {
+				loginPage = new LoginPageGeneratingWebFilter();
+				loginPage.setFormLoginEnabled(true);
+			}
+			if (http.oauth2Login != null) {
+				Map<String, String> urlToText = http.oauth2Login.getLinks();
+				if (loginPage == null) {
+					loginPage = new LoginPageGeneratingWebFilter();
+				}
+				loginPage.setOauth2AuthenticationUrlToClientName(urlToText);
+			}
+			if (loginPage != null) {
+				http.addFilterAt(loginPage, SecurityWebFiltersOrder.LOGIN_PAGE_GENERATING);
+				http.addFilterAt(new LogoutPageGeneratingWebFilter(), SecurityWebFiltersOrder.LOGOUT_PAGE_GENERATING);
+			}
+		}
+
+		private LoginPageSpec() {}
+	}
+
 	/**
 	 * Configures HTTP Response Headers.
 	 *

+ 39 - 17
config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

@@ -21,29 +21,20 @@ import static org.assertj.core.api.Assertions.assertThat;
 import org.junit.Rule;
 import org.junit.Test;
 import org.openqa.selenium.WebDriver;
-import org.openqa.selenium.WebElement;
-import org.openqa.selenium.support.FindBy;
-import org.openqa.selenium.support.PageFactory;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
-import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
 import org.springframework.security.config.oauth2.client.CommonOAuth2Provider;
 import org.springframework.security.config.test.SpringTestRule;
 import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
-import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
 import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
-import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.WebFilterChainProxy;
-import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler;
-import org.springframework.security.web.server.csrf.CsrfToken;
-import org.springframework.stereotype.Controller;
 import org.springframework.test.web.reactive.server.WebTestClient;
-import org.springframework.web.bind.annotation.GetMapping;
-import org.springframework.web.bind.annotation.ResponseBody;
 import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebFilter;
+import org.springframework.web.server.WebFilterChain;
 
 import reactor.core.publisher.Mono;
 
@@ -59,7 +50,7 @@ public class OAuth2LoginTests {
 	@Autowired
 	private WebFilterChainProxy springSecurity;
 
-	private ClientRegistration github = CommonOAuth2Provider.GITHUB
+	private static ClientRegistration github = CommonOAuth2Provider.GITHUB
 			.getBuilder("github")
 			.clientId("client")
 			.clientSecret("secret")
@@ -90,11 +81,6 @@ public class OAuth2LoginTests {
 	static class OAuth2LoginWithMulitpleClientRegistrations {
 		@Bean
 		InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() {
-			ClientRegistration github = CommonOAuth2Provider.GITHUB
-					.getBuilder("github")
-					.clientId("client")
-					.clientSecret("secret")
-					.build();
 			ClientRegistration google = CommonOAuth2Provider.GOOGLE
 					.getBuilder("google")
 					.clientId("client")
@@ -103,4 +89,40 @@ public class OAuth2LoginTests {
 			return new InMemoryReactiveClientRegistrationRepository(github, google);
 		}
 	}
+
+	@Test
+	public void defaultLoginPageWithSingleClientRegistrationThenRedirect() {
+		this.spring.register(OAuth2LoginWithSingleClientRegistrations.class).autowire();
+
+		WebTestClient webTestClient = WebTestClientBuilder
+				.bindToWebFilters(new GitHubWebFilter(), this.springSecurity)
+				.build();
+
+		WebDriver driver = WebTestClientHtmlUnitDriverBuilder
+				.webTestClientSetup(webTestClient)
+				.build();
+
+		driver.get("http://localhost/");
+
+		assertThat(driver.getCurrentUrl()).startsWith("https://github.com/login/oauth/authorize");
+	}
+
+	@EnableWebFluxSecurity
+	static class OAuth2LoginWithSingleClientRegistrations {
+		@Bean
+		InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() {
+			return new InMemoryReactiveClientRegistrationRepository(github);
+		}
+	}
+
+	static class GitHubWebFilter implements WebFilter {
+
+		@Override
+		public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
+			if (exchange.getRequest().getURI().getHost().equals("github.com")) {
+				return exchange.getResponse().setComplete();
+			}
+			return chain.filter(exchange);
+		}
+	}
 }

+ 7 - 1
config/src/test/java/org/springframework/security/htmlunit/server/HtmlUnitWebTestClient.java

@@ -151,8 +151,14 @@ final class HtmlUnitWebTestClient {
 
 		private Mono<ClientResponse> redirectIfNecessary(ClientRequest request, ExchangeFunction next, ClientResponse response) {
 			URI location = response.headers().asHttpHeaders().getLocation();
+			String host = request.url().getHost();
+			String scheme = request.url().getScheme();
 			if(location != null) {
-				ClientRequest redirect = ClientRequest.method(HttpMethod.GET, URI.create("http://localhost" + location.toASCIIString()))
+				String redirectUrl = location.toASCIIString();
+				if (location.getHost() == null) {
+					redirectUrl = scheme+ "://" + host + location.toASCIIString();
+				}
+				ClientRequest redirect = ClientRequest.method(HttpMethod.GET, URI.create(redirectUrl))
 					.headers(headers -> headers.addAll(request.headers()))
 					.cookies(cookies -> cookies.addAll(request.cookies()))
 					.attributes(attributes -> attributes.putAll(request.attributes()))