Browse Source

Add query parameter support for authn requests

Closes gh-15017
Josh Cummings 1 năm trước cách đây
mục cha
commit
796e4d6b6c

+ 82 - 12
config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java

@@ -16,9 +16,13 @@
 
 package org.springframework.security.config.annotation.web.configurers.saml2;
 
+import java.util.ArrayList;
 import java.util.LinkedHashMap;
+import java.util.List;
 import java.util.Map;
 
+import jakarta.servlet.http.HttpServletRequest;
+
 import org.springframework.beans.factory.NoSuchBeanDefinitionException;
 import org.springframework.context.ApplicationContext;
 import org.springframework.security.authentication.AuthenticationManager;
@@ -33,6 +37,7 @@ import org.springframework.security.saml2.provider.service.authentication.Abstra
 import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrations;
 import org.springframework.security.saml2.provider.service.web.HttpSessionSaml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.web.OpenSamlAuthenticationTokenConverter;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestRepository;
@@ -50,6 +55,7 @@ import org.springframework.security.web.util.matcher.AndRequestMatcher;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
 import org.springframework.security.web.util.matcher.OrRequestMatcher;
+import org.springframework.security.web.util.matcher.ParameterRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestHeaderRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatchers;
@@ -111,7 +117,13 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 
 	private String loginPage;
 
-	private String authenticationRequestUri = Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI;
+	private String authenticationRequestUri = "/saml2/authenticate";
+
+	private String[] authenticationRequestParams = { "registrationId={registrationId}" };
+
+	private RequestMatcher authenticationRequestMatcher = RequestMatchers.anyOf(
+			new AntPathRequestMatcher(Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI),
+			new AntPathQueryRequestMatcher(this.authenticationRequestUri, this.authenticationRequestParams));
 
 	private Saml2AuthenticationRequestResolver authenticationRequestResolver;
 
@@ -196,11 +208,31 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 	 * Request
 	 * @return the {@link Saml2LoginConfigurer} for further configuration
 	 * @since 6.0
+	 * @deprecated Use {@link #authenticationRequestUriQuery} instead
 	 */
 	public Saml2LoginConfigurer<B> authenticationRequestUri(String authenticationRequestUri) {
-		Assert.state(authenticationRequestUri.contains("{registrationId}"),
-				"authenticationRequestUri must contain {registrationId} path variable");
-		this.authenticationRequestUri = authenticationRequestUri;
+		return authenticationRequestUriQuery(authenticationRequestUri);
+	}
+
+	/**
+	 * Customize the URL that the SAML Authentication Request will be sent to. This method
+	 * also supports query parameters like so: <pre>
+	 * 	authenticationRequestUriQuery("/saml/authenticate?registrationId={registrationId}")
+	 * </pre> {@link RelyingPartyRegistrations}
+	 * @param authenticationRequestUriQuery the URI and query to use for the SAML 2.0
+	 * Authentication Request
+	 * @return the {@link Saml2LoginConfigurer} for further configuration
+	 * @since 6.0
+	 */
+	public Saml2LoginConfigurer<B> authenticationRequestUriQuery(String authenticationRequestUriQuery) {
+		Assert.state(authenticationRequestUriQuery.contains("{registrationId}"),
+				"authenticationRequestUri must contain {registrationId} path variable or query value");
+		String[] parts = authenticationRequestUriQuery.split("[?&]");
+		this.authenticationRequestUri = parts[0];
+		this.authenticationRequestParams = new String[parts.length - 1];
+		System.arraycopy(parts, 1, this.authenticationRequestParams, 0, parts.length - 1);
+		this.authenticationRequestMatcher = new AntPathQueryRequestMatcher(this.authenticationRequestUri,
+				this.authenticationRequestParams);
 		return this;
 	}
 
@@ -255,7 +287,7 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 		}
 		else {
 			Map<String, String> providerUrlMap = getIdentityProviderUrlMap(this.authenticationRequestUri,
-					this.relyingPartyRegistrationRepository);
+					this.authenticationRequestParams, this.relyingPartyRegistrationRepository);
 			boolean singleProvider = providerUrlMap.size() == 1;
 			if (singleProvider) {
 				// Setup auto-redirect to provider login page
@@ -336,8 +368,7 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 		}
 		OpenSaml4AuthenticationRequestResolver openSaml4AuthenticationRequestResolver = new OpenSaml4AuthenticationRequestResolver(
 				relyingPartyRegistrationRepository(http));
