Selaa lähdekoodia

Add OpenSamlAssertingPartyMetadataRepository

Closes gh-12116
Closes gh-15395
Josh Cummings 1 vuosi sitten
vanhempi
commit
e6dfb63bdf

+ 121 - 2
docs/modules/ROOT/pages/servlet/saml2/metadata.adoc

@@ -27,11 +27,130 @@ Kotlin::
 [source,kotlin,role="secondary"]
 ----
 val details: OpenSamlAssertingPartyDetails =
-        registration.getAssertingPartyDetails() as OpenSamlAssertingPartyDetails;
-val openSamlEntityDescriptor: EntityDescriptor = details.getEntityDescriptor();
+        registration.getAssertingPartyDetails() as OpenSamlAssertingPartyDetails
+val openSamlEntityDescriptor: EntityDescriptor = details.getEntityDescriptor()
 ----
 ======
 
+=== Using `AssertingPartyMetadataRepository`
+
+You can also be more targeted than `RelyingPartyRegistrations` by using `AssertingPartyMetadataRepository`, an interface that allows for only retrieving the asserting party metadata.
+
+This allows three valuable features:
+
+* Implementations can refresh asserting party metadata in an expiry-aware fashion
+* Implementations of `RelyingPartyRegistrationRepository` can more easily articulate a relationship between a relying party and its one or many corresponding asserting parties
+* Implementations can verify metadata signatures
+
+For example, `OpenSamlAssertingPartyMetadataRepository` uses OpenSAML's `MetadataResolver`, and API whose implementations regularly refresh the underlying metadata in an expiry-aware fashion.
+
+This means that you can now create a refreshable `RelyingPartyRegistrationRepository` in just a few lines of code:
+
+[tabs]
+======
+Java::
++
+[source,java,role="primary"]
+----
+@Component
+public class RefreshableRelyingPartyRegistrationRepository
+        implements IterableRelyingPartyRegistrationRepository {
+
+	private final AssertingPartyMetadataRepository metadata =
+            OpenSamlAssertingPartyMetadataRepository
+                .fromTrustedMetadataLocation("https://idp.example.org/metadata").build();
+
+	@Override
+    public RelyingPartyRegistration findByRegistrationId(String registrationId) {
+		AssertingPartyMetadata metadata = this.metadata.findByEntityId(registrationId);
+        if (metadata == null) {
+            return null;
+        }
+		return applyRelyingParty(metadata);
+    }
+
+	@Override
+    public Iterator<RelyingPartyRegistration> iterator() {
+		return StreamSupport.stream(this.metadata.spliterator(), false)
+            .map(this::applyRelyingParty).iterator();
+    }
+
+	private RelyingPartyRegistration applyRelyingParty(AssertingPartyMetadata metadata) {
+		AssertingPartyDetails details = (AssertingPartyDetails) metadata;
+		return RelyingPartyRegistration.withAssertingPartyDetails(details)
+            // apply any relying party configuration
+            .build();
+	}
+
+}
+----
+
+Kotlin::
++
+[source,kotlin,role="secondary"]
+----
+@Component
+class RefreshableRelyingPartyRegistrationRepository : IterableRelyingPartyRegistrationRepository {
+
+    private val metadata: AssertingPartyMetadataRepository =
+        OpenSamlAssertingPartyMetadataRepository.fromTrustedMetadataLocation(
+            "https://idp.example.org/metadata").build()
+
+    fun findByRegistrationId(registrationId:String?): RelyingPartyRegistration {
+        val metadata = this.metadata.findByEntityId(registrationId)
+        if (metadata == null) {
+            return null
+        }
+        return applyRelyingParty(metadata)
+    }
+
+    fun iterator(): Iterator<RelyingPartyRegistration> {
+        return StreamSupport.stream(this.metadata.spliterator(), false)
+            .map(this::applyRelyingParty).iterator()
+    }
+
+    private fun applyRelyingParty(metadata: AssertingPartyMetadata): RelyingPartyRegistration {
+        val details: AssertingPartyDetails = metadata as AssertingPartyDetails
+        return RelyingPartyRegistration.withAssertingPartyDetails(details)
+            // apply any relying party configuration
+            .build()
+    }
+ }
+----
+======
+
+[TIP]
+`OpenSamlAssertingPartyMetadataRepository` also ships with a constructor so you can provide a custom `MetadataResolver`. Since the underlying `MetadataResolver` is doing the expirying and refreshing, if you use the constructor directly, you will only get these features by providing an implementation that does so.
+
+=== Verifying Metadata Signatures
+
+You can also verify metadata signatures using `OpenSamlAssertingPartyMetadataRepository` by providing the appropriate set of ``Saml2X509Credential``s as follows:
+
+[tabs]
+======
+Java::
++
+[source,java,role="primary"]
+----
+OpenSamlAssertingPartyMetadataRepository.withMetadataLocation("https://idp.example.org/metadata")
+    .verificationCredentials((c) -> c.add(myVerificationCredential))
+    .build();
+----
+
+Kotlin::
++
+[source,kotlin,role="secondary"]
+----
+OpenSamlAssertingPartyMetadataRepository.withMetadataLocation("https://idp.example.org/metadata")
+    .verificationCredentials({ c : Collection<Saml2X509Credential> ->
+        c.add(myVerificationCredential) })
+    .build()
+----
+======
+
+[NOTE]
+If no credentials are provided, the component will not perform signature validation.
+
 [[publishing-relying-party-metadata]]
 == Producing `<saml2:SPSSODescriptor>` Metadata
 

