فهرست منبع

Support Creating EntitiesDescriptor

Clsoes gh-12844
Josh Cummings 2 سال پیش
والد
کامیت
7678523b73

+ 40 - 5
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java

@@ -30,11 +30,13 @@ import org.opensaml.core.xml.XMLObjectBuilder;
 import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
 import org.opensaml.saml.common.xml.SAMLConstants;
 import org.opensaml.saml.saml2.metadata.AssertionConsumerService;
+import org.opensaml.saml.saml2.metadata.EntitiesDescriptor;
 import org.opensaml.saml.saml2.metadata.EntityDescriptor;
 import org.opensaml.saml.saml2.metadata.KeyDescriptor;
 import org.opensaml.saml.saml2.metadata.NameIDFormat;
 import org.opensaml.saml.saml2.metadata.SPSSODescriptor;
 import org.opensaml.saml.saml2.metadata.SingleLogoutService;
+import org.opensaml.saml.saml2.metadata.impl.EntitiesDescriptorMarshaller;
 import org.opensaml.saml.saml2.metadata.impl.EntityDescriptorMarshaller;
 import org.opensaml.security.credential.UsageType;
 import org.opensaml.xmlsec.signature.KeyInfo;
@@ -65,6 +67,8 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
 
 	private final EntityDescriptorMarshaller entityDescriptorMarshaller;
 
+	private final EntitiesDescriptorMarshaller entitiesDescriptorMarshaller;
+
 	private Consumer<EntityDescriptorParameters> entityDescriptorCustomizer = (parameters) -> {
 	};
 
@@ -72,17 +76,38 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
 		this.entityDescriptorMarshaller = (EntityDescriptorMarshaller) XMLObjectProviderRegistrySupport
 				.getMarshallerFactory().getMarshaller(EntityDescriptor.DEFAULT_ELEMENT_NAME);
 		Assert.notNull(this.entityDescriptorMarshaller, "entityDescriptorMarshaller cannot be null");
+		this.entitiesDescriptorMarshaller = (EntitiesDescriptorMarshaller) XMLObjectProviderRegistrySupport
+				.getMarshallerFactory().getMarshaller(EntitiesDescriptor.DEFAULT_ELEMENT_NAME);
+		Assert.notNull(this.entitiesDescriptorMarshaller, "entitiesDescriptorMarshaller cannot be null");
 	}
 
 	@Override
 	public String resolve(RelyingPartyRegistration relyingPartyRegistration) {
+		EntityDescriptor entityDescriptor = entityDescriptor(relyingPartyRegistration);
+		return serialize(entityDescriptor);
+	}
+
+	public String resolve(Iterable<RelyingPartyRegistration> relyingPartyRegistrations) {
+		Collection<EntityDescriptor> entityDescriptors = new ArrayList<>();
+		for (RelyingPartyRegistration registration : relyingPartyRegistrations) {
+			EntityDescriptor entityDescriptor = entityDescriptor(registration);
+			entityDescriptors.add(entityDescriptor);
+		}
+		if (entityDescriptors.size() == 1) {
+			return serialize(entityDescriptors.iterator().next());
+		}
+		EntitiesDescriptor entities = build(EntitiesDescriptor.DEFAULT_ELEMENT_NAME);
+		entities.getEntityDescriptors().addAll(entityDescriptors);
+		return serialize(entities);
+	}
+
+	private EntityDescriptor entityDescriptor(RelyingPartyRegistration registration) {
 		EntityDescriptor entityDescriptor = build(EntityDescriptor.DEFAULT_ELEMENT_NAME);
-		entityDescriptor.setEntityID(relyingPartyRegistration.getEntityId());
-		SPSSODescriptor spSsoDescriptor = buildSpSsoDescriptor(relyingPartyRegistration);
+		entityDescriptor.setEntityID(registration.getEntityId());
+		SPSSODescriptor spSsoDescriptor = buildSpSsoDescriptor(registration);
 		entityDescriptor.getRoleDescriptors(SPSSODescriptor.DEFAULT_ELEMENT_NAME).add(spSsoDescriptor);
-		this.entityDescriptorCustomizer
-				.accept(new EntityDescriptorParameters(entityDescriptor, relyingPartyRegistration));
-		return serialize(entityDescriptor);
+		this.entityDescriptorCustomizer.accept(new EntityDescriptorParameters(entityDescriptor, registration));
+		return entityDescriptor;
 	}
 
 	/**
@@ -184,6 +209,16 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
 		}
 	}
 
+	private String serialize(EntitiesDescriptor entities) {
+		try {
+			Element element = this.entitiesDescriptorMarshaller.marshall(entities);
+			return SerializeSupport.prettyPrintXML(element);
+		}
+		catch (Exception ex) {
+			throw new Saml2Exception(ex);
+		}
+	}
+
 	/**
 	 * A tuple containing an OpenSAML {@link EntityDescriptor} and its associated
 	 * {@link RelyingPartyRegistration}

+ 4 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/Saml2MetadataResolver.java

@@ -35,4 +35,8 @@ public interface Saml2MetadataResolver {
 	 */
 	String resolve(RelyingPartyRegistration relyingPartyRegistration);
 
+	default String resolve(Iterable<RelyingPartyRegistration> relyingPartyRegistrations) {
+		return resolve(relyingPartyRegistrations.iterator().next());
+	}
+
 }

+ 19 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java

@@ -16,6 +16,8 @@
 
 package org.springframework.security.saml2.provider.service.metadata;
 
+import java.util.List;
+
 import org.junit.jupiter.api.Test;
 
 import org.springframework.security.saml2.core.TestSaml2X509Credentials;
@@ -89,4 +91,21 @@ public class OpenSamlMetadataResolverTests {
 		assertThat(metadata).contains("<md:EntityDescriptor").contains("entityID=\"overriddenEntityId\"");
 	}
 
+	@Test
+	public void resolveIterableWhenRelyingPartiesThenMetadataMatches() {
+		RelyingPartyRegistration one = TestRelyingPartyRegistrations.full()
+				.assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT).build();
+		RelyingPartyRegistration two = TestRelyingPartyRegistrations.full().entityId("two")
+				.assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT).build();
+		OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver();
+		String metadata = openSamlMetadataResolver.resolve(List.of(one, two));
+		assertThat(metadata).contains("<md:EntitiesDescriptor").contains("<md:EntityDescriptor")
+				.contains("entityID=\"rp-entity-id\"").contains("two").contains("<md:KeyDescriptor use=\"signing\">")
+				.contains("<md:KeyDescriptor use=\"encryption\">")
+				.contains("<ds:X509Certificate>MIICgTCCAeoCCQCuVzyqFgMSyDANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UEBh")
+				.contains("Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect\"")
+				.contains("Location=\"https://rp.example.org/acs\" index=\"1\"")
+				.contains("ResponseLocation=\"https://rp.example.org/logout/saml2/response\"");
+	}
+
 }