Selaa lähdekoodia

LoginPageGeneratingWebFilter conditionally renders formLogin

Issue: gh-4807
Rob Winch 7 vuotta sitten
vanhempi
commit
f29e4cf91f

+ 1 - 0
config/spring-security-config.gradle

@@ -35,6 +35,7 @@ dependencies {
 	testCompile powerMock2Dependencies
 	testCompile spockDependencies
 	testCompile 'ch.qos.logback:logback-classic'
+	testCompile 'io.projectreactor.ipc:reactor-netty'
 	testCompile 'javax.annotation:jsr250-api:1.0'
 	testCompile 'javax.xml.bind:jaxb-api'
 	testCompile 'ldapsdk:ldapsdk:4.1'

+ 11 - 8
config/src/main/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfiguration.java

@@ -87,13 +87,14 @@ class WebFluxSecurityConfiguration {
 	private SecurityWebFilterChain springSecurityFilterChain(ServerHttpSecurity http) {
 		http
 			.authorizeExchange()
-				.anyExchange().authenticated()
-				.and()
-			.httpBasic().and()
-			.formLogin();
+				.anyExchange().authenticated();
 
-		if (isOAuth2Present) {
+		if (isOAuth2Present && OAuth2ClasspathGuard.shouldConfigure(this.context)) {
 			OAuth2ClasspathGuard.configure(this.context, http);
+		} else {
+			http
+				.httpBasic().and()
+				.formLogin();
 		}
 
 		SecurityWebFilterChain result = http.build();
@@ -102,11 +103,13 @@ class WebFluxSecurityConfiguration {
 
 	private static class OAuth2ClasspathGuard {
 		static void configure(ApplicationContext context, ServerHttpSecurity http) {
+			http.oauth2Login();
+		}
+
+		static boolean shouldConfigure(ApplicationContext context) {
 			ClassLoader loader = context.getClassLoader();
 			Class<?> reactiveClientRegistrationRepositoryClass = ClassUtils.resolveClassName(REACTIVE_CLIENT_REGISTRATION_REPOSITORY_CLASSNAME, loader);
-			if (context.getBeanNamesForType(reactiveClientRegistrationRepositoryClass).length == 1) {
-				http.oauth2Login();
-			}
+			return context.getBeanNamesForType(reactiveClientRegistrationRepositoryClass).length == 1;
 		}
 	}
 }

+ 20 - 11
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -41,6 +41,7 @@ import org.springframework.security.authorization.AuthorityReactiveAuthorization
 import org.springframework.security.authorization.AuthorizationDecision;
 import org.springframework.security.authorization.ReactiveAuthorizationManager;
 import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.oauth2.client.InMemoryReactiveOAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
 import org.springframework.security.oauth2.client.authentication.OAuth2LoginReactiveAuthenticationManager;
 import org.springframework.security.oauth2.client.endpoint.NimbusReactiveAuthorizationCodeTokenResponseClient;
@@ -361,11 +362,7 @@ public class ServerHttpSecurity {
 			return this;
 		}
 
-		protected void configure(LoginPageGeneratingWebFilter loginPageFilter, ServerHttpSecurity http) {
-			if (loginPageFilter != null) {
-				loginPageFilter.setOauth2AuthenticationUrlToClientName(getLinks());
-			}
-
+		protected void configure(ServerHttpSecurity http) {
 			ReactiveClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository();
 			ReactiveOAuth2AuthorizedClientService authorizedClientService = getAuthorizedClientService();
 			OAuth2AuthorizationRequestRedirectWebFilter oauthRedirectFilter = new OAuth2AuthorizationRequestRedirectWebFilter(clientRegistrationRepository);
@@ -417,6 +414,9 @@ public class ServerHttpSecurity {
 			if (this.authorizedClientService == null) {
 				this.authorizedClientService = getBeanOrNull(ReactiveOAuth2AuthorizedClientService.class);
 			}
+			if (this.authorizedClientService == null) {
+				this.authorizedClientService = new InMemoryReactiveOAuth2AuthorizedClientService(getClientRegistrationRepository());
+			}
 			return this.authorizedClientService;
 		}
 
@@ -616,15 +616,24 @@ public class ServerHttpSecurity {
 			if(this.securityContextRepository != null) {
 				this.formLogin.securityContextRepository(this.securityContextRepository);
 			}
-			if(this.formLogin.authenticationEntryPoint == null) {
+			if (this.authenticationEntryPoint == null) {
 				loginPageFilter = new LoginPageGeneratingWebFilter();
-				this.webFilters.add(new OrderedWebFilter(loginPageFilter, SecurityWebFiltersOrder.LOGIN_PAGE_GENERATING.getOrder()));
-				this.webFilters.add(new OrderedWebFilter(new LogoutPageGeneratingWebFilter(), SecurityWebFiltersOrder.LOGOUT_PAGE_GENERATING.getOrder()));
+				loginPageFilter.setFormLoginEnabled(true);
+				this.authenticationEntryPoint = this.formLogin.authenticationEntryPoint;
 			}
 			this.formLogin.configure(this);
 		}
 		if (this.oauth2Login != null) {
-			this.oauth2Login.configure(loginPageFilter, this);
+			if (this.authenticationEntryPoint == null) {
+				loginPageFilter = new LoginPageGeneratingWebFilter();
+				loginPageFilter.setOauth2AuthenticationUrlToClientName(this.oauth2Login.getLinks());
+			}
+			this.oauth2Login.configure(this);
+		}
+		if (loginPageFilter != null) {
+			this.authenticationEntryPoint = new RedirectServerAuthenticationEntryPoint("/login");
+			this.webFilters.add(new OrderedWebFilter(loginPageFilter, SecurityWebFiltersOrder.LOGIN_PAGE_GENERATING.getOrder()));
+			this.webFilters.add(new OrderedWebFilter(new LogoutPageGeneratingWebFilter(), SecurityWebFiltersOrder.LOGOUT_PAGE_GENERATING.getOrder()));
 		}
 		if(this.logout != null) {
 			this.logout.configure(this);
@@ -638,8 +647,8 @@ public class ServerHttpSecurity {
 				exceptionTranslationWebFilter.setAuthenticationEntryPoint(
 					authenticationEntryPoint);
 			}
-			if(accessDeniedHandler != null) {
-				exceptionTranslationWebFilter.setAccessDeniedHandler(accessDeniedHandler);
+			if(this.accessDeniedHandler != null) {
+				exceptionTranslationWebFilter.setAccessDeniedHandler(this.accessDeniedHandler);
 			}
 			this.addFilterAt(exceptionTranslationWebFilter, SecurityWebFiltersOrder.EXCEPTION_TRANSLATION);
 			this.authorizeExchange.configure(this);

+ 35 - 1
config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java

@@ -17,6 +17,8 @@
 package org.springframework.security.config.web.server;
 
 import org.junit.Test;
+import org.openqa.selenium.By;
+import org.openqa.selenium.NoSuchElementException;
 import org.openqa.selenium.WebDriver;
 import org.openqa.selenium.WebElement;
 import org.openqa.selenium.support.FindBy;
@@ -36,6 +38,8 @@ import org.springframework.web.server.ServerWebExchange;
 import reactor.core.publisher.Mono;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
 /**
  * @author Rob Winch
@@ -204,9 +208,10 @@ public class FormLoginTests {
 
 		private LoginForm loginForm;
 
+		private OAuth2Login oauth2Login = new OAuth2Login();
+
 		public DefaultLoginPage(WebDriver webDriver) {
 			this.driver = webDriver;
-			this.loginForm = PageFactory.initElements(webDriver, LoginForm.class);
 		}
 
 		static DefaultLoginPage create(WebDriver driver) {
@@ -228,10 +233,23 @@ public class FormLoginTests {
 			return this;
 		}
 
+		public DefaultLoginPage assertLoginFormNotPresent() {
+			assertThatThrownBy(() -> loginForm().username(""))
+					.isInstanceOf(NoSuchElementException.class);
+			return this;
+		}
+
 		public LoginForm loginForm() {
+			if (this.loginForm == null) {
+				this.loginForm = PageFactory.initElements(this.driver, LoginForm.class);
+			}
 			return this.loginForm;
 		}
 
+		public OAuth2Login oauth2Login() {
+			return this.oauth2Login;
+		}
+
 		static DefaultLoginPage to(WebDriver driver) {
 			driver.get("http://localhost/login");
 			return PageFactory.initElements(driver, DefaultLoginPage.class);
@@ -263,6 +281,22 @@ public class FormLoginTests {
 				return PageFactory.initElements(this.driver, page);
 			}
 		}
+
+		public class OAuth2Login {
+			public WebElement findClientRegistrationByName(String clientName) {
+				return DefaultLoginPage.this.driver.findElement(By.linkText(clientName));
+			}
+
+			public OAuth2Login assertClientRegistrationByName(String clientName) {
+				assertThatCode(() -> findClientRegistrationByName(clientName))
+						.doesNotThrowAnyException();
+				return this;
+			}
+
+			public DefaultLoginPage and() {
+				return DefaultLoginPage.this;
+			}
+		}
 	}
 
 	public static class DefaultLogoutPage {

+ 106 - 0
config/src/test/java/org/springframework/security/config/web/server/OAuth2LoginTests.java

@@ -0,0 +1,106 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.web.server;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import org.junit.Rule;
+import org.junit.Test;
+import org.openqa.selenium.WebDriver;
+import org.openqa.selenium.WebElement;
+import org.openqa.selenium.support.FindBy;
+import org.openqa.selenium.support.PageFactory;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.annotation.Bean;
+import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
+import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
+import org.springframework.security.config.oauth2.client.CommonOAuth2Provider;
+import org.springframework.security.config.test.SpringTestRule;
+import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.InMemoryReactiveClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
+import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
+import org.springframework.security.web.server.SecurityWebFilterChain;
+import org.springframework.security.web.server.WebFilterChainProxy;
+import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler;
+import org.springframework.security.web.server.csrf.CsrfToken;
+import org.springframework.stereotype.Controller;
+import org.springframework.test.web.reactive.server.WebTestClient;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.ResponseBody;
+import org.springframework.web.server.ServerWebExchange;
+
+import reactor.core.publisher.Mono;
+
+/**
+ * @author Rob Winch
+ * @since 5.1
+ */
+public class OAuth2LoginTests {
+
+	@Rule
+	public final SpringTestRule spring = new SpringTestRule();
+
+	@Autowired
+	private WebFilterChainProxy springSecurity;
+
+	private ClientRegistration github = CommonOAuth2Provider.GITHUB
+			.getBuilder("github")
+			.clientId("client")
+			.clientSecret("secret")
+			.build();
+
+	@Test
+	public void defaultLoginPageWithMultipleClientRegistrationsThenLinks() {
+		this.spring.register(OAuth2LoginWithMulitpleClientRegistrations.class).autowire();
+
+		WebTestClient webTestClient = WebTestClientBuilder
+				.bindToWebFilters(this.springSecurity)
+				.build();
+
+		WebDriver driver = WebTestClientHtmlUnitDriverBuilder
+				.webTestClientSetup(webTestClient)
+				.build();
+
+		FormLoginTests.DefaultLoginPage loginPage = FormLoginTests.HomePage
+				.to(driver, FormLoginTests.DefaultLoginPage.class)
+				.assertAt()
+				.assertLoginFormNotPresent()
+				.oauth2Login()
+					.assertClientRegistrationByName(this.github.getClientName())
+					.and();
+	}
+
+	@EnableWebFluxSecurity
+	static class OAuth2LoginWithMulitpleClientRegistrations {
+		@Bean
+		InMemoryReactiveClientRegistrationRepository clientRegistrationRepository() {
+			ClientRegistration github = CommonOAuth2Provider.GITHUB
+					.getBuilder("github")
+					.clientId("client")
+					.clientSecret("secret")
+					.build();
+			ClientRegistration google = CommonOAuth2Provider.GOOGLE
+					.getBuilder("google")
+					.clientId("client")
+					.clientSecret("secret")
+					.build();
+			return new InMemoryReactiveClientRegistrationRepository(github, google);
+		}
+	}
+}

+ 42 - 34
web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java

@@ -50,6 +50,12 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
 
 	private Map<String, String> oauth2AuthenticationUrlToClientName = new HashMap<>();
 
+	private boolean formLoginEnabled;
+
+	public void setFormLoginEnabled(boolean enabled) {
+		this.formLoginEnabled = enabled;
+	}
+
 	public void setOauth2AuthenticationUrlToClientName(
 			Map<String, String> oauth2AuthenticationUrlToClientName) {
 		Assert.notNull(oauth2AuthenticationUrlToClientName, "oauth2AuthenticationUrlToClientName cannot be null");
@@ -87,45 +93,47 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
 	private byte[] createPage(ServerWebExchange exchange, String csrfTokenHtmlInput) {
 		MultiValueMap<String, String> queryParams = exchange.getRequest()
 				.getQueryParams();
-		boolean isError = queryParams.containsKey("error");
-		boolean isLogoutSuccess = queryParams.containsKey("logout");
 		String contextPath = exchange.getRequest().getPath().contextPath().value();
-		String page =  "<!DOCTYPE html>\n"
-			+ "<html lang=\"en\">\n"
-			+ "  <head>\n"
-			+ "    <meta charset=\"utf-8\">\n"
-			+ "    <meta name=\"viewport\" content=\"width=device-width, initial-scale=1, shrink-to-fit=no\">\n"
-			+ "    <meta name=\"description\" content=\"\">\n"
-			+ "    <meta name=\"author\" content=\"\">\n"
-			+ "    <title>Please sign in</title>\n"
-			+ "    <link href=\"https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0-beta/css/bootstrap.min.css\" rel=\"stylesheet\" integrity=\"sha384-/Y6pD6FV/Vv2HJnA6t+vslU6fwYXjCFtcEpHbNJ0lyAFsXTsjBbfaDjzALeQsN6M\" crossorigin=\"anonymous\">\n"
-			+ "    <link href=\"http://getbootstrap.com/docs/4.0/examples/signin/signin.css\" rel=\"stylesheet\" crossorigin=\"anonymous\"/>\n"
-			+ "  </head>\n"
-			+ "  <body>\n"
-			+ "     <div class=\"container\">\n"
-			+ "      <form class=\"form-signin\" method=\"post\" action=\"/login\">\n"
-			+ "        <h2 class=\"form-signin-heading\">Please sign in</h2>\n"
-			+ createError(isError)
-			+ createLogoutSuccess(isLogoutSuccess)
-			+ "        <p>\n"
-			+ "          <label for=\"username\" class=\"sr-only\">Username</label>\n"
-			+ "          <input type=\"text\" id=\"username\" name=\"username\" class=\"form-control\" placeholder=\"Username\" required autofocus>\n"
-			+ "        </p>\n"
-			+ "        <p>\n"
-			+ "          <label for=\"password\" class=\"sr-only\">Password</label>\n"
-			+ "          <input type=\"password\" id=\"password\" name=\"password\" class=\"form-control\" placeholder=\"Password\" required>\n"
-			+ "        </p>\n"
-			+ csrfTokenHtmlInput
-			+ "        <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Sign in</button>\n"
-			+ "      </form>\n"
-			+ oauth2LoginLinks(contextPath, this.oauth2AuthenticationUrlToClientName)
-			+ "    </div>\n"
-			+ "  </body>\n"
-			+ "</html>";
+		String page = "<!DOCTYPE html>\n" + "<html lang=\"en\">\n" + "  <head>\n"
+				+ "    <meta charset=\"utf-8\">\n"
+				+ "    <meta name=\"viewport\" content=\"width=device-width, initial-scale=1, shrink-to-fit=no\">\n"
+				+ "    <meta name=\"description\" content=\"\">\n"
+				+ "    <meta name=\"author\" content=\"\">\n"
+				+ "    <title>Please sign in</title>\n"
+				+ "    <link href=\"https://maxcdn.bootstrapcdn.com/bootstrap/4.0.0-beta/css/bootstrap.min.css\" rel=\"stylesheet\" integrity=\"sha384-/Y6pD6FV/Vv2HJnA6t+vslU6fwYXjCFtcEpHbNJ0lyAFsXTsjBbfaDjzALeQsN6M\" crossorigin=\"anonymous\">\n"
+				+ "    <link href=\"http://getbootstrap.com/docs/4.0/examples/signin/signin.css\" rel=\"stylesheet\" crossorigin=\"anonymous\"/>\n"
+				+ "  </head>\n"
+				+ "  <body>\n"
+				+ "     <div class=\"container\">\n"
+				+ formLogin(queryParams, csrfTokenHtmlInput)
+				+ oauth2LoginLinks(contextPath, this.oauth2AuthenticationUrlToClientName)
+				+ "    </div>\n"
+				+ "  </body>\n"
+				+ "</html>";
 
 		return page.getBytes(Charset.defaultCharset());
 	}
 
+	private String formLogin(MultiValueMap<String, String> queryParams, String csrfTokenHtmlInput) {
+		if (!this.formLoginEnabled) {
+			return "";
+		}
+		boolean isError = queryParams.containsKey("error");
+		boolean isLogoutSuccess = queryParams.containsKey("logout");
+		return "      <form class=\"form-signin\" method=\"post\" action=\"/login\">\n"
+				+ "        <h2 class=\"form-signin-heading\">Please sign in</h2>\n"
+				+ createError(isError) + createLogoutSuccess(isLogoutSuccess)
+				+ "        <p>\n"
+				+ "          <label for=\"username\" class=\"sr-only\">Username</label>\n"
+				+ "          <input type=\"text\" id=\"username\" name=\"username\" class=\"form-control\" placeholder=\"Username\" required autofocus>\n"
+				+ "        </p>\n" + "        <p>\n"
+				+ "          <label for=\"password\" class=\"sr-only\">Password</label>\n"
+				+ "          <input type=\"password\" id=\"password\" name=\"password\" class=\"form-control\" placeholder=\"Password\" required>\n"
+				+ "        </p>\n" + csrfTokenHtmlInput
+				+ "        <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Sign in</button>\n"
+				+ "      </form>\n";
+	}
+
 	private static String oauth2LoginLinks(String contextPath, Map<String, String> oauth2AuthenticationUrlToClientName) {
 		if (oauth2AuthenticationUrlToClientName.isEmpty()) {
 			return "";