浏览代码

Add SAML Attribute Support

Closes gh-8661
Nikola Kostic 5 年之前
父节点
当前提交
eed33228f4

+ 73 - 2
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,11 +17,13 @@ package org.springframework.security.saml2.provider.service.authentication;
 
 import java.security.cert.X509Certificate;
 import java.time.Duration;
+import java.time.Instant;
 import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.LinkedHashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
@@ -31,7 +33,19 @@ import javax.annotation.Nonnull;
 import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.joda.time.DateTime;
 import org.opensaml.core.criterion.EntityIdCriterion;
+import org.opensaml.core.xml.XMLObject;
+import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
+import org.opensaml.core.xml.io.Marshaller;
+
+import org.opensaml.core.xml.schema.XSAny;
+import org.opensaml.core.xml.schema.XSBoolean;
+import org.opensaml.core.xml.schema.XSBooleanValue;
+import org.opensaml.core.xml.schema.XSDateTime;
+import org.opensaml.core.xml.schema.XSInteger;
+import org.opensaml.core.xml.schema.XSString;
+import org.opensaml.core.xml.schema.XSURI;
 import org.opensaml.saml.common.assertion.ValidationContext;
 import org.opensaml.saml.common.assertion.ValidationResult;
 import org.opensaml.saml.common.xml.SAMLConstants;
@@ -45,6 +59,8 @@ import org.opensaml.saml.saml2.assertion.SubjectConfirmationValidator;
 import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator;
 import org.opensaml.saml.saml2.assertion.impl.BearerSubjectConfirmationValidator;
 import org.opensaml.saml.saml2.core.Assertion;
+import org.opensaml.saml.saml2.core.Attribute;
+import org.opensaml.saml.saml2.core.AttributeStatement;
 import org.opensaml.saml.saml2.core.EncryptedAssertion;
 import org.opensaml.saml.saml2.core.EncryptedID;
 import org.opensaml.saml.saml2.core.NameID;
@@ -205,8 +221,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 			List<Assertion> validAssertions = validateResponse(token, response);
 			Assertion assertion = validAssertions.get(0);
 			String username = getUsername(token, assertion);
+			Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
 			return new Saml2Authentication(
-					new SimpleSaml2AuthenticatedPrincipal(username), token.getSaml2Response(),
+					new SimpleSaml2AuthenticatedPrincipal(username, attributes), token.getSaml2Response(),
 					this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
 		} catch (Saml2AuthenticationException e) {
 			throw e;
@@ -494,6 +511,60 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		throw last;
 	}
 
+	private Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
+		Map<String, List<Object>> attributeMap = new LinkedHashMap<>();
+		for (AttributeStatement attributeStatement : assertion.getAttributeStatements()) {
+			for (Attribute attribute : attributeStatement.getAttributes()) {
+
+				List<Object> attributeValues = new ArrayList<>();
+				for (XMLObject xmlObject : attribute.getAttributeValues()) {
+					Object attributeValue = getXmlObjectValue(xmlObject);
+					if (attributeValue != null) {
+						attributeValues.add(attributeValue);
+					}
+				}
+				attributeMap.put(attribute.getName(), attributeValues);
+
+			}
+		}
+		return attributeMap;
+	}
+
+	private Object getXmlObjectValue(XMLObject xmlObject) {
+		if (xmlObject == null) {
+			return null;
+		}
+		if (xmlObject instanceof XSAny) {
+			return getXSAnyObjectValue((XSAny) xmlObject);
+		}
+		if (xmlObject instanceof XSString) {
+			return ((XSString) xmlObject).getValue();
+		}
+		if (xmlObject instanceof XSInteger) {
+			return ((XSInteger) xmlObject).getValue();
+		}
+		if (xmlObject instanceof XSURI) {
+			return ((XSURI) xmlObject).getValue();
+		}
+		if (xmlObject instanceof XSBoolean) {
+			XSBooleanValue xsBooleanValue = ((XSBoolean) xmlObject).getValue();
+			return xsBooleanValue != null ? xsBooleanValue.getValue() : null;
+		}
+		if (xmlObject instanceof XSDateTime) {
+			DateTime dateTime = ((XSDateTime) xmlObject).getValue();
+			return dateTime != null ? Instant.ofEpochMilli(dateTime.getMillis()) : null;
+		}
+		return null;
+	}
+
+	private Object getXSAnyObjectValue(XSAny xsAny) {
+		Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(xsAny);
+		if (marshaller != null) {
+			return this.saml.serialize(xsAny);
+		}
+		return xsAny.getTextContent();
+	}
+
 	private Saml2Error validationError(String code, String description) {
 		return new Saml2Error(code, description);
 	}

