Browse Source

Add RelyingPartyRegistrationResolver

Closes gh-9486
Josh Cummings 4 years ago
parent
commit
6488295cad
14 changed files with 239 additions and 44 deletions
  1. 6 3
      config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java
  2. 14 11
      docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc
  3. 4 1
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java
  4. 4 1
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java
  5. 31 17
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java
  6. 13 0
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java
  7. 40 0
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/RelyingPartyRegistrationResolver.java
  8. 14 0
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java
  9. 23 7
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java
  10. 34 0
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java
  11. 31 0
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java
  12. 4 1
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java
  13. 2 1
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java
  14. 19 2
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java

+ 6 - 3
config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java

@@ -47,6 +47,7 @@ import org.springframework.security.saml2.provider.service.servlet.filter.Saml2W
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
 import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
@@ -264,7 +265,8 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 	private AuthenticationConverter getAuthenticationConverter(B http) {
 		if (this.authenticationConverter == null) {
 			return new Saml2AuthenticationTokenConverter(
-					new DefaultRelyingPartyRegistrationResolver(this.relyingPartyRegistrationRepository));
+					(RelyingPartyRegistrationResolver) new DefaultRelyingPartyRegistrationResolver(
+							this.relyingPartyRegistrationRepository));
 		}
 		return this.authenticationConverter;
 	}
@@ -390,8 +392,9 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>>
 			Saml2AuthenticationRequestContextResolver resolver = getBeanOrNull(http,
 					Saml2AuthenticationRequestContextResolver.class);
 			if (resolver == null) {
-				return new DefaultSaml2AuthenticationRequestContextResolver(new DefaultRelyingPartyRegistrationResolver(
-						Saml2LoginConfigurer.this.relyingPartyRegistrationRepository));
+				RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(
+						Saml2LoginConfigurer.this.relyingPartyRegistrationRepository);
+				return new DefaultSaml2AuthenticationRequestContextResolver(relyingPartyRegistrationResolver);
 			}
 			return resolver;
 		}

+ 14 - 11
docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc

@@ -727,7 +727,7 @@ There are a number of reasons you may want to customize. Among them:
 * You may know that you will never be a multi-tenant application and so want to have a simpler URL scheme
 * You may identify tenants in a way other than by the URI path
 
-To customize the way that a `RelyingPartyRegistration` is resolved, you can configure a custom `Converter<HttpServletRequest, RelyingPartyRegistration>`.
+To customize the way that a `RelyingPartyRegistration` is resolved, you can configure a custom `RelyingPartyRegistrationResolver`.
 The default looks up the registration id from the URI's last path element and looks it up in your `RelyingPartyRegistrationRepository`.
 
 You can provide a simpler resolver that, for example, always returns the same relying party:
@@ -736,12 +736,17 @@ You can provide a simpler resolver that, for example, always returns the same re
 .Java
 [source,java,role="primary"]
 ----
