2
0
Эх сурвалжийг харах

Complete SAML 2.0 SP Metadata Endpoint

Closes gh-8693
Josh Cummings 5 жил өмнө
parent
commit
b999faa5a0
15 өөрчлөгдсөн 371 нэмэгдсэн , 358 устгасан
  1. 0 3
      config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java
  2. 0 27
      config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java
  3. 1 1
      config/src/test/kotlin/org/springframework/security/config/web/servlet/Saml2DslTests.kt
  4. 151 0
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java
  5. 11 4
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/Saml2MetadataResolver.java
  6. 0 28
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java
  7. 0 161
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlMetadataResolver.java
  8. 48 29
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java
  9. 80 0
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java
  10. 1 1
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java
  11. 13 1
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/TestRelyingPartyRegistrations.java
  12. 0 67
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/OpenSamlMetadataResolverTest.java
  13. 64 34
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTest.java
  14. 1 1
      samples/javaconfig/saml2login/src/main/java/org/springframework/security/samples/config/SecurityConfig.java
  15. 1 1
      samples/javaconfig/saml2login/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java

+ 0 - 3
config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java

@@ -73,9 +73,6 @@ final class FilterComparator implements Comparator<Filter>, Serializable {
 		filterToOrder.put(
 			"org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestRedirectFilter",
 				order.next());
-		filterToOrder.put(
-				"org.springframework.security.saml2.provider.service.web.Saml2MetadataFilter",
-				order.next());
 		filterToOrder.put(
 				"org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter",
 				order.next());

+ 0 - 27
config/src/main/java/org/springframework/security/config/annotation/web/configurers/saml2/Saml2LoginConfigurer.java

@@ -38,11 +38,8 @@ import org.springframework.security.saml2.provider.service.servlet.filter.Saml2W
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationRequestFilter;
 import org.springframework.security.saml2.provider.service.web.DefaultRelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.DefaultSaml2AuthenticationRequestContextResolver;
-import org.springframework.security.saml2.provider.service.web.OpenSamlMetadataResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationRequestContextResolver;
 import org.springframework.security.saml2.provider.service.web.Saml2AuthenticationTokenConverter;
-import org.springframework.security.saml2.provider.service.web.Saml2MetadataFilter;
-import org.springframework.security.saml2.provider.service.web.Saml2MetadataResolver;
 import org.springframework.security.web.authentication.AuthenticationConverter;
 import org.springframework.security.web.authentication.LoginUrlAuthenticationEntryPoint;
 import org.springframework.security.web.authentication.ui.DefaultLoginPageGeneratingFilter;
@@ -113,15 +110,10 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
 	private RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
 
 	private AuthenticationConverter authenticationConverter;
-
-	private Saml2MetadataResolver saml2MetadataResolver;
-
 	private AuthenticationManager authenticationManager;
 
 	private Saml2WebSsoAuthenticationFilter saml2WebSsoAuthenticationFilter;
 
-	private Saml2MetadataFilter saml2MetadataFilter;
-
 	/**
 	 * Use this {@link AuthenticationConverter} when converting incoming requests to an {@link Authentication}.
 	 * By default the {@link Saml2AuthenticationTokenConverter} is used.
@@ -162,16 +154,6 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
 		return this;
 	}
 
-	/**
-	 * Sets the {@code Saml2MetadataResolver}
-	 * @param saml2MetadataResolver the implementation of the metadata resolver
-	 * @return the {@link Saml2LoginConfigurer} for further configuration
-	 */
-	public Saml2LoginConfigurer saml2MetadataResolver(Saml2MetadataResolver saml2MetadataResolver) {
-		this.saml2MetadataResolver = saml2MetadataResolver;
-		return this;
-	}
-
 	/**
 	 * {@inheritDoc}
 	 */
@@ -229,14 +211,6 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
 		setAuthenticationFilter(saml2WebSsoAuthenticationFilter);
 		super.loginProcessingUrl(this.loginProcessingUrl);
 
-		if (this.saml2MetadataResolver == null) {
-			this.saml2MetadataResolver = new OpenSamlMetadataResolver();
-		}
-
-		saml2MetadataFilter = new Saml2MetadataFilter(
-				this.relyingPartyRegistrationRepository, this.saml2MetadataResolver
-		);
-
 		if (hasText(this.loginPage)) {
 			// Set custom login page
 			super.loginPage(this.loginPage);
@@ -276,7 +250,6 @@ public final class Saml2LoginConfigurer<B extends HttpSecurityBuilder<B>> extend
 	@Override
 	public void configure(B http) throws Exception {
 		http.addFilter(this.authenticationRequestEndpoint.build(http));
-		http.addFilter(saml2MetadataFilter);
 		super.configure(http);
 		if (this.authenticationManager == null) {
 			registerDefaultAuthenticationProvider(http);

+ 1 - 1
config/src/test/kotlin/org/springframework/security/config/web/servlet/Saml2DslTests.kt

@@ -30,7 +30,7 @@ import org.springframework.security.saml2.credentials.Saml2X509Credential
 import org.springframework.security.saml2.credentials.Saml2X509Credential.Saml2X509CredentialType.VERIFICATION
 import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration
-import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationFilter
+import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter
 import org.springframework.test.web.servlet.MockMvc
 import org.springframework.test.web.servlet.get
 import java.security.cert.Certificate

+ 151 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java

@@ -0,0 +1,151 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.metadata;
+
+import java.security.cert.CertificateEncodingException;
+import java.util.ArrayList;
+import java.util.Base64;
+import java.util.Collection;
+import java.util.List;
+import javax.xml.namespace.QName;
+
+import net.shibboleth.utilities.java.support.xml.SerializeSupport;
+import org.opensaml.core.xml.XMLObjectBuilder;
+import org.opensaml.saml.common.xml.SAMLConstants;
+import org.opensaml.saml.saml2.metadata.AssertionConsumerService;
+import org.opensaml.saml.saml2.metadata.EntityDescriptor;
+import org.opensaml.saml.saml2.metadata.KeyDescriptor;
+import org.opensaml.saml.saml2.metadata.SPSSODescriptor;
+import org.opensaml.saml.saml2.metadata.impl.EntityDescriptorMarshaller;
+import org.opensaml.security.credential.UsageType;
+import org.opensaml.xmlsec.signature.KeyInfo;
+import org.opensaml.xmlsec.signature.X509Certificate;
+import org.opensaml.xmlsec.signature.X509Data;
+import org.w3c.dom.Element;
+
+import org.springframework.security.saml2.Saml2Exception;
+import org.springframework.security.saml2.core.OpenSamlInitializationService;
+import org.springframework.security.saml2.core.Saml2X509Credential;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.util.Assert;
+
+import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
+import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory;
+
+/**
+ * Resolves the SAML 2.0 Relying Party Metadata for a given {@link RelyingPartyRegistration}
+ * using the OpenSAML API.
+ *
+ * @author Jakub Kubrynski
+ * @author Josh Cummings
+ * @since 5.4
+ */
+public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
+	static {
+		OpenSamlInitializationService.initialize();
+	}
+
+	private final EntityDescriptorMarshaller entityDescriptorMarshaller;
+
+	public OpenSamlMetadataResolver() {
+		this.entityDescriptorMarshaller = (EntityDescriptorMarshaller)
+				getMarshallerFactory().getMarshaller(EntityDescriptor.DEFAULT_ELEMENT_NAME);
+		Assert.notNull(this.entityDescriptorMarshaller, "entityDescriptorMarshaller cannot be null");
+	}
+
+	/**
+	 * {@inheritDoc}
+	 */
+	@Override
+	public String resolve(RelyingPartyRegistration relyingPartyRegistration) {
+		EntityDescriptor entityDescriptor = build(EntityDescriptor.ELEMENT_QNAME);
+		entityDescriptor.setEntityID(relyingPartyRegistration.getEntityId());
+
+		SPSSODescriptor spSsoDescriptor = buildSpSsoDescriptor(relyingPartyRegistration);
+		entityDescriptor.getRoleDescriptors(SPSSODescriptor.DEFAULT_ELEMENT_NAME).add(spSsoDescriptor);
+
+		return serialize(entityDescriptor);
+	}
+
+	private SPSSODescriptor buildSpSsoDescriptor(RelyingPartyRegistration registration) {
+		SPSSODescriptor spSsoDescriptor = build(SPSSODescriptor.DEFAULT_ELEMENT_NAME);
+		spSsoDescriptor.addSupportedProtocol(SAMLConstants.SAML20P_NS);
+		spSsoDescriptor.setWantAssertionsSigned(true);
+		spSsoDescriptor.getKeyDescriptors().addAll(buildKeys(
+				registration.getSigningX509Credentials(), UsageType.SIGNING));
+		spSsoDescriptor.getKeyDescriptors().addAll(buildKeys(
+				registration.getDecryptionX509Credentials(), UsageType.ENCRYPTION));
+		spSsoDescriptor.getAssertionConsumerServices().add(buildAssertionConsumerService(registration));
+		return spSsoDescriptor;
+	}
+
+	private List<KeyDescriptor> buildKeys(Collection<Saml2X509Credential> credentials, UsageType usageType) {
+		List<KeyDescriptor> list = new ArrayList<>();
+		for (Saml2X509Credential credential : credentials) {
+			KeyDescriptor keyDescriptor = buildKeyDescriptor(usageType, credential.getCertificate());
+			list.add(keyDescriptor);
+		}
+		return list;
+	}
+
+	private KeyDescriptor buildKeyDescriptor(UsageType usageType, java.security.cert.X509Certificate certificate) {
+		KeyDescriptor keyDescriptor = build(KeyDescriptor.DEFAULT_ELEMENT_NAME);
+		KeyInfo keyInfo = build(KeyInfo.DEFAULT_ELEMENT_NAME);
+		X509Certificate x509Certificate = build(X509Certificate.DEFAULT_ELEMENT_NAME);
+		X509Data x509Data = build(X509Data.DEFAULT_ELEMENT_NAME);
+
+		try {
+			x509Certificate.setValue(new String(Base64.getEncoder().encode(certificate.getEncoded())));
+		} catch (CertificateEncodingException e) {
+			throw new Saml2Exception("Cannot encode certificate " + certificate.toString());
+		}
+
+		x509Data.getX509Certificates().add(x509Certificate);
+		keyInfo.getX509Datas().add(x509Data);
+
+		keyDescriptor.setUse(usageType);
+		keyDescriptor.setKeyInfo(keyInfo);
+		return keyDescriptor;
+	}
+
+	private AssertionConsumerService buildAssertionConsumerService(RelyingPartyRegistration registration) {
+		AssertionConsumerService assertionConsumerService = build(AssertionConsumerService.DEFAULT_ELEMENT_NAME);
+		assertionConsumerService.setLocation(registration.getAssertionConsumerServiceLocation());
+		assertionConsumerService.setBinding(registration.getAssertionConsumerServiceBinding().getUrn());
+		assertionConsumerService.setIndex(1);
+		return assertionConsumerService;
+	}
+
+	@SuppressWarnings("unchecked")
+	private <T> T build(QName elementName) {
+		XMLObjectBuilder<?> builder = getBuilderFactory().getBuilder(elementName);
+		if (builder == null) {
+			throw new Saml2Exception("Unable to resolve Builder for " + elementName);
+		}
+		return (T) builder.buildObject(elementName);
+	}
+
+
+	private String serialize(EntityDescriptor entityDescriptor) {
+		try {
+			Element element = this.entityDescriptorMarshaller.marshall(entityDescriptor);
+			return SerializeSupport.prettyPrintXML(element);
+		} catch (Exception e) {
+			throw new Saml2Exception(e);
+		}
+	}
+}

+ 11 - 4
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataResolver.java → saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/Saml2MetadataResolver.java

@@ -14,16 +14,23 @@
  * limitations under the License.
  */
 
-package org.springframework.security.saml2.provider.service.web;
+package org.springframework.security.saml2.provider.service.metadata;
 
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 
-import javax.servlet.http.HttpServletRequest;
-
 /**
+ * Resolves the SAML 2.0 Relying Party Metadata for a given {@link RelyingPartyRegistration}
+ *
  * @author Jakub Kubrynski
+ * @author Josh Cummings
  * @since 5.4
  */
 public interface Saml2MetadataResolver {
-	String resolveMetadata(HttpServletRequest request, RelyingPartyRegistration registration);
+	/**
+	 * Resolve the given relying party's metadata
+	 *
+	 * @param relyingPartyRegistration the relying party
+	 * @return the relying party's metadata
+	 */
+	String resolve(RelyingPartyRegistration relyingPartyRegistration);
 }

+ 0 - 28
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java

@@ -29,7 +29,6 @@ import java.util.function.Function;
 
 import org.springframework.security.saml2.core.Saml2X509Credential;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
-import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationFilter;
 import org.springframework.util.Assert;
 
 /**
@@ -361,7 +360,6 @@ public class RelyingPartyRegistration {
 					.encryptionX509Credentials(c -> c.addAll(registration.getAssertingPartyDetails().getEncryptionX509Credentials()))
 					.singleSignOnServiceLocation(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation())
 					.singleSignOnServiceBinding(registration.getAssertingPartyDetails().getSingleSignOnServiceBinding())
-					.nameIdFormat(registration.getAssertingPartyDetails().getNameIdFormat())
 				);
 	}
 
@@ -377,7 +375,6 @@ public class RelyingPartyRegistration {
 		private final Collection<Saml2X509Credential> verificationX509Credentials;
 		private final Collection<Saml2X509Credential> encryptionX509Credentials;
 		private final String singleSignOnServiceLocation;
-		private final String nameIdFormat;
 		private final Saml2MessageBinding singleSignOnServiceBinding;
 
 		private AssertingPartyDetails(
@@ -386,7 +383,6 @@ public class RelyingPartyRegistration {
 				Collection<Saml2X509Credential> verificationX509Credentials,
 				Collection<Saml2X509Credential> encryptionX509Credentials,
 				String singleSignOnServiceLocation,
-				String nameIdFormat,
 				Saml2MessageBinding singleSignOnServiceBinding) {
 
 			Assert.hasText(entityId, "entityId cannot be null or empty");
@@ -409,7 +405,6 @@ public class RelyingPartyRegistration {
 			this.verificationX509Credentials = verificationX509Credentials;
 			this.encryptionX509Credentials = encryptionX509Credentials;
 			this.singleSignOnServiceLocation = singleSignOnServiceLocation;
-			this.nameIdFormat = nameIdFormat;
 			this.singleSignOnServiceBinding = singleSignOnServiceBinding;
 		}
 
@@ -477,15 +472,6 @@ public class RelyingPartyRegistration {
 			return this.singleSignOnServiceLocation;
 		}
 
-		/**
-		 * Get the NameIDFormat setting, indicating which user property should be used as a NameID Format attribute
-		 *
-		 * @return the NameIdFormat value
-		 */
-		public String getNameIdFormat() {
-			return nameIdFormat;
-		}
-
 		/**
 		 * Get the
 		 * <a href="https://wiki.shibboleth.net/confluence/display/CONCEPT/MetadataForIdP#MetadataForIdP-SingleSign-OnServices">SingleSignOnService</a>
@@ -507,7 +493,6 @@ public class RelyingPartyRegistration {
 			private Collection<Saml2X509Credential> verificationX509Credentials = new HashSet<>();
 			private Collection<Saml2X509Credential> encryptionX509Credentials = new HashSet<>();
 			private String singleSignOnServiceLocation;
-			private String nameIdFormat = "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified";
 			private Saml2MessageBinding singleSignOnServiceBinding = Saml2MessageBinding.REDIRECT;
 
 			/**
@@ -577,18 +562,6 @@ public class RelyingPartyRegistration {
 				return this;
 			}
 
-			/**
-			 * Set the preference for name identifier returned by IdP.
-			 * See <a href="https://wiki.shibboleth.net/confluence/display/SHIB/NameIdentifierFormat">for possible values</a>
-			 *
-			 * @param nameIdFormat the name identifier
-			 * @return the {@link ProviderDetails.Builder} for further configuration
-			 */
-			public Builder nameIdFormat(String nameIdFormat) {
-				this.nameIdFormat = nameIdFormat;
-				return this;
-			}
-
 			/**
 			 * Set the
 			 * <a href="https://wiki.shibboleth.net/confluence/display/CONCEPT/MetadataForIdP#MetadataForIdP-SingleSign-OnServices">SingleSignOnService</a>
@@ -617,7 +590,6 @@ public class RelyingPartyRegistration {
 						this.verificationX509Credentials,
 						this.encryptionX509Credentials,
 						this.singleSignOnServiceLocation,
-						this.nameIdFormat,
 						this.singleSignOnServiceBinding
 				);
 			}

+ 0 - 161
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/OpenSamlMetadataResolver.java

@@ -1,161 +0,0 @@
-/*
- * Copyright 2002-2020 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.security.saml2.provider.service.web;
-
-import net.shibboleth.utilities.java.support.xml.SerializeSupport;
-import org.opensaml.core.xml.XMLObjectBuilder;
-import org.opensaml.core.xml.XMLObjectBuilderFactory;
-import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
-import org.opensaml.core.xml.io.Marshaller;
-import org.opensaml.saml.common.xml.SAMLConstants;
-import org.opensaml.saml.saml2.metadata.AssertionConsumerService;
-import org.opensaml.saml.saml2.metadata.EntityDescriptor;
-import org.opensaml.saml.saml2.metadata.KeyDescriptor;
-import org.opensaml.saml.saml2.metadata.NameIDFormat;
-import org.opensaml.saml.saml2.metadata.SPSSODescriptor;
-import org.opensaml.security.credential.UsageType;
-import org.opensaml.xmlsec.signature.KeyInfo;
-import org.opensaml.xmlsec.signature.X509Certificate;
-import org.opensaml.xmlsec.signature.X509Data;
-import org.springframework.security.saml2.Saml2Exception;
-import org.springframework.security.saml2.credentials.Saml2X509Credential;
-import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
-import org.springframework.security.saml2.provider.service.servlet.filter.Saml2ServletUtils;
-import org.w3c.dom.Element;
-
-import javax.servlet.http.HttpServletRequest;
-import javax.xml.namespace.QName;
-import java.security.cert.CertificateEncodingException;
-import java.util.ArrayList;
-import java.util.Base64;
-import java.util.List;
-
-/**
- * @author Jakub Kubrynski
- * @since 5.4
- */
-public class OpenSamlMetadataResolver implements Saml2MetadataResolver {
-
-	@Override
-	public String resolveMetadata(HttpServletRequest request, RelyingPartyRegistration registration) {
-
-		XMLObjectBuilderFactory builderFactory = XMLObjectProviderRegistrySupport.getBuilderFactory();
-
-		EntityDescriptor entityDescriptor = buildObject(builderFactory, EntityDescriptor.ELEMENT_QNAME);
-
-		entityDescriptor.setEntityID(
-				resolveTemplate(registration.getEntityId(), registration, request));
-
-		SPSSODescriptor spSsoDescriptor = buildSpSsoDescriptor(registration, builderFactory, request);
-		entityDescriptor.getRoleDescriptors(SPSSODescriptor.DEFAULT_ELEMENT_NAME).add(spSsoDescriptor);
-
-		return serializeToXmlString(entityDescriptor);
-	}
-
-	private String serializeToXmlString(EntityDescriptor entityDescriptor) {
-		Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(entityDescriptor);
-		if (marshaller == null) {
-			throw new Saml2Exception("Unable to resolve Marshaller");
-		}
-		Element element;
-		try {
-			element = marshaller.marshall(entityDescriptor);
-		} catch (Exception e) {
-			throw new Saml2Exception(e);
-		}
-		return SerializeSupport.prettyPrintXML(element);
-	}
-
-	private SPSSODescriptor buildSpSsoDescriptor(RelyingPartyRegistration registration,
-			XMLObjectBuilderFactory builderFactory, HttpServletRequest request) {
-
-		SPSSODescriptor spSsoDescriptor = buildObject(builderFactory, SPSSODescriptor.DEFAULT_ELEMENT_NAME);
-		spSsoDescriptor.setAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned());
-		spSsoDescriptor.setWantAssertionsSigned(true);
-		spSsoDescriptor.addSupportedProtocol(SAMLConstants.SAML20P_NS);
-
-		NameIDFormat nameIdFormat = buildObject(builderFactory, NameIDFormat.DEFAULT_ELEMENT_NAME);
-		nameIdFormat.setFormat(registration.getAssertingPartyDetails().getNameIdFormat());
-		spSsoDescriptor.getNameIDFormats().add(nameIdFormat);
-
-		spSsoDescriptor.getAssertionConsumerServices().add(
-				buildAssertionConsumerService(registration, builderFactory, request));
-
-		spSsoDescriptor.getKeyDescriptors().addAll(buildKeys(builderFactory,
-				registration.getSigningCredentials(), UsageType.SIGNING));
-		spSsoDescriptor.getKeyDescriptors().addAll(buildKeys(builderFactory,
-				registration.getEncryptionCredentials(), UsageType.ENCRYPTION));
-
-		return spSsoDescriptor;
-	}
-
-	private List<KeyDescriptor> buildKeys(XMLObjectBuilderFactory builderFactory,
-			List<Saml2X509Credential> credentials, UsageType usageType) {
-		List<KeyDescriptor> list = new ArrayList<>();
-		for (Saml2X509Credential credential : credentials) {
-			KeyDescriptor keyDescriptor = buildKeyDescriptor(builderFactory, usageType, credential.getCertificate());
-			list.add(keyDescriptor);
-		}
-		return list;
-	}
-
-	private KeyDescriptor buildKeyDescriptor(XMLObjectBuilderFactory builderFactory, UsageType usageType,
-			java.security.cert.X509Certificate certificate) {
-		KeyDescriptor keyDescriptor = buildObject(builderFactory, KeyDescriptor.DEFAULT_ELEMENT_NAME);
-		KeyInfo keyInfo = buildObject(builderFactory, KeyInfo.DEFAULT_ELEMENT_NAME);
-		X509Certificate x509Certificate = buildObject(builderFactory, X509Certificate.DEFAULT_ELEMENT_NAME);
-		X509Data x509Data = buildObject(builderFactory, X509Data.DEFAULT_ELEMENT_NAME);
-
-		try {
-			x509Certificate.setValue(new String(Base64.getEncoder().encode(certificate.getEncoded())));
-		} catch (CertificateEncodingException e) {
-			throw new Saml2Exception("Cannot encode certificate " + certificate.toString());
-		}
-
-		x509Data.getX509Certificates().add(x509Certificate);
-		keyInfo.getX509Datas().add(x509Data);
-
-		keyDescriptor.setUse(usageType);
-		keyDescriptor.setKeyInfo(keyInfo);
-		return keyDescriptor;
-	}
-
-	private AssertionConsumerService buildAssertionConsumerService(RelyingPartyRegistration registration,
-			XMLObjectBuilderFactory builderFactory, HttpServletRequest request) {
-		AssertionConsumerService assertionConsumerService = buildObject(builderFactory, AssertionConsumerService.DEFAULT_ELEMENT_NAME);
-
-		assertionConsumerService.setLocation(
-				resolveTemplate(registration.getAssertionConsumerServiceLocation(), registration, request));
-		assertionConsumerService.setBinding(registration.getAssertingPartyDetails().getSingleSignOnServiceBinding().getUrn());
-		assertionConsumerService.setIndex(1);
-		return assertionConsumerService;
-	}
-
-	@SuppressWarnings("unchecked")
-	private <T> T buildObject(XMLObjectBuilderFactory builderFactory, QName elementName) {
-		XMLObjectBuilder<?> builder = builderFactory.getBuilder(elementName);
-		if (builder == null) {
-			throw new Saml2Exception("Cannot build object - builder not defined for element " + elementName);
-		}
-		return (T) builder.buildObject(elementName);
-	}
-
-	private String resolveTemplate(String template, RelyingPartyRegistration registration, HttpServletRequest request) {
-		return Saml2ServletUtils.resolveUrlTemplate(template, Saml2ServletUtils.getApplicationUri(request), registration);
-	}
-
-}

+ 48 - 29
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java

@@ -16,66 +16,85 @@
 
 package org.springframework.security.saml2.provider.service.web;
 
+import java.io.IOException;
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.MediaType;
+import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResolver;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
-import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
+import org.springframework.util.Assert;
 import org.springframework.web.filter.OncePerRequestFilter;
 
-import javax.servlet.FilterChain;
-import javax.servlet.ServletException;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-import java.io.IOException;
-
 /**
- * This {@code Servlet} returns a generated Service Provider Metadata XML
+ * A {@link javax.servlet.Filter} that returns the metadata for a Relying Party
  *
- * @since 5.4
  * @author Jakub Kubrynski
+ * @author Josh Cummings
+ * @since 5.4
  */
-public class Saml2MetadataFilter extends OncePerRequestFilter {
+public final class Saml2MetadataFilter extends OncePerRequestFilter {
 
-	private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
+	private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationConverter;
 	private final Saml2MetadataResolver saml2MetadataResolver;
 
-	private RequestMatcher redirectMatcher = new AntPathRequestMatcher("/saml2/service-provider-metadata/{registrationId}");
+	private RequestMatcher requestMatcher = new AntPathRequestMatcher(
+			"/saml2/service-provider-metadata/{registrationId}");
+
+	public Saml2MetadataFilter(
+			Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationConverter,
+			Saml2MetadataResolver saml2MetadataResolver) {
 
-	public Saml2MetadataFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository, Saml2MetadataResolver saml2MetadataResolver) {
-		this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
+		this.relyingPartyRegistrationConverter = relyingPartyRegistrationConverter;
 		this.saml2MetadataResolver = saml2MetadataResolver;
 	}
 
 	@Override
-	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
+	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
+			throws ServletException, IOException {
 
-		RequestMatcher.MatchResult matcher = this.redirectMatcher.matcher(request);
+		RequestMatcher.MatchResult matcher = this.requestMatcher.matcher(request);
 		if (!matcher.isMatch()) {
-			filterChain.doFilter(request, response);
+			chain.doFilter(request, response);
 			return;
 		}
 
-		String registrationId = matcher.getVariables().get("registrationId");
-
-		RelyingPartyRegistration registration = relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
-
-		if (registration == null) {
+		RelyingPartyRegistration relyingPartyRegistration =
+				this.relyingPartyRegistrationConverter.convert(request);
+		if (relyingPartyRegistration == null) {
 			response.setStatus(HttpServletResponse.SC_UNAUTHORIZED);
 			return;
 		}
 
-		String xml = saml2MetadataResolver.resolveMetadata(request, registration);
-
-		writeMetadataToResponse(response, registrationId, xml);
+		String metadata = this.saml2MetadataResolver.resolve(relyingPartyRegistration);
+		String registrationId = relyingPartyRegistration.getRegistrationId();
+		writeMetadataToResponse(response, registrationId, metadata);
 	}
 
-	private void writeMetadataToResponse(HttpServletResponse response, String registrationId, String xml) throws IOException {
+	private void writeMetadataToResponse(HttpServletResponse response, String registrationId, String metadata)
+			throws IOException {
+
 		response.setContentType(MediaType.APPLICATION_XML_VALUE);
-		response.setHeader(HttpHeaders.CONTENT_DISPOSITION, "attachment; filename=\"saml-" + registrationId + "-metadata.xml\"");
-		response.setContentLength(xml.length());
-		response.getWriter().write(xml);
+		response.setHeader(HttpHeaders.CONTENT_DISPOSITION,
+				"attachment; filename=\"saml-" + registrationId + "-metadata.xml\"");
+		response.setContentLength(metadata.length());
+		response.getWriter().write(metadata);
 	}
 
+	/**
+	 * Set the {@link RequestMatcher} that determines whether this filter should
+	 * handle the incoming {@link HttpServletRequest}
+	 *
+	 * @param requestMatcher
+	 */
+	public void setRequestMatcher(RequestMatcher requestMatcher) {
+		Assert.notNull(requestMatcher, "requestMatcher cannot be null");
+		this.requestMatcher = requestMatcher;
+	}
 }

+ 80 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java

@@ -0,0 +1,80 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.metadata;
+
+import org.junit.Test;
+
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
+import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
+import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.full;
+import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
+
+/**
+ * Tests for {@link OpenSamlMetadataResolver}
+ */
+public class OpenSamlMetadataResolverTests {
+
+	@Test
+	public void resolveWhenRelyingPartyThenMetadataMatches() {
+		// given
+		RelyingPartyRegistration relyingPartyRegistration = full()
+				.assertionConsumerServiceBinding(REDIRECT)
+				.build();
+		OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver();
+
+		// when
+		String metadata = openSamlMetadataResolver.resolve(relyingPartyRegistration);
+
+		// then
+		assertThat(metadata)
+				.contains("<EntityDescriptor")
+				.contains("entityID=\"rp-entity-id\"")
+				.contains("WantAssertionsSigned=\"true\"")
+				.contains("<md:KeyDescriptor use=\"signing\">")
+				.contains("<md:KeyDescriptor use=\"encryption\">")
+				.contains("<ds:X509Certificate>MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBh")
+				.contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\"")
+				.contains("Location=\"https://rp.example.org/acs\" index=\"1\"");
+	}
+
+	@Test
+	public void resolveWhenRelyingPartyNoCredentialsThenMetadataMatches() {
+		// given
+		RelyingPartyRegistration relyingPartyRegistration = noCredentials()
+				.assertingPartyDetails(party -> party
+					.verificationX509Credentials(c -> c.add(relyingPartyVerifyingCredential()))
+				)
+				.build();
+		OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver();
+
+		// when
+		String metadata = openSamlMetadataResolver.resolve(relyingPartyRegistration);
+
+		// then
+		assertThat(metadata)
+				.contains("<EntityDescriptor")
+				.contains("entityID=\"rp-entity-id\"")
+				.contains("WantAssertionsSigned=\"true\"")
+				.doesNotContain("<md:KeyDescriptor use=\"signing\">")
+				.doesNotContain("<md:KeyDescriptor use=\"encryption\">")
+				.contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\"")
+				.contains("Location=\"https://rp.example.org/acs\" index=\"1\"");
+	}
+}

+ 1 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java

@@ -18,7 +18,7 @@ package org.springframework.security.saml2.provider.service.registration;
 
 import org.junit.Test;
 
-import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationFilter;
+import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;

+ 13 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/TestRelyingPartyRegistrations.java

@@ -16,8 +16,9 @@
 
 package org.springframework.security.saml2.provider.service.registration;
 
+import org.springframework.security.saml2.core.TestSaml2X509Credentials;
 import org.springframework.security.saml2.credentials.Saml2X509Credential;
-import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationFilter;
+import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
 
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
@@ -57,4 +58,15 @@ public class TestRelyingPartyRegistrations {
 						.singleSignOnServiceLocation("https://ap.example.org/sso")
 				);
 	}
+
+	public static RelyingPartyRegistration.Builder full() {
+		return noCredentials()
+				.signingX509Credentials(c -> c.add(TestSaml2X509Credentials.relyingPartySigningCredential()))
+				.decryptionX509Credentials(c -> c.add(TestSaml2X509Credentials.relyingPartyDecryptingCredential()))
+				.assertingPartyDetails(party -> party
+					.verificationX509Credentials(c -> c.add(
+							TestSaml2X509Credentials.relyingPartyVerifyingCredential())
+					)
+				);
+	}
 }

+ 0 - 67
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/OpenSamlMetadataResolverTest.java

@@ -1,67 +0,0 @@
-/*
- * Copyright 2002-2020 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.security.saml2.provider.service.web;
-
-import org.junit.Before;
-import org.junit.Test;
-import org.opensaml.core.config.InitializationException;
-import org.opensaml.core.config.InitializationService;
-import org.opensaml.saml.saml2.core.NameIDType;
-import org.springframework.mock.web.MockHttpServletRequest;
-import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
-import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
-
-import javax.servlet.http.HttpServletRequest;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
-
-public class OpenSamlMetadataResolverTest {
-
-	@Before
-	public void setUp() throws InitializationException {
-		InitializationService.initialize();
-	}
-
-	@Test
-	public void shouldGenerateMetadata() {
-		// given
-		OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver();
-		RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.relyingPartyRegistration()
-				.assertingPartyDetails(p -> p.singleSignOnServiceBinding(REDIRECT))
-				.assertingPartyDetails(p -> p.wantAuthnRequestsSigned(true))
-				.assertingPartyDetails(p -> p.nameIdFormat(NameIDType.EMAIL))
-				.build();
-		HttpServletRequest servletRequestMock = new MockHttpServletRequest();
-
-		// when
-		String metadataXml = openSamlMetadataResolver.resolveMetadata(servletRequestMock, relyingPartyRegistration);
-
-		// then
-		assertThat(metadataXml)
-				.contains("<EntityDescriptor")
-				.contains("entityID=\"http://localhost/saml2/service-provider-metadata/simplesamlphp\"")
-				.contains("AuthnRequestsSigned=\"true\"")
-				.contains("WantAssertionsSigned=\"true\"")
-				.contains("<md:KeyDescriptor use=\"signing\">")
-				.contains("<ds:X509Certificate>MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBh")
-				.contains("<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>")
-				.contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\"")
-				.contains("Location=\"http://localhost/login/saml2/sso/simplesamlphp\" index=\"1\"");
-	}
-
-}

+ 64 - 34
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTest.java

@@ -20,95 +20,125 @@ import org.junit.Before;
 import org.junit.Test;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResolver;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
-import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
+import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 
 import javax.servlet.FilterChain;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.when;
+import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
+import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
 
+/**
+ * Tests for {@link Saml2MetadataFilter}
+ */
 public class Saml2MetadataFilterTest {
 
 	RelyingPartyRegistrationRepository repository;
-	Saml2MetadataResolver saml2MetadataResolver;
+	Saml2MetadataResolver resolver;
 	Saml2MetadataFilter filter;
 	MockHttpServletRequest request;
 	MockHttpServletResponse response;
-	FilterChain filterChain;
+	FilterChain chain;
 
 	@Before
 	public void setup() {
-		repository = mock(RelyingPartyRegistrationRepository.class);
-		saml2MetadataResolver = mock(Saml2MetadataResolver.class);
-		filter = new Saml2MetadataFilter(repository, saml2MetadataResolver);
-		request = new MockHttpServletRequest();
-		response = new MockHttpServletResponse();
-		filterChain = mock(FilterChain.class);
+		this.repository = mock(RelyingPartyRegistrationRepository.class);
+		this.resolver = mock(Saml2MetadataResolver.class);
+		this.filter = new Saml2MetadataFilter(
+				new DefaultRelyingPartyRegistrationResolver(this.repository), this.resolver);
+		this.request = new MockHttpServletRequest();
+		this.response = new MockHttpServletResponse();
+		this.chain = mock(FilterChain.class);
 	}
 
 	@Test
-	public void shouldReturnValueWhenMatcherSucceed() throws Exception {
+	public void doFilterWhenMatcherSucceedsThenResolverInvoked() throws Exception {
 		// given
-		request.setPathInfo("/saml2/service-provider-metadata/registration-id");
+		this.request.setPathInfo("/saml2/service-provider-metadata/registration-id");
 
 		// when
-		filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(this.request, this.response, this.chain);
 
 		// then
-		verifyNoInteractions(filterChain);
+		verifyNoInteractions(this.chain);
+		verify(this.repository).findByRegistrationId("registration-id");
 	}
 
 	@Test
-	public void shouldProcessFilterChainIfMatcherFails() throws Exception {
+	public void doFilterWhenMatcherFailsThenProcessesFilterChain() throws Exception {
 		// given
-		request.setPathInfo("/saml2/authenticate/registration-id");
+		this.request.setPathInfo("/saml2/authenticate/registration-id");
 
 		// when
-		filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(this.request, this.response, this.chain);
 
 		// then
-		verify(filterChain).doFilter(request, response);
+		verify(this.chain).doFilter(this.request, this.response);
 	}
 
 	@Test
-	public void shouldReturn401IfNoRegistrationIsFound() throws Exception {
+	public void doFilterWhenNoRelyingPartyRegistrationThenUnauthorized() throws Exception {
 		// given
-		request.setPathInfo("/saml2/service-provider-metadata/invalidRegistration");
-		when(repository.findByRegistrationId("invalidRegistration")).thenReturn(null);
+		this.request.setPathInfo("/saml2/service-provider-metadata/invalidRegistration");
+		when(this.repository.findByRegistrationId("invalidRegistration")).thenReturn(null);
 
 		// when
-		filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(this.request, this.response, this.chain);
 
 		// then
-		verifyNoInteractions(filterChain);
-		assertThat(response.getStatus()).isEqualTo(401);
+		verifyNoInteractions(this.chain);
+		assertThat(this.response.getStatus()).isEqualTo(401);
 	}
 
 	@Test
-	public void shouldInvokeMetadataGenerationIfRegistrationIsFound() throws Exception {
+	public void doFilterWhenRelyingPartyRegistrationFoundThenInvokesMetadataResolver() throws Exception {
 		// given
-		request.setPathInfo("/saml2/service-provider-metadata/validRegistration");
-		RelyingPartyRegistration validRegistration = TestRelyingPartyRegistrations.relyingPartyRegistration().build();
-		when(repository.findByRegistrationId("validRegistration")).thenReturn(validRegistration);
+		this.request.setPathInfo("/saml2/service-provider-metadata/validRegistration");
+		RelyingPartyRegistration validRegistration = noCredentials()
+				.assertingPartyDetails(party -> party
+						.verificationX509Credentials(c -> c.add(relyingPartyVerifyingCredential())))
+				.build();
 
 		String generatedMetadata = "<xml>test</xml>";
-		when(saml2MetadataResolver.resolveMetadata(request, validRegistration)).thenReturn(generatedMetadata);
+		when(this.resolver.resolve(validRegistration)).thenReturn(generatedMetadata);
 
-		filter = new Saml2MetadataFilter(repository, saml2MetadataResolver);
+		this.filter = new Saml2MetadataFilter(request -> validRegistration, this.resolver);
 
 		// when
-		filter.doFilter(request, response, filterChain);
+		this.filter.doFilter(this.request, this.response, this.chain);
 
 		// then
-		verifyNoInteractions(filterChain);
-		assertThat(response.getStatus()).isEqualTo(200);
-		assertThat(response.getContentAsString()).isEqualTo(generatedMetadata);
-		verify(saml2MetadataResolver).resolveMetadata(request, validRegistration);
+		verifyNoInteractions(this.chain);
+		assertThat(this.response.getStatus()).isEqualTo(200);
+		assertThat(this.response.getContentAsString()).isEqualTo(generatedMetadata);
+		verify(this.resolver).resolve(validRegistration);
 	}
 
+	@Test
+	public void doFilterWhenCustomRequestMatcherThenUses() throws Exception {
+		// given
+		this.request.setPathInfo("/path");
+		this.filter.setRequestMatcher(new AntPathRequestMatcher("/path"));
+
+		// when
+		this.filter.doFilter(this.request, this.response, this.chain);
+
+		// then
+		verifyNoInteractions(this.chain);
+		verify(this.repository).findByRegistrationId("path");
+	}
+
+	@Test
+	public void setRequestMatcherWhenNullThenIllegalArgument() {
+		assertThatCode(() -> this.filter.setRequestMatcher(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
 }

+ 1 - 1
samples/javaconfig/saml2login/src/main/java/org/springframework/security/samples/config/SecurityConfig.java

@@ -29,7 +29,7 @@ import org.springframework.security.converter.RsaKeyConverters;
 import org.springframework.security.saml2.core.Saml2X509Credential;
 import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
-import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationFilter;
+import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
 
 import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.DECRYPTION;
 import static org.springframework.security.saml2.core.Saml2X509Credential.Saml2X509CredentialType.SIGNING;

+ 1 - 1
samples/javaconfig/saml2login/src/test/java/org/springframework/security/samples/config/SecurityConfigTests.java

@@ -17,7 +17,7 @@ package org.springframework.security.samples.config;
 
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.ApplicationContext;
-import org.springframework.security.saml2.provider.service.web.Saml2WebSsoAuthenticationFilter;
+import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
 import org.springframework.security.web.FilterChainProxy;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;