+ 12 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlSigningUtils.java

@@ -80,7 +80,13 @@ final class OpenSamlSigningUtils {
 	}
 
 	static <O extends SignableXMLObject> O sign(O object, RelyingPartyRegistration relyingPartyRegistration) {
-		SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration);
+		List<String> algorithms = relyingPartyRegistration.getAssertingPartyDetails().getSigningAlgorithms();
+		List<Credential> credentials = resolveSigningCredentials(relyingPartyRegistration);
+		return sign(object, algorithms, credentials);
+	}
+
+	static <O extends SignableXMLObject> O sign(O object, List<String> algorithms, List<Credential> credentials) {
+		SignatureSigningParameters parameters = resolveSigningParameters(algorithms, credentials);
 		try {
 			SignatureSupport.signObject(object, parameters);
 			return object;
@@ -98,6 +104,11 @@ final class OpenSamlSigningUtils {
 			RelyingPartyRegistration relyingPartyRegistration) {
 		List<Credential> credentials = resolveSigningCredentials(relyingPartyRegistration);
 		List<String> algorithms = relyingPartyRegistration.getAssertingPartyDetails().getSigningAlgorithms();
+		return resolveSigningParameters(algorithms, credentials);
+	}
+
+	private static SignatureSigningParameters resolveSigningParameters(List<String> algorithms,
+			List<Credential> credentials) {
 		List<String> digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256);
 		String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS;
 		SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver();

+ 383 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/OpenSamlAssertingPartyMetadataRepository.java

@@ -0,0 +1,383 @@
+/*
+ * Copyright 2002-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.registration;
+
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URI;
+import java.net.URL;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Set;
+import java.util.function.Consumer;
+import java.util.function.Supplier;
+
+import javax.annotation.Nonnull;
+
+import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
+import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
+import net.shibboleth.utilities.java.support.resolver.ResolverException;
+import org.opensaml.core.criterion.EntityIdCriterion;
+import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
+import org.opensaml.saml.criterion.EntityRoleCriterion;
+import org.opensaml.saml.metadata.IterableMetadataSource;
+import org.opensaml.saml.metadata.resolver.MetadataResolver;
+import org.opensaml.saml.metadata.resolver.filter.impl.SignatureValidationFilter;
+import org.opensaml.saml.metadata.resolver.impl.AbstractBatchMetadataResolver;
+import org.opensaml.saml.metadata.resolver.impl.ResourceBackedMetadataResolver;
+import org.opensaml.saml.metadata.resolver.index.MetadataIndex;
+import org.opensaml.saml.metadata.resolver.index.impl.RoleMetadataIndex;
+import org.opensaml.saml.saml2.metadata.EntityDescriptor;
+import org.opensaml.saml.saml2.metadata.IDPSSODescriptor;
+import org.opensaml.security.credential.Credential;
+import org.opensaml.security.credential.impl.CollectionCredentialResolver;
+import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap;
+import org.opensaml.xmlsec.signature.support.SignatureTrustEngine;
+import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine;
+
+import org.springframework.core.io.DefaultResourceLoader;
+import org.springframework.core.io.Resource;
+import org.springframework.core.io.ResourceLoader;
+import org.springframework.lang.NonNull;
+import org.springframework.lang.Nullable;
+import org.springframework.security.saml2.Saml2Exception;
+import org.springframework.security.saml2.core.OpenSamlInitializationService;
+import org.springframework.security.saml2.core.Saml2X509Credential;
+import org.springframework.util.Assert;
+
+/**
+ * An implementation of {@link AssertingPartyMetadataRepository} that uses a
+ * {@link MetadataResolver} to retrieve {@link AssertingPartyMetadata} instances.
+ *
+ * <p>
+ * The {@link MetadataResolver} constructed in {@link #withTrustedMetadataLocation}
+ * provides expiry-aware refreshing.
+ *
+ * @author Josh Cummings
+ * @since 6.4
+ * @see AssertingPartyMetadataRepository
+ * @see RelyingPartyRegistrations
+ */
+public final class OpenSamlAssertingPartyMetadataRepository implements AssertingPartyMetadataRepository {
+
+	static {
+		OpenSamlInitializationService.initialize();
+	}
+
+	private final MetadataResolver metadataResolver;
+
+	private final Supplier<Iterator<EntityDescriptor>> descriptors;
+
+	/**
+	 * Construct an {@link OpenSamlAssertingPartyMetadataRepository} using the provided
+	 * {@link MetadataResolver}.
+	 *
+	 * <p>
+	 * The {@link MetadataResolver} should either be of type
+	 * {@link IterableMetadataSource} or it should have a {@link RoleMetadataIndex}
+	 * configured.
+	 * @param metadataResolver the {@link MetadataResolver} to use
+	 */
+	public OpenSamlAssertingPartyMetadataRepository(MetadataResolver metadataResolver) {
+		Assert.notNull(metadataResolver, "metadataResolver cannot be null");
+		if (isRoleIndexed(metadataResolver)) {
+			this.descriptors = this::allIndexedEntities;
+		}
+		else if (metadataResolver instanceof IterableMetadataSource source) {
+			this.descriptors = source::iterator;
+		}
+		else {
+			throw new IllegalArgumentException(
+					"metadataResolver must be an IterableMetadataSource or have a RoleMetadataIndex");
+		}
+		this.metadataResolver = metadataResolver;
+	}
+
+	private static boolean isRoleIndexed(MetadataResolver resolver) {
+		if (!(resolver instanceof AbstractBatchMetadataResolver batch)) {
+			return false;
+		}
+		for (MetadataIndex index : batch.getIndexes()) {
+			if (index instanceof RoleMetadataIndex) {
+				return true;
+			}
+		}
+		return false;
+	}
+
+	private Iterator<EntityDescriptor> allIndexedEntities() {
+		CriteriaSet all = new CriteriaSet(new EntityRoleCriterion(IDPSSODescriptor.DEFAULT_ELEMENT_NAME));
+		try {
+			return this.metadataResolver.resolve(all).iterator();
+		}
+		catch (ResolverException ex) {
+			throw new Saml2Exception(ex);
+		}
+	}
+
+	@Override
+	@NonNull
+	public Iterator<AssertingPartyMetadata> iterator() {
+		Iterator<EntityDescriptor> descriptors = this.descriptors.get();
+		return new Iterator<>() {
+			@Override
+			public boolean hasNext() {
+				return descriptors.hasNext();
+			}
+
+			@Override
+			public AssertingPartyMetadata next() {
+				return OpenSamlAssertingPartyDetails.withEntityDescriptor(descriptors.next()).build();
+			}
+		};
+	}
+
+	@Nullable
+	@Override
+	public AssertingPartyMetadata findByEntityId(String entityId) {
+		CriteriaSet byEntityId = new CriteriaSet(new EntityIdCriterion(entityId));
+		EntityDescriptor descriptor = resolveSingle(byEntityId);
+		if (descriptor == null) {
+			return null;
+		}
+		return OpenSamlAssertingPartyDetails.withEntityDescriptor(descriptor).build();
+	}
+
+	private EntityDescriptor resolveSingle(CriteriaSet criteria) {
+		try {
+			return this.metadataResolver.resolveSingle(criteria);
+		}
+		catch (ResolverException ex) {
+			throw new Saml2Exception(ex);
+		}
+	}
+
+	/**
+	 * Use this trusted {@code metadataLocation} to retrieve refreshable, expiry-aware
+	 * SAML 2.0 Asserting Party (IDP) metadata.
+	 *
+	 * <p>
+	 * Valid locations can be classpath- or file-based or they can be HTTPS endpoints.
+	 * Some valid endpoints might include:
+	 *
+	 * <pre>
+	 *   metadataLocation = "classpath:asserting-party-metadata.xml";
+	 *   metadataLocation = "file:asserting-party-metadata.xml";
+	 *   metadataLocation = "https://ap.example.org/metadata";
+	 * </pre>
+	 *
+	 * <p>
+	 * Resolution of location is attempted immediately. To defer, wrap in
+	 * {@link CachingRelyingPartyRegistrationRepository}.
+	 * @param metadataLocation the classpath- or file-based locations or HTTPS endpoints
+	 * of the asserting party metadata file
+	 * @return the {@link MetadataLocationRepositoryBuilder} for further configuration
+	 */
+	public static MetadataLocationRepositoryBuilder withTrustedMetadataLocation(String metadataLocation) {
+		return new MetadataLocationRepositoryBuilder(metadataLocation, true);
+	}
+
+	/**
+	 * Use this {@code metadataLocation} to retrieve refreshable, expiry-aware SAML 2.0
+	 * Asserting Party (IDP) metadata. Verification credentials are required.
+	 *
+	 * <p>
+	 * Valid locations can be classpath- or file-based or they can be remote endpoints.
+	 * Some valid endpoints might include:
+	 *
+	 * <pre>
+	 *   metadataLocation = "classpath:asserting-party-metadata.xml";
+	 *   metadataLocation = "file:asserting-party-metadata.xml";
+	 *   metadataLocation = "https://ap.example.org/metadata";
+	 * </pre>
+	 *
+	 * <p>
+	 * Resolution of location is attempted immediately. To defer, wrap in
+	 * {@link CachingRelyingPartyRegistrationRepository}.
+	 * @param metadataLocation the classpath- or file-based locations or remote endpoints
+	 * of the asserting party metadata file
+	 * @return the {@link MetadataLocationRepositoryBuilder} for further configuration
+	 */
+	public static MetadataLocationRepositoryBuilder withMetadataLocation(String metadataLocation) {
+		return new MetadataLocationRepositoryBuilder(metadataLocation, false);
+	}
+
+	/**
+	 * A builder class for configuring {@link OpenSamlAssertingPartyMetadataRepository}
+	 * for a specific metadata location.
+	 *
+	 * @author Josh Cummings
+	 */
+	public static final class MetadataLocationRepositoryBuilder {
+
+		private final String metadataLocation;
+
+		private final boolean requireVerificationCredentials;
+
+		private final Collection<Credential> verificationCredentials = new ArrayList<>();
+
+		private ResourceLoader resourceLoader = new DefaultResourceLoader();
+
+		private MetadataLocationRepositoryBuilder(String metadataLocation, boolean trusted) {
+			this.metadataLocation = metadataLocation;
+			this.requireVerificationCredentials = !trusted;
+		}
+
+		/**
+		 * Apply this {@link Consumer} to the list of {@link Saml2X509Credential}s to use
+		 * for verifying metadata signatures.
+		 *
+		 * <p>
+		 * If no credentials are supplied, no signature verification is performed.
+		 * @param credentials a {@link Consumer} of the {@link Collection} of
+		 * {@link Saml2X509Credential}s
+		 * @return the {@link MetadataLocationRepositoryBuilder} for further configuration
+		 */
+		public MetadataLocationRepositoryBuilder verificationCredentials(Consumer<Collection<Credential>> credentials) {
+			credentials.accept(this.verificationCredentials);
+			return this;
+		}
+
+		/**
+		 * Use this {@link ResourceLoader} for resolving the {@code metadataLocation}
+		 * @param resourceLoader the {@link ResourceLoader} to use
+		 * @return the {@link MetadataLocationRepositoryBuilder} for further configuration
+		 */
+		public MetadataLocationRepositoryBuilder resourceLoader(ResourceLoader resourceLoader) {
+			this.resourceLoader = resourceLoader;
+			return this;
+		}
+
+		/**
+		 * Build the {@link OpenSamlAssertingPartyMetadataRepository}
+		 * @return the {@link OpenSamlAssertingPartyMetadataRepository}
+		 */
+		public OpenSamlAssertingPartyMetadataRepository build() {
+			ResourceBackedMetadataResolver metadataResolver = metadataResolver();
+			if (!this.verificationCredentials.isEmpty()) {
+				SignatureTrustEngine engine = new ExplicitKeySignatureTrustEngine(
+						new CollectionCredentialResolver(this.verificationCredentials),
+						DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver());
+				SignatureValidationFilter filter = new SignatureValidationFilter(engine);
+				filter.setRequireSignedRoot(true);
+				metadataResolver.setMetadataFilter(filter);
+				return new OpenSamlAssertingPartyMetadataRepository(initialize(metadataResolver));
+			}
+			Assert.isTrue(!this.requireVerificationCredentials, "Verification credentials are required");
+			return new OpenSamlAssertingPartyMetadataRepository(initialize(metadataResolver));
+		}
+
+		private ResourceBackedMetadataResolver metadataResolver() {
+			Resource resource = this.resourceLoader.getResource(this.metadataLocation);
+			try {
+				return new ResourceBackedMetadataResolver(new SpringResource(resource));
+			}
+			catch (IOException ex) {
+				throw new Saml2Exception(ex);
+			}
+		}
+
+		private MetadataResolver initialize(ResourceBackedMetadataResolver metadataResolver) {
+			try {
+				metadataResolver.setId(this.getClass().getName() + ".metadataResolver");
+				metadataResolver.setParserPool(XMLObjectProviderRegistrySupport.getParserPool());
+				metadataResolver.setIndexes(Set.of(new RoleMetadataIndex()));
+				metadataResolver.initialize();
+				return metadataResolver;
+			}
+			catch (ComponentInitializationException ex) {
+				throw new Saml2Exception(ex);
+			}
+		}
+
+		private static final class SpringResource implements net.shibboleth.utilities.java.support.resource.Resource {
+
+			private final Resource resource;
+
+			SpringResource(Resource resource) {
+				this.resource = resource;
+			}
+
+			@Override
+			public boolean exists() {
+				return this.resource.exists();
+			}
+
+			@Override
+			public boolean isReadable() {
+				return this.resource.isReadable();
+			}
+
+			@Override
+			public boolean isOpen() {
+				return this.resource.isOpen();
+			}
+
+			@Override
+			public URL getURL() throws IOException {
+				return this.resource.getURL();
+			}
+
+			@Override
+			public URI getURI() throws IOException {
+				return this.resource.getURI();
+			}
+
+			@Override
+			public File getFile() throws IOException {
+				return this.resource.getFile();
+			}
+
+			@Nonnull
+			@Override
+			public InputStream getInputStream() throws IOException {
+				return this.resource.getInputStream();
+			}
+
+			@Override
+			public long contentLength() throws IOException {
+				return this.resource.contentLength();
+			}
+
+			@Override
+			public long lastModified() throws IOException {
+				return this.resource.lastModified();
+			}
+
+			@Override
+			public net.shibboleth.utilities.java.support.resource.Resource createRelativeResource(String relativePath)
+					throws IOException {
+				return new SpringResource(this.resource.createRelative(relativePath));
+			}
+
+			@Override
+			public String getFilename() {
+				return this.resource.getFilename();
+			}
+
+			@Override
+			public String getDescription() {
+				return this.resource.getDescription();
+			}
+
+		}
+
+	}
+
+}

+ 47 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java

@@ -45,6 +45,7 @@ import org.opensaml.core.xml.schema.impl.XSStringBuilder;
 import org.opensaml.core.xml.schema.impl.XSURIBuilder;
 import org.opensaml.saml.common.SAMLVersion;
 import org.opensaml.saml.common.SignableSAMLObject;
+import org.opensaml.saml.common.xml.SAMLConstants;
 import org.opensaml.saml.saml2.core.Assertion;
 import org.opensaml.saml.saml2.core.Attribute;
 import org.opensaml.saml.saml2.core.AttributeStatement;
@@ -74,6 +75,14 @@ import org.opensaml.saml.saml2.core.impl.NameIDBuilder;
 import org.opensaml.saml.saml2.core.impl.StatusBuilder;
 import org.opensaml.saml.saml2.core.impl.StatusCodeBuilder;
 import org.opensaml.saml.saml2.encryption.Encrypter;
+import org.opensaml.saml.saml2.metadata.EntityDescriptor;
+import org.opensaml.saml.saml2.metadata.IDPSSODescriptor;
+import org.opensaml.saml.saml2.metadata.KeyDescriptor;
+import org.opensaml.saml.saml2.metadata.SingleSignOnService;
+import org.opensaml.saml.saml2.metadata.impl.EntityDescriptorBuilder;
+import org.opensaml.saml.saml2.metadata.impl.IDPSSODescriptorBuilder;
+import org.opensaml.saml.saml2.metadata.impl.KeyDescriptorBuilder;
+import org.opensaml.saml.saml2.metadata.impl.SingleSignOnServiceBuilder;
 import org.opensaml.security.SecurityException;
 import org.opensaml.security.credential.BasicCredential;
 import org.opensaml.security.credential.Credential;
@@ -83,6 +92,9 @@ import org.opensaml.xmlsec.SignatureSigningParameters;
 import org.opensaml.xmlsec.encryption.support.DataEncryptionParameters;
 import org.opensaml.xmlsec.encryption.support.EncryptionException;
 import org.opensaml.xmlsec.encryption.support.KeyEncryptionParameters;
+import org.opensaml.xmlsec.keyinfo.KeyInfoSupport;
+import org.opensaml.xmlsec.signature.KeyInfo;
+import org.opensaml.xmlsec.signature.impl.KeyInfoBuilder;
 import org.opensaml.xmlsec.signature.support.SignatureConstants;
 import org.opensaml.xmlsec.signature.support.SignatureException;
 import org.opensaml.xmlsec.signature.support.SignatureSupport;
@@ -92,6 +104,7 @@ import org.springframework.security.saml2.core.OpenSamlInitializationService;
 import org.springframework.security.saml2.core.Saml2X509Credential;
 import org.springframework.security.saml2.core.TestSaml2X509Credentials;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
 
 public final class TestOpenSamlObjects {
 
@@ -221,7 +234,7 @@ public final class TestOpenSamlObjects {
 		return logoutRequest;
 	}
 
-	static Credential getSigningCredential(Saml2X509Credential credential, String entityId) {
+	public static Credential getSigningCredential(Saml2X509Credential credential, String entityId) {
 		BasicCredential cred = getBasicCredential(credential);
 		cred.setEntityId(entityId);
 		cred.setUsageType(UsageType.SIGNING);
@@ -466,6 +479,39 @@ public final class TestOpenSamlObjects {
 		return logoutRequest;
 	}
 
+	public static EntityDescriptor entityDescriptor(RelyingPartyRegistration registration) {
+		EntityDescriptorBuilder entityDescriptorBuilder = new EntityDescriptorBuilder();
+		EntityDescriptor entityDescriptor = entityDescriptorBuilder.buildObject();
+		entityDescriptor.setEntityID(registration.getAssertingPartyDetails().getEntityId());
+		IDPSSODescriptorBuilder idpssoDescriptorBuilder = new IDPSSODescriptorBuilder();
+		IDPSSODescriptor idpssoDescriptor = idpssoDescriptorBuilder.buildObject();
+		idpssoDescriptor.addSupportedProtocol(SAMLConstants.SAML20P_NS);
+		SingleSignOnServiceBuilder singleSignOnServiceBuilder = new SingleSignOnServiceBuilder();
+		SingleSignOnService singleSignOnService = singleSignOnServiceBuilder.buildObject();
+		singleSignOnService.setBinding(Saml2MessageBinding.POST.getUrn());
+		singleSignOnService.setLocation(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation());
+		idpssoDescriptor.getSingleSignOnServices().add(singleSignOnService);
+		KeyDescriptorBuilder keyDescriptorBuilder = new KeyDescriptorBuilder();
+		KeyDescriptor keyDescriptor = keyDescriptorBuilder.buildObject();
+		keyDescriptor.setUse(UsageType.SIGNING);
+		KeyInfoBuilder keyInfoBuilder = new KeyInfoBuilder();
+		KeyInfo keyInfo = keyInfoBuilder.buildObject();
+		addCertificate(keyInfo, registration.getSigningX509Credentials().iterator().next().getCertificate());
+		keyDescriptor.setKeyInfo(keyInfo);
+		idpssoDescriptor.getKeyDescriptors().add(keyDescriptor);
+		entityDescriptor.getRoleDescriptors(IDPSSODescriptor.DEFAULT_ELEMENT_NAME).add(idpssoDescriptor);
+		return entityDescriptor;
+	}
+
+	static void addCertificate(KeyInfo info, X509Certificate certificate) {
+		try {
+			KeyInfoSupport.addCertificate(info, certificate);
+		}
+		catch (Exception ex) {
+			throw new Saml2Exception(ex);
+		}
+	}
+
 	static <T extends XMLObject> T build(QName qName) {
 		return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName);
 	}

+ 377 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/OpenSamlAssertingPartyMetadataRepositoryTests.java

@@ -0,0 +1,377 @@
+/*
+ * Copyright 2002-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.registration;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import net.shibboleth.utilities.java.support.xml.SerializeSupport;
+import okhttp3.mockwebserver.Dispatcher;
+import okhttp3.mockwebserver.MockResponse;
+import okhttp3.mockwebserver.MockWebServer;
+import okhttp3.mockwebserver.RecordedRequest;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.opensaml.core.xml.XMLObject;
+import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
+import org.opensaml.core.xml.io.Marshaller;
+import org.opensaml.core.xml.io.MarshallingException;
+import org.opensaml.saml.metadata.IterableMetadataSource;
+import org.opensaml.saml.metadata.resolver.MetadataResolver;
+import org.opensaml.saml.metadata.resolver.impl.FilesystemMetadataResolver;
+import org.opensaml.saml.metadata.resolver.index.impl.RoleMetadataIndex;
+import org.opensaml.saml.saml2.metadata.EntityDescriptor;
+import org.opensaml.security.credential.Credential;
+import org.w3c.dom.Element;
+
+import org.springframework.core.io.ClassPathResource;
+import org.springframework.core.io.ResourceLoader;
+import org.springframework.security.saml2.Saml2Exception;
+import org.springframework.security.saml2.core.OpenSamlInitializationService;
+import org.springframework.security.saml2.core.TestSaml2X509Credentials;
+import org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.withSettings;
+
+/**
+ * Tests for {@link OpenSamlAssertingPartyMetadataRepository}
+ */
+public class OpenSamlAssertingPartyMetadataRepositoryTests {
+
+	static {
+		OpenSamlInitializationService.initialize();
+	}
+
+	private String metadata;
+
+	private String entitiesDescriptor;
+
+	@BeforeEach
+	public void setup() throws Exception {
+		ClassPathResource resource = new ClassPathResource("test-metadata.xml");
+		try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) {
+			this.metadata = reader.lines().collect(Collectors.joining());
+		}
+		resource = new ClassPathResource("test-entitiesdescriptor.xml");
+		try (BufferedReader reader = new BufferedReader(new InputStreamReader(resource.getInputStream()))) {
+			this.entitiesDescriptor = reader.lines().collect(Collectors.joining());
+		}
+	}
+
+	@Test
+	public void withMetadataUrlLocationWhenResolvableThenFindByEntityIdReturns() throws Exception {
+		try (MockWebServer server = new MockWebServer()) {
+			server.setDispatcher(new AlwaysDispatch(new MockResponse().setBody(this.metadata).setResponseCode(200)));
+			AssertingPartyMetadataRepository parties = OpenSamlAssertingPartyMetadataRepository
+				.withTrustedMetadataLocation(server.url("/").toString())
+				.build();
+			AssertingPartyMetadata party = parties.findByEntityId("https://idp.example.com/idp/shibboleth");
+			assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth");
+			assertThat(party.getSingleSignOnServiceLocation())
+				.isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO");
+			assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST);
+			assertThat(party.getVerificationX509Credentials()).hasSize(1);
+			assertThat(party.getEncryptionX509Credentials()).hasSize(1);
+		}
+	}
+
+	@Test
+	public void withMetadataUrlLocationnWhenResolvableThenIteratorReturns() throws Exception {
+		try (MockWebServer server = new MockWebServer()) {
+			server.setDispatcher(
+					new AlwaysDispatch(new MockResponse().setBody(this.entitiesDescriptor).setResponseCode(200)));
+			List<AssertingPartyMetadata> parties = new ArrayList<>();
+			OpenSamlAssertingPartyMetadataRepository.withTrustedMetadataLocation(server.url("/").toString())
+				.build()
+				.iterator()
+				.forEachRemaining(parties::add);
+			assertThat(parties).hasSize(2);
+			assertThat(parties).extracting(AssertingPartyMetadata::getEntityId)
+				.contains("https://ap.example.org/idp/shibboleth", "https://idp.example.com/idp/shibboleth");
+		}
+	}
+
+	@Test
+	public void withMetadataUrlLocationWhenUnresolvableThenThrowsSaml2Exception() throws Exception {
+		try (MockWebServer server = new MockWebServer()) {
+			server.enqueue(new MockResponse().setBody(this.metadata).setResponseCode(200));
+			String url = server.url("/").toString();
+			server.shutdown();
+			assertThatExceptionOfType(Saml2Exception.class)
+				.isThrownBy(() -> OpenSamlAssertingPartyMetadataRepository.withTrustedMetadataLocation(url).build());
+		}
+	}
+
+	@Test
+	public void withMetadataUrlLocationWhenMalformedResponseThenSaml2Exception() throws Exception {
+		try (MockWebServer server = new MockWebServer()) {
+			server.setDispatcher(new AlwaysDispatch("malformed"));
+			String url = server.url("/").toString();
+			assertThatExceptionOfType(Saml2Exception.class)
+				.isThrownBy(() -> OpenSamlAssertingPartyMetadataRepository.withTrustedMetadataLocation(url).build());
+		}
+	}
+
+	@Test
+	public void fromMetadataFileLocationWhenResolvableThenFindByEntityIdReturns() {
+		File file = new File("src/test/resources/test-metadata.xml");
+		AssertingPartyMetadata party = OpenSamlAssertingPartyMetadataRepository
+			.withTrustedMetadataLocation("file:" + file.getAbsolutePath())
+			.build()
+			.findByEntityId("https://idp.example.com/idp/shibboleth");
+		assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth");
+		assertThat(party.getSingleSignOnServiceLocation())
+			.isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO");
+		assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST);
+		assertThat(party.getVerificationX509Credentials()).hasSize(1);
+		assertThat(party.getEncryptionX509Credentials()).hasSize(1);
+	}
+
+	@Test
+	public void fromMetadataFileLocationWhenResolvableThenIteratorReturns() {
+		File file = new File("src/test/resources/test-entitiesdescriptor.xml");
+		Collection<AssertingPartyMetadata> parties = new ArrayList<>();
+		OpenSamlAssertingPartyMetadataRepository.withTrustedMetadataLocation("file:" + file.getAbsolutePath())
+			.build()
+			.iterator()
+			.forEachRemaining(parties::add);
+		assertThat(parties).hasSize(2);
+		assertThat(parties).extracting(AssertingPartyMetadata::getEntityId)
+			.contains("https://idp.example.com/idp/shibboleth", "https://ap.example.org/idp/shibboleth");
+	}
+
+	@Test
+	public void withMetadataFileLocationWhenNotFoundThenSaml2Exception() {
+		assertThatExceptionOfType(Saml2Exception.class).isThrownBy(
+				() -> OpenSamlAssertingPartyMetadataRepository.withTrustedMetadataLocation("file:path").build());
+	}
+
+	@Test
+	public void fromMetadataClasspathLocationWhenResolvableThenFindByEntityIdReturns() {
+		AssertingPartyMetadata party = OpenSamlAssertingPartyMetadataRepository
+			.withTrustedMetadataLocation("classpath:test-entitiesdescriptor.xml")
+			.build()
+			.findByEntityId("https://ap.example.org/idp/shibboleth");
+		assertThat(party.getEntityId()).isEqualTo("https://ap.example.org/idp/shibboleth");
+		assertThat(party.getSingleSignOnServiceLocation())
+			.isEqualTo("https://ap.example.org/idp/profile/SAML2/POST/SSO");
+		assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST);
+		assertThat(party.getVerificationX509Credentials()).hasSize(1);
+		assertThat(party.getEncryptionX509Credentials()).hasSize(1);
+	}
+
+	@Test
+	public void fromMetadataClasspathLocationWhenResolvableThenIteratorReturns() {
+		Collection<AssertingPartyMetadata> parties = new ArrayList<>();
+		OpenSamlAssertingPartyMetadataRepository.withTrustedMetadataLocation("classpath:test-entitiesdescriptor.xml")
+			.build()
+			.iterator()
+			.forEachRemaining(parties::add);
+		assertThat(parties).hasSize(2);
+		assertThat(parties).extracting(AssertingPartyMetadata::getEntityId)
+			.contains("https://idp.example.com/idp/shibboleth", "https://ap.example.org/idp/shibboleth");
+	}
+
+	@Test
+	public void withMetadataClasspathLocationWhenNotFoundThenSaml2Exception() {
+		assertThatExceptionOfType(Saml2Exception.class).isThrownBy(
+				() -> OpenSamlAssertingPartyMetadataRepository.withTrustedMetadataLocation("classpath:path").build());
+	}
+
+	@Test
+	public void withTrustedMetadataLocationWhenMatchingCredentialsThenVerifiesSignature() throws IOException {
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
+		EntityDescriptor descriptor = TestOpenSamlObjects.entityDescriptor(registration);
+		TestOpenSamlObjects.signed(descriptor, TestSaml2X509Credentials.assertingPartySigningCredential(),
+				descriptor.getEntityID());
+		String serialized = serialize(descriptor);
+		Credential credential = TestOpenSamlObjects
+			.getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID());
+		try (MockWebServer server = new MockWebServer()) {
+			server.start();
+			server.setDispatcher(new AlwaysDispatch(serialized));
+			AssertingPartyMetadataRepository parties = OpenSamlAssertingPartyMetadataRepository
+				.withTrustedMetadataLocation(server.url("/").toString())
+				.verificationCredentials((c) -> c.add(credential))
+				.build();
+			assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull();
+		}
+	}
+
+	@Test
+	public void withTrustedMetadataLocationWhenMismatchingCredentialsThenSaml2Exception() throws IOException {
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
+		EntityDescriptor descriptor = TestOpenSamlObjects.entityDescriptor(registration);
+		TestOpenSamlObjects.signed(descriptor, TestSaml2X509Credentials.relyingPartySigningCredential(),
+				descriptor.getEntityID());
+		String serialized = serialize(descriptor);
+		Credential credential = TestOpenSamlObjects
+			.getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID());
+		try (MockWebServer server = new MockWebServer()) {
+			server.start();
+			server.setDispatcher(new AlwaysDispatch(serialized));
+			assertThatExceptionOfType(Saml2Exception.class).isThrownBy(() -> OpenSamlAssertingPartyMetadataRepository
+				.withTrustedMetadataLocation(server.url("/").toString())
+				.verificationCredentials((c) -> c.add(credential))
+				.build());
+		}
+	}
+
+	@Test
+	public void withTrustedMetadataLocationWhenNoCredentialsThenSkipsVerifySignature() throws IOException {
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
+		EntityDescriptor descriptor = TestOpenSamlObjects.entityDescriptor(registration);
+		TestOpenSamlObjects.signed(descriptor, TestSaml2X509Credentials.assertingPartySigningCredential(),
+				descriptor.getEntityID());
+		String serialized = serialize(descriptor);
+		try (MockWebServer server = new MockWebServer()) {
+			server.start();
+			server.setDispatcher(new AlwaysDispatch(serialized));
+			AssertingPartyMetadataRepository parties = OpenSamlAssertingPartyMetadataRepository
+				.withTrustedMetadataLocation(server.url("/").toString())
+				.build();
+			assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull();
+		}
+	}
+
+	@Test
+	public void withTrustedMetadataLocationWhenCustomResourceLoaderThenUses() {
+		ResourceLoader resourceLoader = mock(ResourceLoader.class);
+		given(resourceLoader.getResource(any())).willReturn(new ClassPathResource("test-metadata.xml"));
+		AssertingPartyMetadata party = OpenSamlAssertingPartyMetadataRepository
+			.withTrustedMetadataLocation("classpath:wrong")
+			.resourceLoader(resourceLoader)
+			.build()
+			.iterator()
+			.next();
+		assertThat(party.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth");
+		assertThat(party.getSingleSignOnServiceLocation())
+			.isEqualTo("https://idp.example.com/idp/profile/SAML2/POST/SSO");
+		assertThat(party.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.POST);
+		assertThat(party.getVerificationX509Credentials()).hasSize(1);
+		assertThat(party.getEncryptionX509Credentials()).hasSize(1);
+		verify(resourceLoader).getResource(any());
+	}
+
+	@Test
+	public void constructorWhenNoIndexAndNoIteratorThenException() {
+		MetadataResolver resolver = mock(MetadataResolver.class);
+		assertThatExceptionOfType(IllegalArgumentException.class)
+			.isThrownBy(() -> new OpenSamlAssertingPartyMetadataRepository(resolver));
+	}
+
+	@Test
+	public void constructorWhenIterableResolverThenUses() {
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
+		EntityDescriptor descriptor = TestOpenSamlObjects.entityDescriptor(registration);
+		MetadataResolver resolver = mock(MetadataResolver.class,
+				withSettings().extraInterfaces(IterableMetadataSource.class));
+		given(((IterableMetadataSource) resolver).iterator()).willReturn(List.of(descriptor).iterator());
+		AssertingPartyMetadataRepository parties = new OpenSamlAssertingPartyMetadataRepository(resolver);
+		parties.iterator()
+			.forEachRemaining((p) -> assertThat(p.getEntityId())
+				.isEqualTo(registration.getAssertingPartyDetails().getEntityId()));
+		verify(((IterableMetadataSource) resolver)).iterator();
+	}
+
+	@Test
+	public void constructorWhenIndexedResolverThenUses() throws Exception {
+		FilesystemMetadataResolver resolver = new FilesystemMetadataResolver(
+				new ClassPathResource("test-metadata.xml").getFile());
+		resolver.setIndexes(Set.of(new RoleMetadataIndex()));
+		resolver.setId("id");
+		resolver.setParserPool(XMLObjectProviderRegistrySupport.getParserPool());
+		resolver.initialize();
+		MetadataResolver spied = spy(resolver);
+		AssertingPartyMetadataRepository parties = new OpenSamlAssertingPartyMetadataRepository(spied);
+		parties.iterator()
+			.forEachRemaining((p) -> assertThat(p.getEntityId()).isEqualTo("https://idp.example.com/idp/shibboleth"));
+		verify(spied).resolve(any());
+	}
+
+	@Test
+	public void withMetadataLocationWhenNoCredentialsThenException() {
+		assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(
+				() -> OpenSamlAssertingPartyMetadataRepository.withMetadataLocation("classpath:test-metadata.xml")
+					.build());
+	}
+
+	@Test
+	public void withMetadataLocationWhenMatchingCredentialsThenVerifiesSignature() throws IOException {
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
+		EntityDescriptor descriptor = TestOpenSamlObjects.entityDescriptor(registration);
+		TestOpenSamlObjects.signed(descriptor, TestSaml2X509Credentials.assertingPartySigningCredential(),
+				descriptor.getEntityID());
+		String serialized = serialize(descriptor);
+		Credential credential = TestOpenSamlObjects
+			.getSigningCredential(TestSaml2X509Credentials.relyingPartyVerifyingCredential(), descriptor.getEntityID());
+		try (MockWebServer server = new MockWebServer()) {
+			server.start();
+			server.setDispatcher(new AlwaysDispatch(serialized));
+			AssertingPartyMetadataRepository parties = OpenSamlAssertingPartyMetadataRepository
+				.withMetadataLocation(server.url("/").toString())
+				.verificationCredentials((c) -> c.add(credential))
+				.build();
+			assertThat(parties.findByEntityId(registration.getAssertingPartyDetails().getEntityId())).isNotNull();
+		}
+	}
+
+	private static String serialize(XMLObject object) {
+		try {
+			Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(object);
+			Element element = marshaller.marshall(object);
+			return SerializeSupport.nodeToString(element);
+		}
+		catch (MarshallingException ex) {
+			throw new Saml2Exception(ex);
+		}
+	}
+
+	private static final class AlwaysDispatch extends Dispatcher {
+
+		private final MockResponse response;
+
+		private AlwaysDispatch(String body) {
+			this.response = new MockResponse().setBody(body).setResponseCode(200);
+		}
+
+		private AlwaysDispatch(MockResponse response) {
+			this.response = response;
+		}
+
+		@Override
+		public MockResponse dispatch(RecordedRequest recordedRequest) throws InterruptedException {
+			return this.response;
+		}
+
+	}
+
+}