Bläddra i källkod

Add method to customize EntityDescriptor

Closes gh-10839
Ulrich Grave 3 år sedan
förälder
incheckning
d225205bf2

+ 39 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -21,6 +21,7 @@ import java.util.ArrayList;
 import java.util.Base64;
 import java.util.Collection;
 import java.util.List;
+import java.util.function.Consumer;
 
 import javax.xml.namespace.QName;
 
@@ -63,6 +64,9 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
 
 	private final EntityDescriptorMarshaller entityDescriptorMarshaller;
 
+	private Consumer<EntityDescriptorParameters> entityDescriptorCustomizer = (parameters) -> {
+	};
+
 	public OpenSamlMetadataResolver() {
 		this.entityDescriptorMarshaller = (EntityDescriptorMarshaller) XMLObjectProviderRegistrySupport
 				.getMarshallerFactory().getMarshaller(EntityDescriptor.DEFAULT_ELEMENT_NAME);
@@ -75,9 +79,22 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
 		entityDescriptor.setEntityID(relyingPartyRegistration.getEntityId());
 		SPSSODescriptor spSsoDescriptor = buildSpSsoDescriptor(relyingPartyRegistration);
 		entityDescriptor.getRoleDescriptors(SPSSODescriptor.DEFAULT_ELEMENT_NAME).add(spSsoDescriptor);
+		this.entityDescriptorCustomizer
+				.accept(new EntityDescriptorParameters(entityDescriptor, relyingPartyRegistration));
 		return serialize(entityDescriptor);
 	}
 
+	/**
+	 * Set a {@link Consumer} for modifying the OpenSAML {@link EntityDescriptor}
+	 * @param entityDescriptorCustomizer a consumer that accepts an
+	 * {@link EntityDescriptorParameters}
+	 * @since 5.7
+	 */
+	public void setEntityDescriptorCustomizer(Consumer<EntityDescriptorParameters> entityDescriptorCustomizer) {
+		Assert.notNull(entityDescriptorCustomizer, "entityDescriptorCustomizer cannot be null");
+		this.entityDescriptorCustomizer = entityDescriptorCustomizer;
+	}
+
 	private SPSSODescriptor buildSpSsoDescriptor(RelyingPartyRegistration registration) {
 		SPSSODescriptor spSsoDescriptor = build(SPSSODescriptor.DEFAULT_ELEMENT_NAME);
 		spSsoDescriptor.addSupportedProtocol(SAMLConstants.SAML20P_NS);
@@ -163,4 +180,25 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
 		}
 	}
 
+	public static final class EntityDescriptorParameters {
+
+		private final EntityDescriptor entityDescriptor;
+
+		private final RelyingPartyRegistration registration;
+
+		public EntityDescriptorParameters(EntityDescriptor entityDescriptor, RelyingPartyRegistration registration) {
+			this.entityDescriptor = entityDescriptor;
+			this.registration = registration;
+		}
+
+		public EntityDescriptor getEntityDescriptor() {
+			return this.entityDescriptor;
+		}
+
+		public RelyingPartyRegistration getRegistration() {
+			return this.registration;
+		}
+
+	}
+
 }

+ 12 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -78,4 +78,15 @@ public class OpenSamlMetadataResolverTests {
 		assertThat(metadata).doesNotContain("ResponseLocation");
 	}
 
+	@Test
+	public void resolveWhenEntityDescriptorCustomizerThenUses() {
+		RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.full()
+				.entityId("originalEntityId").build();
+		OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver();
+		openSamlMetadataResolver.setEntityDescriptorCustomizer(
+				(parameters) -> parameters.getEntityDescriptor().setEntityID("overriddenEntityId"));
+		String metadata = openSamlMetadataResolver.resolve(relyingPartyRegistration);
+		assertThat(metadata).contains("<EntityDescriptor").contains("entityID=\"overriddenEntityId\"");
+	}
+
 }