Răsfoiți Sursa

Polish SAML Attribute Support

Issue gh-8661
Josh Cummings 5 ani în urmă
părinte
comite
360db53dd2

+ 0 - 3
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java

@@ -531,9 +531,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 	}
 
 	private Object getXmlObjectValue(XMLObject xmlObject) {
-		if (xmlObject == null) {
-			return null;
-		}
 		if (xmlObject instanceof XSAny) {
 			return getXSAnyObjectValue((XSAny) xmlObject);
 		}

+ 65 - 17
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java

@@ -19,11 +19,15 @@ package org.springframework.security.saml2.provider.service.authentication;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.ObjectOutputStream;
+import java.io.StringReader;
 import java.time.Instant;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.LinkedHashMap;
+import java.util.List;
 import java.util.Map;
+import javax.xml.parsers.DocumentBuilder;
+import javax.xml.parsers.DocumentBuilderFactory;
 
 import org.hamcrest.BaseMatcher;
 import org.hamcrest.Description;
@@ -33,27 +37,40 @@ import org.joda.time.Duration;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
+import org.opensaml.core.xml.XMLObject;
+import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
+import org.opensaml.core.xml.io.Marshaller;
 import org.opensaml.saml.saml2.core.Assertion;
+import org.opensaml.saml.saml2.core.AttributeStatement;
+import org.opensaml.saml.saml2.core.AttributeValue;
 import org.opensaml.saml.saml2.core.EncryptedAssertion;
 import org.opensaml.saml.saml2.core.EncryptedID;
 import org.opensaml.saml.saml2.core.NameID;
 import org.opensaml.saml.saml2.core.Response;
+import org.w3c.dom.Document;
+import org.w3c.dom.Element;
+import org.xml.sax.InputSource;
 
 import org.springframework.security.core.Authentication;
 import org.springframework.security.saml2.credentials.Saml2X509Credential;
 
-import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion;
-import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements;
-import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted;
-import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response;
-import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
-import static org.springframework.test.util.AssertionErrors.assertEquals;
-import static org.springframework.test.util.AssertionErrors.assertTrue;
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion;
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements;
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted;
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response;
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed;
 import static org.springframework.util.StringUtils.hasText;
 
 /**
@@ -203,24 +220,48 @@ public class OpenSamlAuthenticationProviderTests {
 	public void authenticateWhenAssertionContainsAttributesThenItSucceeds() {
 		Response response = response();
 		Assertion assertion = assertion();
-		attributeStatements().forEach(as -> assertion.getAttributeStatements().add(as));
+		List<AttributeStatement> attributes = attributeStatements();
+		assertion.getAttributeStatements().addAll(attributes);
 		signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
 		response.getAssertions().add(assertion);
 		Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
 		Authentication authentication = this.provider.authenticate(token);
 		Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal();
 
-		Map<String, Object> attributes = new LinkedHashMap<>();
-		attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com"));
-		attributes.put("name", Collections.singletonList("John Doe"));
-		attributes.put("age", Collections.singletonList(21));
-		attributes.put("website", Collections.singletonList("https://johndoe.com/"));
-		attributes.put("registered", Collections.singletonList(true));
+		Map<String, Object> expected = new LinkedHashMap<>();
+		expected.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com"));
+		expected.put("name", Collections.singletonList("John Doe"));
+		expected.put("age", Collections.singletonList(21));
+		expected.put("website", Collections.singletonList("https://johndoe.com/"));
+		expected.put("registered", Collections.singletonList(true));
 		Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis());
-		attributes.put("registeredDate", Collections.singletonList(registeredDate));
+		expected.put("registeredDate", Collections.singletonList(registeredDate));
 
-		assertEquals("Values should be equal", "John Doe", principal.getFirstAttribute("name"));
-		assertTrue("Attributes should be equal", attributes.equals(principal.getAttributes()));
+		assertEquals("John Doe", principal.getFirstAttribute("name"));
+		assertEquals(expected, principal.getAttributes());
+	}
+
+	@Test
+	public void authenticateWhenAttributeValueMarshallerConfiguredThenUses() throws Exception {
+		Response response = response();
+		Assertion assertion = assertion();
+		List<AttributeStatement> attributes = attributeStatements();
+		assertion.getAttributeStatements().addAll(attributes);
+		signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
+		response.getAssertions().add(assertion);
+		Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
+
+		Element attributeElement = element("<element>value</element>");
+		Marshaller marshaller = mock(Marshaller.class);
+		when(marshaller.marshall(any(XMLObject.class))).thenReturn(attributeElement);
+
+		try {
+			XMLObjectProviderRegistrySupport.getMarshallerFactory().registerMarshaller(AttributeValue.DEFAULT_ELEMENT_NAME, marshaller);
+			this.provider.authenticate(token);
+			verify(marshaller, atLeastOnce()).marshall(any(XMLObject.class));
+		} finally {
+			XMLObjectProviderRegistrySupport.getMarshallerFactory().deregisterMarshaller(AttributeValue.DEFAULT_ELEMENT_NAME);
+		}
 	}
 
 	@Test
@@ -352,4 +393,11 @@ public class OpenSamlAuthenticationProviderTests {
 		return new Saml2AuthenticationToken(payload,
 				DESTINATION, ASSERTING_PARTY_ENTITY_ID, RELYING_PARTY_ENTITY_ID, Arrays.asList(credentials));
 	}
+
+	private static Element element(String xml) throws Exception {
+		DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
+		DocumentBuilder builder = factory.newDocumentBuilder();
+		Document doc = builder.parse(new InputSource(new StringReader(xml)));
+		return doc.getDocumentElement();
+	}
 }