-		openSaml4AuthenticationRequestResolver
-			.setRequestMatcher(new AntPathRequestMatcher(this.authenticationRequestUri));
+		openSaml4AuthenticationRequestResolver.setRequestMatcher(this.authenticationRequestMatcher);
 		return openSaml4AuthenticationRequestResolver;
 	}
 
@@ -382,20 +413,28 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 			return;
 		}
 		loginPageGeneratingFilter.setSaml2LoginEnabled(true);
-		loginPageGeneratingFilter.setSaml2AuthenticationUrlToProviderName(
-				this.getIdentityProviderUrlMap(this.authenticationRequestUri, this.relyingPartyRegistrationRepository));
+		loginPageGeneratingFilter
+			.setSaml2AuthenticationUrlToProviderName(this.getIdentityProviderUrlMap(this.authenticationRequestUri,
+					this.authenticationRequestParams, this.relyingPartyRegistrationRepository));
 		loginPageGeneratingFilter.setLoginPageUrl(this.getLoginPage());
 		loginPageGeneratingFilter.setFailureUrl(this.getFailureUrl());
 	}
 
 	@SuppressWarnings("unchecked")
-	private Map<String, String> getIdentityProviderUrlMap(String authRequestPrefixUrl,
+	private Map<String, String> getIdentityProviderUrlMap(String authRequestPrefixUrl, String[] authRequestQueryParams,
 			RelyingPartyRegistrationRepository idpRepo) {
 		Map<String, String> idps = new LinkedHashMap<>();
 		if (idpRepo instanceof Iterable) {
 			Iterable<RelyingPartyRegistration> repo = (Iterable<RelyingPartyRegistration>) idpRepo;
-			repo.forEach((p) -> idps.put(authRequestPrefixUrl.replace("{registrationId}", p.getRegistrationId()),
-					p.getRegistrationId()));
+			StringBuilder authRequestQuery = new StringBuilder("?");
+			for (String authRequestQueryParam : authRequestQueryParams) {
+				authRequestQuery.append(authRequestQueryParam + "&");
+			}
+			authRequestQuery.deleteCharAt(authRequestQuery.length() - 1);
+			String authenticationRequestUriQuery = authRequestPrefixUrl + authRequestQuery;
+			repo.forEach(
+					(p) -> idps.put(authenticationRequestUriQuery.replace("{registrationId}", p.getRegistrationId()),
+							p.getRegistrationId()));
 		}
 		return idps;
 	}
@@ -437,4 +476,35 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 		}
 	}
 
+	static class AntPathQueryRequestMatcher implements RequestMatcher {
+
+		private final RequestMatcher matcher;
+
+		AntPathQueryRequestMatcher(String path, String... params) {
+			List<RequestMatcher> matchers = new ArrayList<>();
+			matchers.add(new AntPathRequestMatcher(path));
+			for (String param : params) {
+				String[] parts = param.split("=");
+				if (parts.length == 1) {
+					matchers.add(new ParameterRequestMatcher(parts[0]));
+				}
+				else {
+					matchers.add(new ParameterRequestMatcher(parts[0], parts[1]));
+				}
+			}
+			this.matcher = new AndRequestMatcher(matchers);
+		}
+
+		@Override
+		public boolean matches(HttpServletRequest request) {
+			return matcher(request).isMatch();
+		}
+
+		@Override
+		public MatchResult matcher(HttpServletRequest request) {
+			return this.matcher.matcher(request);
+		}
+
+	}
+
 }