-public class SingleRelyingPartyRegistrationResolver
-        implements Converter<HttpServletRequest, RelyingPartyRegistration> {
+public class SingleRelyingPartyRegistrationResolver implements RelyingPartyRegistrationResolver {
+
+    private final RelyingPartyRegistrationResolver delegate;
+
+    public SingleRelyingPartyRegistrationResolver(RelyingPartyRegistrationRepository registrations) {
+        this.delegate = new DefaultRelyingPartyRegistrationResolver(registrations);
+    }
 
     @Override
-    public RelyingPartyRegistration convert(HttpServletRequest request) {
-        return this.relyingParty;
+    public RelyingPartyRegistration resolve(HttpServletRequest request, String registrationId) {
+        return this.delegate.resolve(request, "single");
     }
 }
 ----
@@ -749,9 +754,9 @@ public class SingleRelyingPartyRegistrationResolver
 .Kotlin
 [source,kotlin,role="secondary"]
 ----
-class SingleRelyingPartyRegistrationResolver : Converter<HttpServletRequest?, RelyingPartyRegistration?> {
-    override fun convert(request: HttpServletRequest?): RelyingPartyRegistration? {
-        return this.relyingParty
+class SingleRelyingPartyRegistrationResolver(delegate: RelyingPartyRegistrationResolver) : RelyingPartyRegistrationResolver {
+    override fun resolve(request: HttpServletRequest?, registrationId: String?): RelyingPartyRegistration? {
+        return this.delegate.resolve(request, "single")
     }
 }
 ----
@@ -1544,7 +1549,7 @@ You can publish a metadata endpoint by adding the `Saml2MetadataFilter` to the f
 .Java
 [source,java,role="primary"]
 ----
-Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver =
+DefaultRelyingPartyRegistrationResolver relyingPartyRegistrationResolver =
         new DefaultRelyingPartyRegistrationResolver(this.relyingPartyRegistrationRepository);
 Saml2MetadataFilter filter = new Saml2MetadataFilter(
         relyingPartyRegistrationResolver,
@@ -1594,8 +1599,6 @@ filter.setRequestMatcher(AntPathRequestMatcher("/saml2/metadata/{registrationId}
 ----
 ====
 
-ensuring that the `registrationId` hint is at the end of the path.
-
 Or, if you have registered a custom relying party registration resolver in the constructor, then you can specify a path without a `registrationId` hint, like so:
 
 ====

+ 4 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilter.java

@@ -29,6 +29,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
 import org.springframework.security.saml2.provider.service.servlet.HttpSessionSaml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
 import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
@@ -67,7 +68,9 @@ public class Saml2WebSsoAuthenticationFilter extends AbstractAuthenticationProce
 	public Saml2WebSsoAuthenticationFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
 			String filterProcessesUrl) {
 		this(new Saml2AuthenticationTokenConverter(
-				new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)), filterProcessesUrl);
+				(RelyingPartyRegistrationResolver) new DefaultRelyingPartyRegistrationResolver(
+						relyingPartyRegistrationRepository)),
+				filterProcessesUrl);
 	}
 
 	/**

+ 4 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilter.java

@@ -39,6 +39,7 @@ import org.springframework.security.saml2.provider.service.servlet.HttpSessionSa
 import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
 import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -96,7 +97,9 @@ public class Saml2WebSsoAuthenticationRequestFilter extends OncePerRequestFilter
 	public Saml2WebSsoAuthenticationRequestFilter(
 			RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
 		this(new DefaultSaml2AuthenticationRequestContextResolver(
-				new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository)), requestFactory());
+				(RelyingPartyRegistrationResolver) new DefaultRelyingPartyRegistrationResolver(
+						relyingPartyRegistrationRepository)),
+				requestFactory());
 	}
 
 	private static Saml2AuthenticationRequestFactory requestFactory() {

+ 31 - 17
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java

@@ -22,6 +22,9 @@ import java.util.function.Function;
 
 import javax.servlet.http.HttpServletRequest;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
@@ -42,13 +45,15 @@ import org.springframework.web.util.UriComponentsBuilder;
  * @since 5.4
  */
 public final class DefaultRelyingPartyRegistrationResolver
-		implements Converter<HttpServletRequest, RelyingPartyRegistration> {
+		implements RelyingPartyRegistrationResolver, Converter<HttpServletRequest, RelyingPartyRegistration> {
+
+	private Log logger = LogFactory.getLog(getClass());
 
 	private static final char PATH_DELIMITER = '/';
 
 	private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
 
-	private final Converter<HttpServletRequest, String> registrationIdResolver = new RegistrationIdResolver();
+	private final RequestMatcher registrationRequestMatcher = new AntPathRequestMatcher("/**/{registrationId}");
 
 	public DefaultRelyingPartyRegistrationResolver(
 			RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
@@ -56,14 +61,35 @@ public final class DefaultRelyingPartyRegistrationResolver
 		this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
 	}
 
+	/**
+	 * {@inheritDoc}
+	 */
 	@Override
 	public RelyingPartyRegistration convert(HttpServletRequest request) {
-		String registrationId = this.registrationIdResolver.convert(request);
-		if (registrationId == null) {
+		return resolve(request, null);
+	}
+
+	/**
+	 * {@inheritDoc}
+	 */
+	@Override
+	public RelyingPartyRegistration resolve(HttpServletRequest request, String relyingPartyRegistrationId) {
+		if (relyingPartyRegistrationId == null) {
+			if (this.logger.isTraceEnabled()) {
+				this.logger.trace("Attempting to resolve from " + this.registrationRequestMatcher
+						+ " since registrationId is null");
+			}
+			relyingPartyRegistrationId = this.registrationRequestMatcher.matcher(request).getVariables()
+					.get("registrationId");
+		}
+		if (relyingPartyRegistrationId == null) {
+			if (this.logger.isTraceEnabled()) {
+				this.logger.trace("Returning null registration since registrationId is null");
+			}
 			return null;
 		}
 		RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationRepository
-				.findByRegistrationId(registrationId);
+				.findByRegistrationId(relyingPartyRegistrationId);
 		if (relyingPartyRegistration == null) {
 			return null;
 		}
@@ -111,16 +137,4 @@ public final class DefaultRelyingPartyRegistrationResolver
 		return uriComponents.toUriString();
 	}
 
-	private static class RegistrationIdResolver implements Converter<HttpServletRequest, String> {
-
-		private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/**/{registrationId}");
-
-		@Override
-		public String convert(HttpServletRequest request) {
-			RequestMatcher.MatchResult result = this.requestMatcher.matcher(request);
-			return result.getVariables().get("registrationId");
-		}
-
-	}
-
 }

+ 13 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolver.java

@@ -42,11 +42,24 @@ public final class DefaultSaml2AuthenticationRequestContextResolver
 
 	private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
 
+	/**
+	 * Construct a {@link DefaultSaml2AuthenticationRequestContextResolver}
+	 * @param relyingPartyRegistrationResolver
+	 * @deprecated Use
+	 * {@link DefaultSaml2AuthenticationRequestContextResolver#DefaultSaml2AuthenticationRequestContextResolver(RelyingPartyRegistrationResolver)}
+	 * instead
+	 */
+	@Deprecated
 	public DefaultSaml2AuthenticationRequestContextResolver(
 			Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
 		this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
 	}
 
+	public DefaultSaml2AuthenticationRequestContextResolver(
+			RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
+		this.relyingPartyRegistrationResolver = (request) -> relyingPartyRegistrationResolver.resolve(request, null);
+	}
+
 	@Override
 	public Saml2AuthenticationRequestContext resolve(HttpServletRequest request) {
 		Assert.notNull(request, "request cannot be null");

+ 40 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/RelyingPartyRegistrationResolver.java

@@ -0,0 +1,40 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web;
+
+import javax.servlet.http.HttpServletRequest;
+
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+
+/**
+ * A contract for resolving a {@link RelyingPartyRegistration} from the HTTP request
+ *
+ * @author Josh Cummings
+ * @since 5.6
+ */
+public interface RelyingPartyRegistrationResolver {
+
+	/**
+	 * Resolve a {@link RelyingPartyRegistration} from the HTTP request, using the
+	 * {@code relyingPartyRegistrationId}, if it is provided
+	 * @param request the HTTP request
+	 * @param relyingPartyRegistrationId the {@link RelyingPartyRegistration} identifier
+	 * @return the resolved {@link RelyingPartyRegistration}
+	 */
+	RelyingPartyRegistration resolve(HttpServletRequest request, String relyingPartyRegistrationId);
+
+}

+ 14 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java

@@ -61,7 +61,11 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 	 * resolving {@link RelyingPartyRegistration}s
 	 * @param relyingPartyRegistrationResolver the strategy for resolving
 	 * {@link RelyingPartyRegistration}s
+	 * @deprecated Use
+	 * {@link Saml2AuthenticationTokenConverter#Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver)}
+	 * instead
 	 */
+	@Deprecated
 	public Saml2AuthenticationTokenConverter(
 			Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
 		Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
@@ -69,6 +73,16 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 		this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
 	}
 
+	public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
+		this(adaptToConverter(relyingPartyRegistrationResolver));
+	}
+
+	private static Converter<HttpServletRequest, RelyingPartyRegistration> adaptToConverter(
+			RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
+		Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
+		return (request) -> relyingPartyRegistrationResolver.resolve(request, null);
+	}
+
 	@Override
 	public Saml2AuthenticationToken convert(HttpServletRequest request) {
 		RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.convert(request);

+ 23 - 7
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java

@@ -46,7 +46,7 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
 
 	public static final String DEFAULT_METADATA_FILE_NAME = "saml-{registrationId}-metadata.xml";
 
-	private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationConverter;
+	private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
 
 	private final Saml2MetadataResolver saml2MetadataResolver;
 
@@ -55,11 +55,26 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
 	private RequestMatcher requestMatcher = new AntPathRequestMatcher(
 			"/saml2/service-provider-metadata/{registrationId}");
 
-	public Saml2MetadataFilter(
-			Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationConverter,
+	/**
+	 * Construct a {@link Saml2MetadataFilter}
+	 * @param relyingPartyRegistrationResolver
+	 * @param saml2MetadataResolver
+	 * @deprecated Use
+	 * {@link Saml2MetadataFilter#Saml2MetadataFilter(RelyingPartyRegistrationResolver)}
+	 * instead
+	 */
+	@Deprecated
+	public Saml2MetadataFilter(Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver,
 			Saml2MetadataResolver saml2MetadataResolver) {
+		this.relyingPartyRegistrationResolver = (request, id) -> relyingPartyRegistrationResolver.convert(request);
+		this.saml2MetadataResolver = saml2MetadataResolver;
+	}
 
-		this.relyingPartyRegistrationConverter = relyingPartyRegistrationConverter;
+	public Saml2MetadataFilter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver,
+			Saml2MetadataResolver saml2MetadataResolver) {
+		Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
+		Assert.notNull(saml2MetadataResolver, "saml2MetadataResolver cannot be null");
+		this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
 		this.saml2MetadataResolver = saml2MetadataResolver;
 	}
 
@@ -71,14 +86,15 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
 			chain.doFilter(request, response);
 			return;
 		}
-		RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationConverter.convert(request);
+		String registrationId = matcher.getVariables().get("registrationId");
+		RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request,
+				registrationId);
 		if (relyingPartyRegistration == null) {
 			response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
 			return;
 		}
 		String metadata = this.saml2MetadataResolver.resolve(relyingPartyRegistration);
-		String registrationId = relyingPartyRegistration.getRegistrationId();
-		writeMetadataToResponse(response, registrationId, metadata);
+		writeMetadataToResponse(response, relyingPartyRegistration.getRegistrationId(), metadata);
 	}
 
 	private void writeMetadataToResponse(HttpServletResponse response, String registrationId, String metadata)

+ 34 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationFilterTests.java

@@ -22,15 +22,25 @@ import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
+import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.AuthenticationManager;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
 import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationException;
 import org.springframework.security.saml2.provider.service.authentication.TestSaml2AuthenticationTokens;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
+import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
 import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
+import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
 import org.springframework.security.web.authentication.AuthenticationConverter;
+import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
+import org.springframework.security.web.util.matcher.RequestMatcher;
 
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -49,6 +59,8 @@ public class Saml2WebSsoAuthenticationFilterTests {
 
 	private HttpServletResponse response = new MockHttpServletResponse();
 
+	private AuthenticationManager authenticationManager = mock(AuthenticationManager.class);
+
 	@BeforeEach
 	public void setup() {
 		this.filter = new Saml2WebSsoAuthenticationFilter(this.repository);
@@ -132,4 +144,26 @@ public class Saml2WebSsoAuthenticationFilterTests {
 		verifyNoInteractions(authenticationConverter);
 	}
 
+	@Test
+	public void doFilterWhenPathStartsWithRegistrationIdThenAuthenticates() throws Exception {
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
+		Authentication authentication = new TestingAuthenticationToken("user", "password");
+		given(this.repository.findByRegistrationId("registration-id")).willReturn(registration);
+		given(this.authenticationManager.authenticate(authentication)).willReturn(authentication);
+		String loginProcessingUrl = "/{registrationId}/login/saml2/sso";
+		RequestMatcher matcher = new AntPathRequestMatcher(loginProcessingUrl);
+		DefaultRelyingPartyRegistrationResolver delegate = new DefaultRelyingPartyRegistrationResolver(this.repository);
+		RelyingPartyRegistrationResolver resolver = (request, id) -> {
+			String registrationId = matcher.matcher(request).getVariables().get("registrationId");
+			return delegate.resolve(request, registrationId);
+		};
+		Saml2AuthenticationTokenConverter authenticationConverter = new Saml2AuthenticationTokenConverter(resolver);
+		this.filter = new Saml2WebSsoAuthenticationFilter(authenticationConverter, loginProcessingUrl);
+		this.filter.setAuthenticationManager(this.authenticationManager);
+		this.request.setPathInfo("/registration-id/login/saml2/sso");
+		this.request.setParameter("SAMLResponse", "response");
+		this.filter.doFilter(this.request, this.response, new MockFilterChain());
+		verify(this.repository).findByRegistrationId("registration-id");
+	}
+
 }

+ 31 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java

@@ -37,8 +37,14 @@ import org.springframework.security.saml2.provider.service.authentication.TestSa
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
+import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
 import org.springframework.security.saml2.provider.service.servlet.Saml2AuthenticationRequestRepository;
+import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
+import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
+import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
+import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.web.util.HtmlUtils;
 import org.springframework.web.util.UriUtils;
 
@@ -256,4 +262,29 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
 				any(Saml2PostAuthenticationRequest.class), eq(this.request), eq(this.response));
 	}
 
+	@Test
+	public void doFilterWhenPathStartsWithRegistrationIdThenPosts() throws Exception {
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full()
+				.assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST)).build();
+		RequestMatcher matcher = new AntPathRequestMatcher("/{registrationId}/saml2/authenticate");
+		DefaultRelyingPartyRegistrationResolver delegate = new DefaultRelyingPartyRegistrationResolver(this.repository);
+		RelyingPartyRegistrationResolver resolver = (request, id) -> {
+			String registrationId = matcher.matcher(request).getVariables().get("registrationId");
+			return delegate.resolve(request, registrationId);
+		};
+		Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver(
+				resolver);
+		Saml2PostAuthenticationRequest authenticationRequest = mock(Saml2PostAuthenticationRequest.class);
+		given(authenticationRequest.getAuthenticationRequestUri()).willReturn("uri");
+		given(authenticationRequest.getRelayState()).willReturn("relay");
+		given(authenticationRequest.getSamlRequest()).willReturn("saml");
+		given(this.repository.findByRegistrationId("registration-id")).willReturn(registration);
+		given(this.factory.createPostAuthenticationRequest(any())).willReturn(authenticationRequest);
+		this.filter = new Saml2WebSsoAuthenticationRequestFilter(authenticationRequestContextResolver, this.factory);
+		this.filter.setRedirectMatcher(matcher);
+		this.request.setPathInfo("/registration-id/saml2/authenticate");
+		this.filter.doFilter(this.request, this.response, new MockFilterChain());
+		verify(this.repository).findByRegistrationId("registration-id");
+	}
+
 }

