Ver Fonte

OpenSamlAuthenticationProvider Uses OpenSAML Directly

Closes gh-8773
Josh Cummings há 5 anos atrás
pai
commit
2e2da06bdb

+ 37 - 14
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java

@@ -15,6 +15,8 @@
  */
 package org.springframework.security.saml2.provider.service.authentication;
 
+import java.io.ByteArrayInputStream;
+import java.nio.charset.StandardCharsets;
 import java.time.Duration;
 import java.time.Instant;
 import java.util.ArrayList;
@@ -32,13 +34,17 @@ import java.util.function.Function;
 import javax.annotation.Nonnull;
 
 import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
+import net.shibboleth.utilities.java.support.xml.ParserPool;
+import net.shibboleth.utilities.java.support.xml.SerializeSupport;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.joda.time.DateTime;
+import org.opensaml.core.config.ConfigurationService;
 import org.opensaml.core.criterion.EntityIdCriterion;
 import org.opensaml.core.xml.XMLObject;
-import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
+import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
 import org.opensaml.core.xml.io.Marshaller;
+import org.opensaml.core.xml.io.MarshallingException;
 import org.opensaml.core.xml.schema.XSAny;
 import org.opensaml.core.xml.schema.XSBoolean;
 import org.opensaml.core.xml.schema.XSBooleanValue;
@@ -65,6 +71,7 @@ import org.opensaml.saml.saml2.core.EncryptedID;
 import org.opensaml.saml.saml2.core.NameID;
 import org.opensaml.saml.saml2.core.Response;
 import org.opensaml.saml.saml2.core.SubjectConfirmation;
+import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller;
 import org.opensaml.saml.saml2.encryption.Decrypter;
 import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver;
 import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator;
@@ -88,6 +95,8 @@ import org.opensaml.xmlsec.keyinfo.impl.StaticKeyInfoCredentialResolver;
 import org.opensaml.xmlsec.signature.support.SignaturePrevalidator;
 import org.opensaml.xmlsec.signature.support.SignatureTrustEngine;
 import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine;
+import org.w3c.dom.Document;
+import org.w3c.dom.Element;
 
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.authentication.AbstractAuthenticationToken;
@@ -120,7 +129,6 @@ import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_IS
 import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE;
 import static org.springframework.security.saml2.core.Saml2ErrorCodes.MALFORMED_RESPONSE_DATA;
 import static org.springframework.security.saml2.core.Saml2ErrorCodes.SUBJECT_NOT_FOUND;