+ 43 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2AuthenticatedPrincipal.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,7 +16,13 @@
 
 package org.springframework.security.saml2.provider.service.authentication;
 
+import org.springframework.lang.Nullable;
 import org.springframework.security.core.AuthenticatedPrincipal;
+import org.springframework.util.CollectionUtils;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
 
 /**
  * Saml2 representation of an {@link AuthenticatedPrincipal}.
@@ -25,4 +31,40 @@ import org.springframework.security.core.AuthenticatedPrincipal;
  * @since 5.2.2
  */
 public interface Saml2AuthenticatedPrincipal extends AuthenticatedPrincipal {
+	/**
+	 * Get the first value of Saml2 token attribute by name
+	 *
+	 * @param name the name of the attribute
+	 * @param <A> the type of the attribute
+	 * @return the first attribute value or {@code null} otherwise
+	 * @since 5.4
+	 */
+	@Nullable
+	default <A> A getFirstAttribute(String name) {
+		List<A> values = getAttribute(name);
+		return CollectionUtils.firstElement(values);
+	}
+
+	/**
+	 * Get the Saml2 token attribute by name
+	 *
+	 * @param name the name of the attribute
+	 * @param <A> the type of the attribute
+	 * @return the attribute or {@code null} otherwise
+	 * @since 5.4
+	 */
+	@Nullable
+	default <A> List<A> getAttribute(String name) {
+		return (List<A>) getAttributes().get(name);
+	}
+
+	/**
+	 * Get the Saml2 token attributes
+	 *
+	 * @return the Saml2 token attributes
+	 * @since 5.4
+	 */
+	default Map<String, List<Object>> getAttributes() {
+		return Collections.emptyMap();
+	}
 }

+ 11 - 2
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/SimpleSaml2AuthenticatedPrincipal.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,6 +17,8 @@
 package org.springframework.security.saml2.provider.service.authentication;
 
 import java.io.Serializable;