+ 4 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultSaml2AuthenticationRequestContextResolverTests.java

@@ -49,8 +49,11 @@ public class DefaultSaml2AuthenticationRequestContextResolverTests {
 
 	private RelyingPartyRegistration.Builder relyingPartyBuilder;
 
+	private RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(
+			(id) -> this.relyingPartyBuilder.build());
+
 	private Saml2AuthenticationRequestContextResolver authenticationRequestContextResolver = new DefaultSaml2AuthenticationRequestContextResolver(
-			new DefaultRelyingPartyRegistrationResolver((id) -> this.relyingPartyBuilder.build()));
+			this.relyingPartyRegistrationResolver);
 
 	@BeforeEach
 	public void setup() {

+ 2 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java

@@ -176,7 +176,8 @@ public class Saml2AuthenticationTokenConverterTests {
 
 	@Test
 	public void constructorWhenResolverIsNullThenIllegalArgument() {
-		assertThatIllegalArgumentException().isThrownBy(() -> new Saml2AuthenticationTokenConverter(null));
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new Saml2AuthenticationTokenConverter((RelyingPartyRegistrationResolver) null));
 	}
 
 	@Test

+ 19 - 2
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java

@@ -25,6 +25,7 @@ import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
 import org.springframework.http.HttpHeaders;
+import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.saml2.core.TestSaml2X509Credentials;
@@ -37,6 +38,7 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
@@ -63,8 +65,9 @@ public class Saml2MetadataFilterTests {
 	public void setup() {
 		this.repository = mock(RelyingPartyRegistrationRepository.class);
 		this.resolver = mock(Saml2MetadataResolver.class);
-		this.filter = new Saml2MetadataFilter(new DefaultRelyingPartyRegistrationResolver(this.repository),
-				this.resolver);
+		RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(
+				this.repository);
+		this.filter = new Saml2MetadataFilter(relyingPartyRegistrationResolver, this.resolver);
 		this.request = new MockHttpServletRequest();
 		this.response = new MockHttpServletResponse();
 		this.chain = mock(FilterChain.class);
@@ -136,6 +139,20 @@ public class Saml2MetadataFilterTests {
 				.isEqualTo("attachment; filename=\"%s\"; filename*=UTF-8''%s", fileName, encodedFileName);
 	}
 
+	@Test
+	public void doFilterWhenPathStartsWithRegistrationIdThenServesMetadata() throws Exception {
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
+		given(this.repository.findByRegistrationId("registration-id")).willReturn(registration);
+		given(this.resolver.resolve(any())).willReturn("metadata");
+		RelyingPartyRegistrationResolver resolver = new DefaultRelyingPartyRegistrationResolver(
+				(id) -> this.repository.findByRegistrationId("registration-id"));
+		this.filter = new Saml2MetadataFilter(resolver, this.resolver);
+		this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata"));
+		this.request.setPathInfo("/metadata");
+		this.filter.doFilter(this.request, this.response, new MockFilterChain());
+		verify(this.repository).findByRegistrationId("registration-id");
+	}
+
 	@Test
 	public void setRequestMatcherWhenNullThenIllegalArgument() {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequestMatcher(null));