+ 4 - 0
config/src/main/kotlin/org/springframework/security/config/annotation/web/Saml2Dsl.kt

@@ -48,6 +48,7 @@ import org.springframework.security.web.authentication.AuthenticationSuccessHand
 class Saml2Dsl {
     var relyingPartyRegistrationRepository: RelyingPartyRegistrationRepository? = null
     var loginPage: String? = null
+    var authenticationRequestUriQuery: String? = null
     var authenticationSuccessHandler: AuthenticationSuccessHandler? = null
     var authenticationFailureHandler: AuthenticationFailureHandler? = null
     var failureUrl: String? = null
@@ -88,6 +89,9 @@ class Saml2Dsl {
             defaultSuccessUrlOption?.also {
                 saml2Login.defaultSuccessUrl(defaultSuccessUrlOption!!.first, defaultSuccessUrlOption!!.second)
             }
+            authenticationRequestUriQuery?.also {
+                saml2Login.authenticationRequestUriQuery(authenticationRequestUriQuery)
+            }
             authenticationSuccessHandler?.also { saml2Login.successHandler(authenticationSuccessHandler) }
             authenticationFailureHandler?.also { saml2Login.failureHandler(authenticationFailureHandler) }
             authenticationManager?.also { saml2Login.authenticationManager(authenticationManager) }

+ 33 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurerTests.java

@@ -101,6 +101,7 @@ import org.springframework.web.util.UriComponents;
 import org.springframework.web.util.UriComponentsBuilder;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.hamcrest.Matchers.startsWith;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.atLeastOnce;
@@ -113,6 +114,7 @@ import static org.springframework.security.config.annotation.SecurityContextChan
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.content;
+import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.header;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.redirectedUrl;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
 
@@ -343,6 +345,19 @@ public class Saml2LoginConfigurerTests {
 				any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
+	@Test
+	public void authenticationRequestWhenCustomAuthenticationRequestPathRepositoryThenUses() throws Exception {
+		this.spring.register(CustomAuthenticationRequestUriQuery.class).autowire();
+		MockHttpServletRequestBuilder request = get("/custom/auth/sso");
+		this.mvc.perform(request)
+			.andExpect(status().isFound())
+			.andExpect(redirectedUrl("http://localhost/custom/auth/sso?entityId=registration-id"));
+		request.queryParam("entityId", registration.getRegistrationId());
+		MvcResult result = this.mvc.perform(request).andExpect(status().isFound()).andReturn();
+		String redirectedUrl = result.getResponse().getRedirectedUrl();
+		assertThat(redirectedUrl).startsWith(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation());
+	}
+
 	@Test
 	public void saml2LoginWhenLoginProcessingUrlWithoutRegistrationIdAndDefaultAuthenticationConverterThenAutowires()
 			throws Exception {
@@ -390,7 +405,7 @@ public class Saml2LoginConfigurerTests {
 			.andExpect(redirectedUrl("http://localhost/login"));
 		this.mvc.perform(get("/").accept(MediaType.TEXT_HTML))
 			.andExpect(status().isFound())
-			.andExpect(redirectedUrl("http://localhost/saml2/authenticate/registration-id"));
+			.andExpect(header().string("Location", startsWith("http://localhost/saml2/authenticate")));
 	}
 
 	@Test
@@ -669,6 +684,23 @@ public class Saml2LoginConfigurerTests {
 
 	}
 
+	@Configuration
+	@EnableWebSecurity
+	@Import(Saml2LoginConfigBeans.class)
+	static class CustomAuthenticationRequestUriQuery {
+
+		@Bean
+		SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+					.authorizeHttpRequests((authz) -> authz.anyRequest().authenticated())
+					.saml2Login((saml2) -> saml2.authenticationRequestUriQuery("/custom/auth/sso?entityId={registrationId}"));
+			// @formatter:on
+			return http.build();
+		}
+
+	}
+
 	@Configuration
 	@EnableWebSecurity
 	@Import(Saml2LoginConfigBeans.class)

+ 42 - 1
config/src/test/kotlin/org/springframework/security/config/annotation/web/Saml2DslTests.kt

@@ -43,11 +43,13 @@ import org.springframework.security.saml2.provider.service.registration.TestRely
 import org.springframework.security.saml2.provider.service.web.authentication.Saml2WebSsoAuthenticationFilter
 import org.springframework.security.web.SecurityFilterChain
 import org.springframework.test.web.servlet.MockMvc
+import org.springframework.test.web.servlet.MvcResult
 import org.springframework.test.web.servlet.get
 import org.springframework.test.web.servlet.request.MockMvcRequestBuilders
+import org.springframework.test.web.servlet.result.MockMvcResultMatchers
 import java.security.cert.Certificate
 import java.security.cert.CertificateFactory
-import java.util.Base64
+import java.util.*
 
 /**
  * Tests for [Saml2Dsl]
@@ -136,6 +138,23 @@ class Saml2DslTests {
         verify(exactly = 1) { Saml2LoginCustomAuthenticationManagerConfig.AUTHENTICATION_MANAGER.authenticate(any()) }
     }
 
+    @Test
+    @Throws(Exception::class)
+    fun authenticationRequestWhenCustomAuthenticationRequestPathRepositoryThenUses() {
+        this.spring.register(CustomAuthenticationRequestUriQuery::class.java).autowire()
+        val registration = TestRelyingPartyRegistrations.relyingPartyRegistration().build();
+        val request = MockMvcRequestBuilders.get("/custom/auth/sso")
+        this.mockMvc.perform(request)
+            .andExpect(MockMvcResultMatchers.status().isFound())
+            .andExpect(MockMvcResultMatchers.redirectedUrl("http://localhost/custom/auth/sso?entityId=simplesamlphp"))
+        request.queryParam("entityId", registration.registrationId)
+        val result: MvcResult =
+            this.mockMvc.perform(request).andExpect(MockMvcResultMatchers.status().isFound()).andReturn()
+        val redirectedUrl = result.response.redirectedUrl
+        Assertions.assertThat(redirectedUrl)
+            .startsWith(registration.assertingPartyDetails.singleSignOnServiceLocation)
+    }
+
     @Configuration
     @EnableWebSecurity
     open class Saml2LoginCustomAuthenticationManagerConfig {
@@ -162,4 +181,26 @@ class Saml2DslTests {
             return repository
         }
     }
+
+    @Configuration
+    @EnableWebSecurity
+    open class CustomAuthenticationRequestUriQuery {
+        @Bean
+        open fun securityFilterChain(http: HttpSecurity): SecurityFilterChain {
+            http {
+                authorizeHttpRequests {
+                    authorize(anyRequest, authenticated)
+                }
+                saml2Login {
+                    authenticationRequestUriQuery = "/custom/auth/sso?entityId={registrationId}"
+                }
+            }
+            return http.build()
+        }
+
+        @Bean
+        open fun relyingPartyRegistrationRepository(): RelyingPartyRegistrationRepository? {
+            return InMemoryRelyingPartyRegistrationRepository(TestRelyingPartyRegistrations.relyingPartyRegistration().build())
+        }
+    }
 }

+ 37 - 1
docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc

@@ -4,7 +4,7 @@
 As stated earlier, Spring Security's SAML 2.0 support produces a `<saml2:AuthnRequest>` to commence authentication with the asserting party.
 
 Spring Security achieves this in part by registering the `Saml2WebSsoAuthenticationRequestFilter` in the filter chain.
-This filter by default responds to endpoint `+/saml2/authenticate/{registrationId}+`.
+This filter by default responds to the endpoints `+/saml2/authenticate/{registrationId}+` and `+/saml2/authenticate?registrationId={registrationId}+`.
 
 For example, if you were deployed to `https://rp.example.com` and you gave your registration an ID of `okta`, you could navigate to:
 
@@ -12,6 +12,42 @@ For example, if you were deployed to `https://rp.example.com` and you gave your
 
 and the result would be a redirect that included a `SAMLRequest` parameter containing the signed, deflated, and encoded `<saml2:AuthnRequest>`.
 
+== Configuring the `<saml2:AuthnRequest>` Endpoint
+
+To configure the endpoint differently from the default, you can set the value in `saml2Login`:
+
+[tabs]
+======
+Java::
++
+[source,java,role="primary"]
+----
+@Bean
+SecurityFilterChain filterChain(HttpSecurity http) {
+	http
+        .saml2Login((saml2) -> saml2
+            .authenticationRequestUriQuery("/custom/auth/sso?peerEntityID={registrationId}")
+        );
+	return new CustomSaml2AuthenticationRequestRepository();
+}
+----
+
+Kotlin::
++
+[source,kotlin,role="secondary"]
+----
+@Bean
+fun filterChain(http: HttpSecurity): SecurityFilterChain {
+    http {
+        saml2Login {
+            authenticationRequestUriQuery = "/custom/auth/sso?peerEntityID={registrationId}"
+        }
+    }
+    return CustomSaml2AuthenticationRequestRepository()
+}
+----
+======
+
 [[servlet-saml2login-store-authn-request]]
 == Changing How the `<saml2:AuthnRequest>` Gets Stored
 

+ 39 - 2
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java

@@ -17,6 +17,8 @@
 package org.springframework.security.saml2.provider.service.web.authentication;
 
 import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.List;
 import java.util.Map;
 import java.util.UUID;
 import java.util.function.BiConsumer;
@@ -50,8 +52,11 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
 import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers;
 import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver;
 import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
+import org.springframework.security.web.util.matcher.AndRequestMatcher;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
+import org.springframework.security.web.util.matcher.ParameterRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
+import org.springframework.security.web.util.matcher.RequestMatchers;
 import org.springframework.util.Assert;
 
 /**
@@ -75,8 +80,9 @@ class OpenSamlAuthenticationRequestResolver {
 
 	private final NameIDPolicyBuilder nameIdPolicyBuilder;
 
-	private RequestMatcher requestMatcher = new AntPathRequestMatcher(
-			Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI);
+	private RequestMatcher requestMatcher = RequestMatchers.anyOf(
+			new AntPathRequestMatcher(Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI),
+			new AntPathQueryRequestMatcher("/saml2/authenticate", "registrationId={registrationId}"));
 
 	private Converter<HttpServletRequest, String> relayStateResolver = (request) -> UUID.randomUUID().toString();
 
@@ -199,4 +205,35 @@ class OpenSamlAuthenticationRequestResolver {
 		}
 	}
 
+	private static final class AntPathQueryRequestMatcher implements RequestMatcher {
+
+		private final RequestMatcher matcher;
+
+		AntPathQueryRequestMatcher(String path, String... params) {
+			List<RequestMatcher> matchers = new ArrayList<>();
+			matchers.add(new AntPathRequestMatcher(path));
+			for (String param : params) {
+				String[] parts = param.split("=");
+				if (parts.length == 1) {
+					matchers.add(new ParameterRequestMatcher(parts[0]));
+				}
+				else {
+					matchers.add(new ParameterRequestMatcher(parts[0], parts[1]));
+				}
+			}
+			this.matcher = new AndRequestMatcher(matchers);
+		}
+
+		@Override
+		public boolean matches(HttpServletRequest request) {
+			return matcher(request).isMatch();
+		}
+
+		@Override
+		public MatchResult matcher(HttpServletRequest request) {
+			return this.matcher.matcher(request);
+		}
+
+	}
+
 }