+import java.util.List;
+import java.util.Map;
 
 /**
  * Default implementation of a {@link Saml2AuthenticatedPrincipal}.
@@ -27,13 +29,20 @@ import java.io.Serializable;
 class SimpleSaml2AuthenticatedPrincipal implements Saml2AuthenticatedPrincipal, Serializable {
 
 	private final String name;
+	private final Map<String, List<Object>> attributes;
 
-	SimpleSaml2AuthenticatedPrincipal(String name) {
+	SimpleSaml2AuthenticatedPrincipal(String name, Map<String, List<Object>> attributes) {
 		this.name = name;
+		this.attributes = attributes;
 	}
 
 	@Override
 	public String getName() {
 		return this.name;
 	}
+
+	@Override
+	public Map<String, List<Object>> getAttributes() {
+		return this.attributes;
+	}
 }

+ 30 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java

@@ -19,7 +19,11 @@ package org.springframework.security.saml2.provider.service.authentication;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.ObjectOutputStream;
+import java.time.Instant;
 import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedHashMap;
+import java.util.Map;
 
 import org.hamcrest.BaseMatcher;
 import org.hamcrest.Description;
@@ -39,6 +43,7 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.saml2.credentials.Saml2X509Credential;
 
 import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion;
+import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements;
 import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted;
 import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.response;
 import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.signed;
@@ -47,6 +52,7 @@ import static org.springframework.security.saml2.credentials.TestSaml2X509Creden
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
+import static org.springframework.test.util.AssertionErrors.assertEquals;
 import static org.springframework.test.util.AssertionErrors.assertTrue;
 import static org.springframework.util.StringUtils.hasText;
 
@@ -193,6 +199,30 @@ public class OpenSamlAuthenticationProviderTests {
 		this.provider.authenticate(token);
 	}
 
+	@Test
+	public void authenticateWhenAssertionContainsAttributesThenItSucceeds() {
+		Response response = response();
+		Assertion assertion = assertion();
+		attributeStatements().forEach(as -> assertion.getAttributeStatements().add(as));
+		signed(assertion, assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID);
+		response.getAssertions().add(assertion);
+		Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
+		Authentication authentication = this.provider.authenticate(token);
+		Saml2AuthenticatedPrincipal principal = (Saml2AuthenticatedPrincipal) authentication.getPrincipal();
+
+		Map<String, Object> attributes = new LinkedHashMap<>();
+		attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com"));
+		attributes.put("name", Collections.singletonList("John Doe"));
+		attributes.put("age", Collections.singletonList(21));
+		attributes.put("website", Collections.singletonList("https://johndoe.com/"));
+		attributes.put("registered", Collections.singletonList(true));
+		Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis());
+		attributes.put("registeredDate", Collections.singletonList(registeredDate));
+
+		assertEquals("Values should be equal", "John Doe", principal.getFirstAttribute("name"));
+		assertTrue("Attributes should be equal", attributes.equals(principal.getAttributes()));
+	}
+
 	@Test
 	public void authenticateWhenEncryptedAssertionWithoutSignatureThenItFails() throws Exception {
 		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.INVALID_SIGNATURE));

+ 47 - 4
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/SimpleSaml2AuthenticatedPrincipalTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,15 +16,58 @@
 
 package org.springframework.security.saml2.provider.service.authentication;
 
-import org.junit.Assert;
+import org.joda.time.DateTime;
 import org.junit.Test;
 
+import java.time.Instant;
+import java.util.Arrays;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
 public class SimpleSaml2AuthenticatedPrincipalTests {
 
 	@Test
 	public void createSimpleSaml2AuthenticatedPrincipal() {
-		SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user");
+		Map<String, List<Object>> attributes = new LinkedHashMap<>();
+		attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com"));
+		SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes);
+		assertThat(principal.getName()).isEqualTo("user");
+		assertThat(principal.getAttributes()).isEqualTo(attributes);
+	}
+
+	@Test
+	public void getFirstAttributeWhenStringValueThenReturnsValue() {
+		Map<String, List<Object>> attributes = new LinkedHashMap<>();
+		attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com"));
+		SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes);
+		assertThat(principal.<String>getFirstAttribute("email")).isEqualTo(attributes.get("email").get(0));
+	}
+
+	@Test
+	public void getAttributeWhenStringValuesThenReturnsValues() {
+		Map<String, List<Object>> attributes = new LinkedHashMap<>();
+		attributes.put("email", Arrays.asList("john.doe@example.com", "doe.john@example.com"));
+		SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes);
+		assertThat(principal.<String>getAttribute("email")).isEqualTo(attributes.get("email"));
+	}
+
+	@Test
+	public void getAttributeWhenDistinctValuesThenReturnsValues() {
+		final Boolean registered = true;
+		final Instant registeredDate = Instant.ofEpochMilli(DateTime.parse("1970-01-01T00:00:00Z").getMillis());
+
+		Map<String, List<Object>> attributes = new LinkedHashMap<>();
+		attributes.put("registration", Arrays.asList(registered, registeredDate));
+
+		SimpleSaml2AuthenticatedPrincipal principal = new SimpleSaml2AuthenticatedPrincipal("user", attributes);
+
+		List<Object> registrationInfo = principal.getAttribute("registration");
 
-		Assert.assertEquals("user", principal.getName());
+		assertThat(registrationInfo).isNotNull();
+		assertThat((Boolean) registrationInfo.get(0)).isEqualTo(registered);
+		assertThat((Instant) registrationInfo.get(1)).isEqualTo(registeredDate);
 	}
 }

+ 83 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/TestOpenSamlObjects.java

@@ -19,6 +19,8 @@ package org.springframework.security.saml2.provider.service.authentication;
 import java.security.cert.X509Certificate;
 import java.util.Base64;
 import java.util.UUID;
+import java.util.List;
+import java.util.ArrayList;
 import javax.crypto.SecretKey;
 import javax.crypto.spec.SecretKeySpec;
 
@@ -26,9 +28,26 @@ import org.apache.xml.security.encryption.XMLCipherParameters;
 import org.joda.time.DateTime;
 import org.joda.time.Duration;
 import org.opensaml.core.xml.io.MarshallingException;
+
+import org.opensaml.core.xml.schema.XSAny;
+import org.opensaml.core.xml.schema.XSBoolean;
+import org.opensaml.core.xml.schema.XSBooleanValue;
+import org.opensaml.core.xml.schema.XSDateTime;
+import org.opensaml.core.xml.schema.XSInteger;
+import org.opensaml.core.xml.schema.XSString;
+import org.opensaml.core.xml.schema.XSURI;
+import org.opensaml.core.xml.schema.impl.XSAnyBuilder;
+import org.opensaml.core.xml.schema.impl.XSBooleanBuilder;
+import org.opensaml.core.xml.schema.impl.XSDateTimeBuilder;
+import org.opensaml.core.xml.schema.impl.XSIntegerBuilder;
+import org.opensaml.core.xml.schema.impl.XSStringBuilder;
+import org.opensaml.core.xml.schema.impl.XSURIBuilder;
 import org.opensaml.saml.common.SAMLVersion;
 import org.opensaml.saml.common.SignableSAMLObject;
 import org.opensaml.saml.saml2.core.Assertion;
+import org.opensaml.saml.saml2.core.Attribute;
+import org.opensaml.saml.saml2.core.AttributeStatement;
+import org.opensaml.saml.saml2.core.AttributeValue;
 import org.opensaml.saml.saml2.core.Conditions;
 import org.opensaml.saml.saml2.core.EncryptedAssertion;
 import org.opensaml.saml.saml2.core.EncryptedID;
@@ -38,6 +57,8 @@ import org.opensaml.saml.saml2.core.Response;
 import org.opensaml.saml.saml2.core.Subject;
 import org.opensaml.saml.saml2.core.SubjectConfirmation;
 import org.opensaml.saml.saml2.core.SubjectConfirmationData;
+import org.opensaml.saml.saml2.core.impl.AttributeBuilder;
+import org.opensaml.saml.saml2.core.impl.AttributeStatementBuilder;
 import org.opensaml.saml.saml2.encryption.Encrypter;
 import org.opensaml.security.SecurityException;
 import org.opensaml.security.credential.BasicCredential;
@@ -222,4 +243,66 @@ final class TestOpenSamlObjects {
 
 		return encrypter;
 	}
+
+	static List<AttributeStatement> attributeStatements() {
+		List<AttributeStatement> attributeStatements = new ArrayList<>();
+
+		AttributeStatementBuilder attributeStatementBuilder = new AttributeStatementBuilder();
+		AttributeBuilder attributeBuilder = new AttributeBuilder();
+
+		AttributeStatement attrStmt1 = attributeStatementBuilder.buildObject();
+
+		Attribute emailAttr = attributeBuilder.buildObject();
+		emailAttr.setName("email");
+		XSAny email1 = new XSAnyBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME);
+		email1.setTextContent("john.doe@example.com");
+		emailAttr.getAttributeValues().add(email1);
+		XSAny email2 = new XSAnyBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME);
+		email2.setTextContent("doe.john@example.com");
+		emailAttr.getAttributeValues().add(email2);
+		attrStmt1.getAttributes().add(emailAttr);
+
+		Attribute nameAttr = attributeBuilder.buildObject();
+		nameAttr.setName("name");
+		XSString name = new XSStringBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSString.TYPE_NAME);
+		name.setValue("John Doe");
+		nameAttr.getAttributeValues().add(name);
+		attrStmt1.getAttributes().add(nameAttr);
+
+		Attribute ageAttr = attributeBuilder.buildObject();
+		ageAttr.setName("age");
+		XSInteger age = new XSIntegerBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSInteger.TYPE_NAME);
+		age.setValue(21);
+		ageAttr.getAttributeValues().add(age);
+		attrStmt1.getAttributes().add(ageAttr);
+
+		attributeStatements.add(attrStmt1);
+
+		AttributeStatement attrStmt2 = attributeStatementBuilder.buildObject();
+
+		Attribute websiteAttr = attributeBuilder.buildObject();
+		websiteAttr.setName("website");
+		XSURI uri = new XSURIBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSURI.TYPE_NAME);
+		uri.setValue("https://johndoe.com/");
+		websiteAttr.getAttributeValues().add(uri);
+		attrStmt2.getAttributes().add(websiteAttr);
+
+		Attribute registeredAttr = attributeBuilder.buildObject();
+		registeredAttr.setName("registered");
+		XSBoolean registered = new XSBooleanBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSBoolean.TYPE_NAME);
+		registered.setValue(new XSBooleanValue(true, false));
+		registeredAttr.getAttributeValues().add(registered);
+		attrStmt2.getAttributes().add(registeredAttr);
+
+		Attribute registeredDateAttr = attributeBuilder.buildObject();
+		registeredDateAttr.setName("registeredDate");
+		XSDateTime registeredDate = new XSDateTimeBuilder().buildObject(AttributeValue.DEFAULT_ELEMENT_NAME, XSDateTime.TYPE_NAME);
+		registeredDate.setValue(DateTime.parse("1970-01-01T00:00:00Z"));
+		registeredDateAttr.getAttributeValues().add(registeredDate);
+		attrStmt2.getAttributes().add(registeredDateAttr);
+
+		attributeStatements.add(attrStmt2);
+
+		return attributeStatements;
+	}
 }