|
@@ -19,11 +19,15 @@ package org.springframework.security.saml2.provider.service.authentication;
|
|
|
import java.io.ByteArrayOutputStream;
|
|
|
import java.io.IOException;
|
|
|
import java.io.ObjectOutputStream;
|
|
|
+import java.io.StringReader;
|
|
|
import java.time.Instant;
|
|
|
import java.util.Arrays;
|
|
|
import java.util.Collections;
|
|
|
import java.util.LinkedHashMap;
|
|
|
+import java.util.List;
|
|
|
import java.util.Map;
|
|
|
+import javax.xml.parsers.DocumentBuilder;
|
|
|
+import javax.xml.parsers.DocumentBuilderFactory;
|
|
|
|
|
|
import org.hamcrest.BaseMatcher;
|
|
|
import org.hamcrest.Description;
|
|
@@ -33,27 +37,40 @@ import org.joda.time.Duration;
|
|
|
import org.junit.Rule;
|
|
|
import org.junit.Test;
|
|
|
import org.junit.rules.ExpectedException;
|
|
|
+import org.opensaml.core.xml.XMLObject;
|
|
|
+import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
|
|
|
+import org.opensaml.core.xml.io.Marshaller;
|
|
|
import org.opensaml.saml.saml2.core.Assertion;
|
|
|
+import org.opensaml.saml.saml2.core.AttributeStatement;
|
|
|
+import org.opensaml.saml.saml2.core.AttributeValue;
|
|
|
import org.opensaml.saml.saml2.core.EncryptedAssertion;
|
|
|
import org.opensaml.saml.saml2.core.EncryptedID;
|
|
|
import org.opensaml.saml.saml2.core.NameID;
|
|
|
import org.opensaml.saml.saml2.core.Response;
|
|
|
+import org.w3c.dom.Document;
|
|
|
+import org.w3c.dom.Element;
|
|
|
+import org.xml.sax.InputSource;
|
|
|
|
|
|
import org.springframework.security.core.Authentication;
|
|
|
import org.springframework.security.saml2.credentials.Saml2X509Credential;
|
|
|
|
|
|
-import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion;
|
|
|
-import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements;
|
|
|
-import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted;
|
|
|
-import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response;
|
|
|
-import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed;
|
|
|
+import static org.junit.Assert.assertEquals;
|
|
|
+import static org.junit.Assert.assertTrue;
|
|
|
+import static org.mockito.ArgumentMatchers.any;
|
|
|
+import static org.mockito.Mockito.atLeastOnce;
|
|
|
+import static org.mockito.Mockito.mock;
|
|
|
+import static org.mockito.Mockito.verify;
|
|
|
+import static org.mockito.Mockito.when;
|
|
|
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential;
|
|
|
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
|
|
|
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
|
|
|
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential;
|
|
|
import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
|
|
|
-import static org.springframework.test.util.AssertionErrors.assertEquals;
|
|
|
-import static org.springframework.test.util.AssertionErrors.assertTrue;
|
|
|
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion;
|
|
|
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements;
|
|
|
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted;
|
|
|
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response;
|
|
|
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed;
|
|
|
import static org.springframework.util.StringUtils.hasText;
|
|
|
|
|
|
/**
|
|
@@ -203,24 +220,48 @@ public class OpenSamlAuthenticationProviderTests {
|
|
|
public void authenticateWhenAssertionContainsAttributesThenItSucceeds() {
|
|
|
Response response = response();
|
|
|
Assertion assertion = assertion();
|
|
|
- attributeStatements().forEach(as -> assertion.getAttributeStatements().add(as));
|
|
|
+ List<AttributeStatement> attributes = attributeStatements();
|
|
|
+ assertion.getAttributeStatements().addAll(attributes);
|
|
|
signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
|
|
|
response.getAssertions().add(assertion);
|
|
|
Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
|
|
|
Authentication authentication = this.provider.authenticate(token);
|
|
|
Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal();
|
|
|
|
|
|
- Map<String, Object> attributes = new LinkedHashMap<>();
|
|
|
- attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com"));
|
|
|
- attributes.put("name", Collections.singletonList("John Doe"));
|
|
|
- attributes.put("age", Collections.singletonList(21));
|
|
|
- attributes.put("website", Collections.singletonList("https://johndoe.com/"));
|
|
|
- attributes.put("registered", Collections.singletonList(true));
|
|
|
+ Map<String, Object> expected = new LinkedHashMap<>();
|
|
|
+ expected.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com"));
|
|
|
+ expected.put("name", Collections.singletonList("John Doe"));
|
|
|
+ expected.put("age", Collections.singletonList(21));
|
|
|
+ expected.put("website", Collections.singletonList("https://johndoe.com/"));
|
|
|
+ expected.put("registered", Collections.singletonList(true));
|
|
|
Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis());
|
|
|
- attributes.put("registeredDate", Collections.singletonList(registeredDate));
|
|
|
+ expected.put("registeredDate", Collections.singletonList(registeredDate));
|
|
|
|
|
|
- assertEquals("Values should be equal", "John Doe", principal.getFirstAttribute("name"));
|
|
|
- assertTrue("Attributes should be equal", attributes.equals(principal.getAttributes()));
|
|
|
+ assertEquals("John Doe", principal.getFirstAttribute("name"));
|
|
|
+ assertEquals(expected, principal.getAttributes());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void authenticateWhenAttributeValueMarshallerConfiguredThenUses() throws Exception {
|
|
|
+ Response response = response();
|
|
|
+ Assertion assertion = assertion();
|
|
|
+ List<AttributeStatement> attributes = attributeStatements();
|
|
|
+ assertion.getAttributeStatements().addAll(attributes);
|
|
|
+ signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
|
|
|
+ response.getAssertions().add(assertion);
|
|
|
+ Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
|
|
|
+
|
|
|
+ Element attributeElement = element("<element>value</element>");
|
|
|
+ Marshaller marshaller = mock(Marshaller.class);
|
|
|
+ when(marshaller.marshall(any(XMLObject.class))).thenReturn(attributeElement);
|
|
|
+
|
|
|
+ try {
|
|
|
+ XMLObjectProviderRegistrySupport.getMarshallerFactory().registerMarshaller(AttributeValue.DEFAULT_ELEMENT_NAME, marshaller);
|
|
|
+ this.provider.authenticate(token);
|
|
|
+ verify(marshaller, atLeastOnce()).marshall(any(XMLObject.class));
|
|
|
+ } finally {
|
|
|
+ XMLObjectProviderRegistrySupport.getMarshallerFactory().deregisterMarshaller(AttributeValue.DEFAULT_ELEMENT_NAME);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
@Test
|
|
@@ -352,4 +393,11 @@ public class OpenSamlAuthenticationProviderTests {
|
|
|
return new Saml2AuthenticationToken(payload,
|
|
|
DESTINATION, ASSERTING_PARTY_ENTITY_ID, RELYING_PARTY_ENTITY_ID, Arrays.asList(credentials));
|
|
|
}
|
|
|
+
|
|
|
+ private static Element element(String xml) throws Exception {
|
|
|
+ DocumentBuilderFactory factory = DocumentBuilderFactory.newInstance();
|
|
|
+ DocumentBuilder builder = factory.newDocumentBuilder();
|
|
|
+ Document doc = builder.parse(new InputSource(new StringReader(xml)));
|
|
|
+ return doc.getDocumentElement();
|
|
|
+ }
|
|
|
}
|