Browse Source

Polish Default Login Page

Issue gh-17901
Josh Cummings 2 months ago
parent
commit
50ebd467c3

+ 8 - 6
config/src/test/java/org/springframework/security/config/annotation/web/configurers/FormLoginConfigurerTests.java

@@ -402,7 +402,7 @@ public class FormLoginConfigurerTests {
 		UserDetails user = PasswordEncodedUser.user();
 		this.mockMvc.perform(get("/profile").with(user(user)))
 			.andExpect(status().is3xxRedirection())
-			.andExpect(redirectedUrl("http://localhost/login?authority=FACTOR_PASSWORD"));
+			.andExpect(redirectedUrl("http://localhost/login?factor=password"));
 		this.mockMvc
 			.perform(post("/ott/generate").param("username", "rod")
 				.with(user(user))
@@ -418,11 +418,11 @@ public class FormLoginConfigurerTests {
 		user = PasswordEncodedUser.withUserDetails(user).authorities("profile:read", "FACTOR_OTT").build();
 		this.mockMvc.perform(get("/profile").with(user(user)))
 			.andExpect(status().is3xxRedirection())
-			.andExpect(redirectedUrl("http://localhost/login?authority=FACTOR_PASSWORD"));
+			.andExpect(redirectedUrl("http://localhost/login?factor=password"));
 		user = PasswordEncodedUser.withUserDetails(user).authorities("profile:read", "FACTOR_PASSWORD").build();
 		this.mockMvc.perform(get("/profile").with(user(user)))
 			.andExpect(status().is3xxRedirection())
-			.andExpect(redirectedUrl("http://localhost/login?authority=FACTOR_OTT"));
+			.andExpect(redirectedUrl("http://localhost/login?factor=ott"));
 		user = PasswordEncodedUser.withUserDetails(user)
 			.authorities("profile:read", "FACTOR_PASSWORD", "FACTOR_OTT")
 			.build();
@@ -438,7 +438,7 @@ public class FormLoginConfigurerTests {
 		this.mockMvc.perform(get("/login")).andExpect(status().isOk());
 		this.mockMvc.perform(get("/profile").with(SecurityMockMvcRequestPostProcessors.x509("rod.cer")))
 			.andExpect(status().is3xxRedirection())
-			.andExpect(redirectedUrl("http://localhost/login?authority=FACTOR_PASSWORD"));
+			.andExpect(redirectedUrl("http://localhost/login?factor=password"));
 		this.mockMvc
 			.perform(post("/login").param("username", "rod")
 				.param("password", "password")
@@ -793,7 +793,8 @@ public class FormLoginConfigurerTests {
 	static class MfaDslConfig {
 
 		@Bean
-		SecurityFilterChain filterChain(HttpSecurity http, AuthorizationManagerFactory<RequestAuthorizationContext> authz) throws Exception {
+		SecurityFilterChain filterChain(HttpSecurity http,
+				AuthorizationManagerFactory<RequestAuthorizationContext> authz) throws Exception {
 			// @formatter:off
 			http
 				.formLogin(Customizer.withDefaults())
@@ -824,7 +825,8 @@ public class FormLoginConfigurerTests {
 	static class MfaDslX509Config {
 
 		@Bean
-		SecurityFilterChain filterChain(HttpSecurity http, AuthorizationManagerFactory<RequestAuthorizationContext> authz) throws Exception {
+		SecurityFilterChain filterChain(HttpSecurity http,
+				AuthorizationManagerFactory<RequestAuthorizationContext> authz) throws Exception {
 			// @formatter:off
 			http
 				.x509(Customizer.withDefaults())

+ 23 - 5
web/src/main/java/org/springframework/security/web/authentication/LoginUrlAuthenticationEntryPoint.java

@@ -18,6 +18,7 @@ package org.springframework.security.web.authentication;
 
 import java.io.IOException;
 import java.util.Collection;
+import java.util.Locale;
 
 import jakarta.servlet.RequestDispatcher;
 import jakarta.servlet.ServletException;
@@ -41,6 +42,7 @@ import org.springframework.security.web.access.ExceptionTranslationFilter;
 import org.springframework.security.web.util.RedirectUrlBuilder;
 import org.springframework.security.web.util.UrlUtils;
 import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
 import org.springframework.util.StringUtils;
 import org.springframework.web.util.UriComponentsBuilder;
 
@@ -71,6 +73,8 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin
 
 	private static final Log logger = LogFactory.getLog(LoginUrlAuthenticationEntryPoint.class);
 
+	private static final String FACTOR_PREFIX = "FACTOR_";
+
 	private PortMapper portMapper = new PortMapperImpl();
 
 	private String loginFormUrl;
@@ -110,15 +114,29 @@ public class LoginUrlAuthenticationEntryPoint implements AuthenticationEntryPoin
 	 * @param exception the exception
 	 * @return the URL (cannot be null or empty; defaults to {@link #getLoginFormUrl()})
 	 */
+	@SuppressWarnings("unchecked")
 	protected String determineUrlToUseForThisRequest(HttpServletRequest request, HttpServletResponse response,
 			AuthenticationException exception) {
+		Collection<GrantedAuthority> authorities = getAttribute(request, GrantedAuthority.MISSING_AUTHORITIES_ATTRIBUTE,
+				Collection.class);
+		if (CollectionUtils.isEmpty(authorities)) {
+			return getLoginFormUrl();
+		}
+		Collection<String> factors = authorities.stream()
+			.filter((a) -> a.getAuthority().startsWith(FACTOR_PREFIX))
+			.map((a) -> a.getAuthority().substring(FACTOR_PREFIX.length()).toLowerCase(Locale.ROOT))
+			.toList();
+		return UriComponentsBuilder.fromUriString(getLoginFormUrl()).queryParam("factor", factors).toUriString();
+	}
+
+	private static <T> @Nullable T getAttribute(HttpServletRequest request, String name, Class<T> clazz) {
 		Object value = request.getAttribute(GrantedAuthority.MISSING_AUTHORITIES_ATTRIBUTE);
-		if (value instanceof Collection<?> authorities) {
-			return UriComponentsBuilder.fromUriString(getLoginFormUrl())
-				.queryParam("authority", authorities)
-				.toUriString();
+		if (value == null) {
+			return null;
 		}
-		return getLoginFormUrl();
+		String message = String.format("Found %s in %s, but expecting a %s", value.getClass(), name, clazz);
+		Assert.isInstanceOf(clazz, value, message);
+		return (T) value;
 	}
 
 	/**

+ 9 - 7
web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java

@@ -88,7 +88,9 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
 
 	private @Nullable String rememberMeParameter;
 
-	private final Collection<String> allowedParameters = List.of("authority");
+	private final String factorParameter = "factor";
+
+	private final Collection<String> allowedParameters = List.of(this.factorParameter);
 
 	@SuppressWarnings("NullAway.Init")
 	private Map<String, String> oauth2AuthenticationUrlToClientName;
@@ -257,29 +259,29 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean {
 			.withRawHtml("passkeyLogin", "");
 
 		Predicate<String> wantsAuthority = wantsAuthority(request);
-		if (wantsAuthority.test("FACTOR_WEBAUTHN")) {
+		if (wantsAuthority.test("webauthn")) {
 			builder.withRawHtml("javaScript", renderJavaScript(request, contextPath))
 				.withRawHtml("passkeyLogin", renderPasskeyLogin());
 		}
-		if (wantsAuthority.test("FACTOR_PASSWORD")) {
+		if (wantsAuthority.test("password")) {
 			builder.withRawHtml("formLogin",
 					renderFormLogin(request, loginError, logoutSuccess, contextPath, errorMsg));
 		}
-		if (wantsAuthority.test("FACTOR_OTT")) {
+		if (wantsAuthority.test("ott")) {
 			builder.withRawHtml("oneTimeTokenLogin",
 					renderOneTimeTokenLogin(request, loginError, logoutSuccess, contextPath, errorMsg));
 		}
-		if (wantsAuthority.test("FACTOR_AUTHORIZATION_CODE")) {
+		if (wantsAuthority.test("authorization_code")) {
 			builder.withRawHtml("oauth2Login", renderOAuth2Login(loginError, logoutSuccess, errorMsg, contextPath));
 		}
-		if (wantsAuthority.test("FACTOR_SAML_RESPONSE")) {
+		if (wantsAuthority.test("saml_response")) {
 			builder.withRawHtml("saml2Login", renderSaml2Login(loginError, logoutSuccess, errorMsg, contextPath));
 		}
 		return builder.render();
 	}
 
 	private Predicate<String> wantsAuthority(HttpServletRequest request) {
-		String[] authorities = request.getParameterValues("authority");
+		String[] authorities = request.getParameterValues(this.factorParameter);
 		if (authorities == null) {
 			return (authority) -> true;
 		}

+ 3 - 4
web/src/test/java/org/springframework/security/web/authentication/DefaultLoginPageGeneratingFilterTests.java

@@ -204,7 +204,7 @@ public class DefaultLoginPageGeneratingFilterTests {
 		filter.setOneTimeTokenEnabled(true);
 		filter.setOneTimeTokenGenerationUrl("/ott/authenticate");
 		MockHttpServletResponse response = new MockHttpServletResponse();
-		filter.doFilter(TestMockHttpServletRequests.get("/login?authority=FACTOR_OTT").build(), response, this.chain);
+		filter.doFilter(TestMockHttpServletRequests.get("/login?factor=ott").build(), response, this.chain);
 		assertThat(response.getContentAsString()).contains("Request a One-Time Token");
 		assertThat(response.getContentAsString()).contains("""
 				      <form id="ott-form" class="login-form" method="post" action="/ott/authenticate">
@@ -231,9 +231,8 @@ public class DefaultLoginPageGeneratingFilterTests {
 		filter.setOneTimeTokenEnabled(true);
 		filter.setOneTimeTokenGenerationUrl("/ott/authenticate");
 		MockHttpServletResponse response = new MockHttpServletResponse();
-		filter.doFilter(
-				TestMockHttpServletRequests.get("/login?authority=FACTOR_OTT&authority=FACTOR_PASSWORD").build(),
-				response, this.chain);
+		filter.doFilter(TestMockHttpServletRequests.get("/login?factor=ott&factor=password").build(), response,
+				this.chain);
 		assertThat(response.getContentAsString()).contains("Request a One-Time Token");
 		assertThat(response.getContentAsString()).contains("""
 				      <form id="ott-form" class="login-form" method="post" action="/ott/authenticate">