Prechádzať zdrojové kódy

Add OpenSaml custom types to Saml2AuthenticatedPrincipal

OpenSaml custom types are added to Saml2AutehnticatedPrincipal as
attributes.

Closes gh-9696
pelesic 3 rokov pred
rodič
commit
f626d11c6e

+ 1 - 1
saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java

@@ -648,7 +648,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 		if (xmlObject instanceof XSDateTime) {
 			return ((XSDateTime) xmlObject).getValue();
 		}
-		return null;
+		return xmlObject;
 	}
 
 	private static Saml2AuthenticationException createAuthenticationException(String code, String message,

+ 26 - 0
saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java

@@ -250,6 +250,32 @@ public class OpenSaml4AuthenticationProviderTests {
 		assertThat(principal.getSessionIndexes()).contains("session-index");
 	}
 
+	@Test
+	public void authenticateWhenAssertionContainsCustomAttributesThenItSucceeds() {
+		XMLObjectProviderRegistrySupport.getMarshallerFactory().registerMarshaller(
+				TestCustomOpenSamlObject.CustomSamlObject.TYPE_NAME,
+				new TestCustomOpenSamlObject.CustomSamlObjectMarshaller());
+		XMLObjectProviderRegistrySupport.getUnmarshallerFactory().registerUnmarshaller(
+				TestCustomOpenSamlObject.CustomSamlObject.TYPE_NAME,
+				new TestCustomOpenSamlObject.CustomSamlObjectUnmarshaller());
+		Response response = response();
+		Assertion assertion = assertion();
+		List<AttributeStatement> attributes = TestOpenSamlObjects.customAttributeStatements();
+		assertion.getAttributeStatements().addAll(attributes);
+		TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
+				RELYING_PARTY_ENTITY_ID);
+		response.getAssertions().add(assertion);
+		Saml2AuthenticationToken token = token(response, verifying(registration()));
+		Authentication authentication = this.provider.authenticate(token);
+		Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal();
+		TestCustomOpenSamlObject.CustomSamlObject customSamlObject;
+		customSamlObject = (TestCustomOpenSamlObject.CustomSamlObject) principal.getAttribute("Address").get(0);
+		assertThat(customSamlObject.getStreet()).isEqualTo("Test Street");
+		assertThat(customSamlObject.getStreetNumber()).isEqualTo("1");
+		assertThat(customSamlObject.getZIP()).isEqualTo("11111");
+		assertThat(customSamlObject.getCity()).isEqualTo("Test City");
+	}
+
 	@Test
 	public void authenticateWhenEncryptedAssertionWithoutSignatureThenItFails() {
 		Response response = response();

+ 177 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestCustomOpenSamlObject.java

@@ -0,0 +1,177 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.authentication;
+
+import java.util.Collections;
+import java.util.List;
+
+import javax.annotation.Nonnull;
+import javax.annotation.Nullable;
+import javax.xml.namespace.QName;
+
+import net.shibboleth.utilities.java.support.xml.ElementSupport;
+import org.opensaml.core.xml.AbstractXMLObject;
+import org.opensaml.core.xml.AbstractXMLObjectBuilder;
+import org.opensaml.core.xml.ElementExtensibleXMLObject;
+import org.opensaml.core.xml.Namespace;
+import org.opensaml.core.xml.XMLObject;
+import org.opensaml.core.xml.io.AbstractXMLObjectMarshaller;
+import org.opensaml.core.xml.io.AbstractXMLObjectUnmarshaller;
+import org.opensaml.core.xml.io.UnmarshallingException;
+import org.opensaml.core.xml.schema.XSAny;
+import org.opensaml.core.xml.util.IndexedXMLObjectChildrenList;
+import org.opensaml.saml.common.xml.SAMLConstants;
+import org.opensaml.saml.saml2.core.AttributeValue;
+import org.w3c.dom.Element;
+
+public class TestCustomOpenSamlObject {
+
+	public interface CustomSamlObject extends ElementExtensibleXMLObject {
+
+		String TYPE_LOCAL_NAME = "CustomType";
+
+		String TYPE_CUSTOM_PREFIX = "custom";
+
+		String CUSTOM_NS = "https://custom.com/schema/custom";
+
+		/** QName of the CustomType type. */
+		QName TYPE_NAME = new QName(CUSTOM_NS, TYPE_LOCAL_NAME, TYPE_CUSTOM_PREFIX);
+
+		String getStreet();
+
+		String getStreetNumber();
+
+		String getZIP();
+
+		String getCity();
+
+	}
+
+	public static class CustomSamlObjectImpl extends AbstractXMLObject
+			implements TestCustomOpenSamlObject.CustomSamlObject {
+
+		@Nonnull
+		private IndexedXMLObjectChildrenList<XMLObject> unknownXMLObjects;
+
+		/**
+		 * Constructor.
+		 * @param namespaceURI the namespace the element is in
+		 * @param elementLocalName the local name of the XML element this Object
+		 * represents
+		 * @param namespacePrefix the prefix for the given namespace
+		 */
+		protected CustomSamlObjectImpl(@Nullable String namespaceURI, @Nonnull String elementLocalName,
+				@Nullable String namespacePrefix) {
+			super(namespaceURI, elementLocalName, namespacePrefix);
+			super.getNamespaceManager().registerNamespaceDeclaration(new Namespace(CUSTOM_NS, TYPE_CUSTOM_PREFIX));
+			this.unknownXMLObjects = new IndexedXMLObjectChildrenList<>(this);
+		}
+
+		@Nonnull
+		@Override
+		public List<XMLObject> getUnknownXMLObjects() {
+			return this.unknownXMLObjects;
+		}
+
+		@Nonnull
+		@Override
+		public List<XMLObject> getUnknownXMLObjects(@Nonnull QName typeOrName) {
+			return (List<XMLObject>) this.unknownXMLObjects.subList(typeOrName);
+		}
+
+		@Nullable
+		@Override
+		public List<XMLObject> getOrderedChildren() {
+			return Collections.unmodifiableList(this.unknownXMLObjects);
+		}
+
+		@Override
+		public String getStreet() {
+			return ((XSAny) getOrderedChildren().get(0)).getTextContent();
+		}
+
+		@Override
+		public String getStreetNumber() {
+			return ((XSAny) getOrderedChildren().get(1)).getTextContent();
+		}
+
+		@Override
+		public String getZIP() {
+			return ((XSAny) getOrderedChildren().get(2)).getTextContent();
+		}
+
+		@Override
+		public String getCity() {
+			return ((XSAny) getOrderedChildren().get(3)).getTextContent();
+		}
+
+	}
+
+	public static class CustomSamlObjectBuilder
+			extends AbstractXMLObjectBuilder<TestCustomOpenSamlObject.CustomSamlObject> {
+
+		@Nonnull
+		@Override
+		public TestCustomOpenSamlObject.CustomSamlObject buildObject(@Nullable String namespaceURI,
+				@Nonnull String localName, @Nullable String namespacePrefix) {
+			return new TestCustomOpenSamlObject.CustomSamlObjectImpl(namespaceURI, localName, namespacePrefix);
+		}
+
+	}
+
+	public static class CustomSamlObjectMarshaller extends AbstractXMLObjectMarshaller {
+
+		public CustomSamlObjectMarshaller() {
+			super();
+		}
+
+		@Override
+		protected void marshallElementContent(@Nonnull XMLObject xmlObject, @Nonnull Element domElement) {
+			final TestCustomOpenSamlObject.CustomSamlObject customSamlObject = (TestCustomOpenSamlObject.CustomSamlObject) xmlObject;
+
+			for (XMLObject object : customSamlObject.getOrderedChildren()) {
+				ElementSupport.appendChildElement(domElement, object.getDOM());
+			}
+		}
+
+	}
+
+	public static class CustomSamlObjectUnmarshaller extends AbstractXMLObjectUnmarshaller {
+
+		public CustomSamlObjectUnmarshaller() {
+			super();
+		}
+
+		@Override
+		protected void processChildElement(@Nonnull XMLObject parentXMLObject, @Nonnull XMLObject childXMLObject)
+				throws UnmarshallingException {
+			final TestCustomOpenSamlObject.CustomSamlObject customSamlObject = (TestCustomOpenSamlObject.CustomSamlObject) parentXMLObject;
+			super.processChildElement(customSamlObject, childXMLObject);
+			customSamlObject.getUnknownXMLObjects().add(childXMLObject);
+		}
+
+		@Nonnull
+		@Override
+		protected XMLObject buildXMLObject(@Nonnull Element domElement) {
+			return new TestCustomOpenSamlObject.CustomSamlObjectImpl(SAMLConstants.SAML20_NS,
+					AttributeValue.DEFAULT_ELEMENT_LOCAL_NAME,
+					TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX);
+		}
+
+	}
+
+}

+ 31 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java

@@ -296,6 +296,37 @@ public final class TestOpenSamlObjects {
 		return attribute;
 	}
 
+	static List<AttributeStatement> customAttributeStatements() {
+		List<AttributeStatement> attributeStatements = new ArrayList<>();
+		AttributeStatementBuilder attributeStatementBuilder = new AttributeStatementBuilder();
+		AttributeBuilder attributeBuilder = new AttributeBuilder();
+		Attribute attribute = attributeBuilder.buildObject();
+		attribute.setName("Address");
+		TestCustomOpenSamlObject.CustomSamlObject samlObject = new TestCustomOpenSamlObject.CustomSamlObjectBuilder()
+				.buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, TestCustomOpenSamlObject.CustomSamlObject.TYPE_NAME);
+		XSAny street = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS, "Street",
+				TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX);
+		street.setTextContent("Test Street");
+		samlObject.getUnknownXMLObjects().add(street);
+		XSAny streetNumber = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS,
+				"Number", TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX);
+		streetNumber.setTextContent("1");
+		samlObject.getUnknownXMLObjects().add(streetNumber);
+		XSAny zip = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS, "ZIP",
+				TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX);
+		zip.setTextContent("11111");
+		samlObject.getUnknownXMLObjects().add(zip);
+		XSAny city = new XSAnyBuilder().buildObject(TestCustomOpenSamlObject.CustomSamlObject.CUSTOM_NS, "City",
+				TestCustomOpenSamlObject.CustomSamlObject.TYPE_CUSTOM_PREFIX);
+		city.setTextContent("Test City");
+		samlObject.getUnknownXMLObjects().add(city);
+		attribute.getAttributeValues().add(samlObject);
+		AttributeStatement attributeStatement = attributeStatementBuilder.buildObject();
+		attributeStatement.getAttributes().add(attribute);
+		attributeStatements.add(attributeStatement);
+		return attributeStatements;
+	}
+
 	static List<AttributeStatement> attributeStatements() {
 		List<AttributeStatement> attributeStatements = new ArrayList<>();
 		AttributeStatementBuilder attributeStatementBuilder = new AttributeStatementBuilder();