-import static org.springframework.security.saml2.core.Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS;
 import static org.springframework.util.Assert.notNull;
 
 /**
@@ -167,7 +175,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 
 	private static Log logger = LogFactory.getLog(OpenSamlAuthenticationProvider.class);
 
-	private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
+	private final XMLObjectProviderRegistry registry;
+	private final ResponseUnmarshaller responseUnmarshaller;
+	private final ParserPool parserPool;
 
 	private Converter<Assertion, Collection<? extends GrantedAuthority>> authoritiesExtractor =
 			(a -> singletonList(new SimpleGrantedAuthority("ROLE_USER")));
@@ -192,6 +202,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 						this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
 			};
 
+	/**
+	 * Creates an {@link OpenSamlAuthenticationProvider}
+	 */
+	public OpenSamlAuthenticationProvider() {
+		this.registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
+		this.responseUnmarshaller = (ResponseUnmarshaller) this.registry.getUnmarshallerFactory()
+				.getUnmarshaller(Response.DEFAULT_ELEMENT_NAME);
+		this.parserPool = this.registry.getParserPool();
+	}
+
 	/**
 	 * Sets the {@link Converter} used for extracting assertion attributes that
 	 * can be mapped to authorities.
@@ -265,15 +285,13 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 
 	private Response parse(String response) throws Saml2Exception, Saml2AuthenticationException {
 		try {
-			Object result = this.saml.resolve(response);
-			if (result instanceof Response) {
-				return (Response) result;
-			}
-			else {
-				throw authException(UNKNOWN_RESPONSE_CLASS, "Invalid response class:" + result.getClass().getName());
-			}
-		} catch (Saml2Exception x) {
-			throw authException(MALFORMED_RESPONSE_DATA, x.getMessage(), x);
+			Document document = this.parserPool.parse(new ByteArrayInputStream(
+					response.getBytes(StandardCharsets.UTF_8)));
+			Element element = document.getDocumentElement();
+			return (Response) this.responseUnmarshaller.unmarshall(element);
+		}
+		catch (Exception e) {
+			throw authException(MALFORMED_RESPONSE_DATA, e.getMessage(), e);
 		}
 	}
 
@@ -427,9 +445,14 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 	}
 
 	private Object getXSAnyObjectValue(XSAny xsAny) {
-		Marshaller marshaller = XMLObjectProviderRegistrySupport.getMarshallerFactory().getMarshaller(xsAny);
+		Marshaller marshaller = this.registry.getMarshallerFactory().getMarshaller(xsAny);
 		if (marshaller != null) {
-			return this.saml.serialize(xsAny);
+			try {
+				Element element = marshaller.marshall(xsAny);
+				return SerializeSupport.nodeToString(element);
+			} catch (MarshallingException e) {
+				throw new Saml2Exception(e);
+			}
 		}
 		return xsAny.getTextContent();
 	}

+ 22 - 8
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java

@@ -29,6 +29,7 @@ import java.util.Map;
 import javax.xml.parsers.DocumentBuilder;
 import javax.xml.parsers.DocumentBuilderFactory;
 
+import net.shibboleth.utilities.java.support.xml.SerializeSupport;
 import org.hamcrest.BaseMatcher;
 import org.hamcrest.Description;
 import org.hamcrest.Matcher;
@@ -40,6 +41,7 @@ import org.junit.rules.ExpectedException;
 import org.opensaml.core.xml.XMLObject;
 import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
 import org.opensaml.core.xml.io.Marshaller;
+import org.opensaml.core.xml.io.MarshallingException;
 import org.opensaml.saml.saml2.core.Assertion;
 import org.opensaml.saml.saml2.core.AttributeStatement;
 import org.opensaml.saml.saml2.core.AttributeValue;
@@ -52,6 +54,7 @@ import org.w3c.dom.Element;
 import org.xml.sax.InputSource;
 
 import org.springframework.security.core.Authentication;
+import org.springframework.security.saml2.Saml2Exception;
 import org.springframework.security.saml2.credentials.Saml2X509Credential;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -60,6 +63,8 @@ import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
+import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
+import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
@@ -85,8 +90,6 @@ public class OpenSamlAuthenticationProviderTests {
 	private static String ASSERTING_PARTY_ENTITY_ID = "https://some.idp.test/saml2/idp";
 
 	private OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
-	private OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
-
 
 	@Rule
 	public ExpectedException exception = ExpectedException.none();
@@ -108,10 +111,11 @@ public class OpenSamlAuthenticationProviderTests {
 
 	@Test
 	public void authenticateWhenUnknownDataClassThenThrowAuthenticationException() {
-		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS));
+		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA));
 
-		Assertion assertion = this.saml.buildSamlObject(Assertion.DEFAULT_ELEMENT_NAME);
-		this.provider.authenticate(token(this.saml.serialize(assertion), relyingPartyVerifyingCredential()));
+		Assertion assertion = (Assertion) getBuilderFactory().getBuilder(Assertion.DEFAULT_ELEMENT_NAME)
+				.buildObject(Assertion.DEFAULT_ELEMENT_NAME);
+		this.provider.authenticate(token(serialize(assertion), relyingPartyVerifyingCredential()));
 	}
 
 	@Test
@@ -316,7 +320,7 @@ public class OpenSamlAuthenticationProviderTests {
 		Response response = response();
 		EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential());
 		response.getEncryptedAssertions().add(encryptedAssertion);
-		Saml2AuthenticationToken token = token(this.saml.serialize(response), relyingPartyVerifyingCredential());
+		Saml2AuthenticationToken token = token(serialize(response), relyingPartyVerifyingCredential());
 		this.provider.authenticate(token);
 	}
 
@@ -329,7 +333,7 @@ public class OpenSamlAuthenticationProviderTests {
 		Response response = response();
 		EncryptedAssertion encryptedAssertion = encrypted(assertion(), assertingPartyEncryptingCredential());
 		response.getEncryptedAssertions().add(encryptedAssertion);
-		Saml2AuthenticationToken token = token(this.saml.serialize(response), assertingPartyPrivateCredential());
+		Saml2AuthenticationToken token = token(serialize(response), assertingPartyPrivateCredential());
 		this.provider.authenticate(token);
 	}
 
@@ -349,6 +353,16 @@ public class OpenSamlAuthenticationProviderTests {
 		objectOutputStream.flush();
 	}
 
+	private String serialize(XMLObject object) {
+		try {
+			Marshaller marshaller = getMarshallerFactory().getMarshaller(object);
+			Element element = marshaller.marshall(object);
+			return SerializeSupport.nodeToString(element);
+		} catch (MarshallingException e) {
+			throw new Saml2Exception(e);
+		}
+	}
+
 	private Matcher<Saml2AuthenticationException> authenticationMatcher(String code) {
 		return authenticationMatcher(code, null);
 	}
@@ -382,7 +396,7 @@ public class OpenSamlAuthenticationProviderTests {
 	}
 
 	private Saml2AuthenticationToken token(Response response, Saml2X509Credential... credentials) {
-		String payload = this.saml.serialize(response);
+		String payload = serialize(response);
 		return token(payload, credentials);
 	}