Jelajahi Sumber

OpenSamlAuthenticationRequestFactory Uses OpenSAML Directly

Closes gh-8774
Josh Cummings 5 tahun lalu
induk
melakukan
5779121da6

+ 110 - 22
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java

@@ -16,30 +16,40 @@
 
 package org.springframework.security.saml2.provider.service.authentication;
 
+import java.nio.charset.StandardCharsets;
 import java.security.PrivateKey;
 import java.security.cert.X509Certificate;
 import java.time.Clock;
 import java.time.Instant;
 import java.util.Collection;
+import java.util.LinkedHashMap;
 import java.util.Map;
 import java.util.UUID;
 import java.util.function.Consumer;
 import java.util.function.Function;
 
+import net.shibboleth.utilities.java.support.xml.SerializeSupport;
 import org.joda.time.DateTime;
+import org.opensaml.core.config.ConfigurationService;
+import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
 import org.opensaml.core.xml.io.MarshallingException;
 import org.opensaml.saml.common.xml.SAMLConstants;
 import org.opensaml.saml.saml2.core.AuthnRequest;
 import org.opensaml.saml.saml2.core.Issuer;
+import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder;
+import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller;
+import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
 import org.opensaml.security.SecurityException;
 import org.opensaml.security.credential.BasicCredential;
 import org.opensaml.security.credential.Credential;
 import org.opensaml.security.credential.CredentialSupport;
 import org.opensaml.security.credential.UsageType;
 import org.opensaml.xmlsec.SignatureSigningParameters;
+import org.opensaml.xmlsec.crypto.XMLSigningUtil;
 import org.opensaml.xmlsec.signature.support.SignatureConstants;
 import org.opensaml.xmlsec.signature.support.SignatureException;
 import org.opensaml.xmlsec.signature.support.SignatureSupport;
+import org.w3c.dom.Element;
 
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.saml2.Saml2Exception;
@@ -47,11 +57,14 @@ import org.springframework.security.saml2.core.OpenSamlInitializationService;
 import org.springframework.security.saml2.core.Saml2X509Credential;
 import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest.Builder;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
 import org.springframework.util.Assert;
+import org.springframework.web.util.UriUtils;
 
 import static java.nio.charset.StandardCharsets.UTF_8;
 import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDeflate;
 import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlEncode;
+import static org.springframework.util.StringUtils.hasText;
 
 /**
  * @since 5.2
@@ -62,7 +75,10 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 	}
 
 	private Clock clock = Clock.systemUTC();
-	private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
+
+	private AuthnRequestMarshaller marshaller;
+	private AuthnRequestBuilder authnRequestBuilder;
+	private IssuerBuilder issuerBuilder;
 
 	private Converter<Saml2AuthenticationRequestContext, String> protocolBindingResolver =
 			context -> {
@@ -75,6 +91,19 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 	private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver
 			= context -> authnRequest -> {};
 
+	/**
+	 * Creates an {@link OpenSamlAuthenticationRequestFactory}
+	 */
+	public OpenSamlAuthenticationRequestFactory() {
+		XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
+		this.marshaller = (AuthnRequestMarshaller) registry.getMarshallerFactory()
+				.getMarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
+		this.authnRequestBuilder = (AuthnRequestBuilder) registry.getBuilderFactory()
+				.getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME);
+		this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory()
+				.getBuilder(Issuer.DEFAULT_ELEMENT_NAME);
+	}
+
 	@Override
 	@Deprecated
 	public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
@@ -84,8 +113,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 		for (org.springframework.security.saml2.credentials.Saml2X509Credential credential : request.getCredentials()) {
 			if (credential.isSigningCredential()) {
 				Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), request.getIssuer());
-				signAuthnRequest(authnRequest, cred);
-				return this.saml.serialize(authnRequest);
+				return serialize(sign(authnRequest, cred));
 			}
 		}
 		throw new IllegalArgumentException("No signing credential provided");
@@ -98,8 +126,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 	public Saml2PostAuthenticationRequest createPostAuthenticationRequest(Saml2AuthenticationRequestContext context) {
 		AuthnRequest authnRequest = createAuthnRequest(context);
 		String xml = context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned() ?
-			signThenSerialize(authnRequest, context.getRelyingPartyRegistration()) :
-			this.saml.serialize(authnRequest);
+			serialize(sign(authnRequest, context.getRelyingPartyRegistration())) :
+			serialize(authnRequest);
 
 		return Saml2PostAuthenticationRequest.withAuthenticationRequestContext(context)
 				.samlRequest(samlEncode(xml.getBytes(UTF_8)))
@@ -112,7 +140,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 	@Override
 	public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(Saml2AuthenticationRequestContext context) {
 		AuthnRequest authnRequest = createAuthnRequest(context);
-		String xml = this.saml.serialize(authnRequest);
+		String xml = serialize(authnRequest);
 		Builder result = Saml2RedirectAuthenticationRequest.withAuthenticationRequestContext(context);
 		String deflatedAndEncoded = samlEncode(samlDeflate(xml));
 		result.samlRequest(deflatedAndEncoded)
@@ -120,15 +148,20 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 
 		if (context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
 			Collection<Saml2X509Credential> signingCredentials = context.getRelyingPartyRegistration().getSigningX509Credentials();
-			Map<String, String> signedParams = this.saml.signQueryParameters(
-					signingCredentials,
-					deflatedAndEncoded,
-					context.getRelayState()
-			);
-			result.samlRequest(signedParams.get("SAMLRequest"))
-					.relayState(signedParams.get("RelayState"))
-					.sigAlg(signedParams.get("SigAlg"))
-					.signature(signedParams.get("Signature"));
+			for (Saml2X509Credential credential : signingCredentials) {
+				Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), "");
+				Map<String, String> signedParams = signQueryParameters(
+						cred,
+						deflatedAndEncoded,
+						context.getRelayState());
+				return result
+						.samlRequest(signedParams.get("SAMLRequest"))
+						.relayState(signedParams.get("RelayState"))
+						.sigAlg(signedParams.get("SigAlg"))
+						.signature(signedParams.get("Signature"))
+						.build();
+			}
+			throw new Saml2Exception("No signing credential provided");
 		}
 
 		return result.build();
@@ -144,13 +177,13 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 
 	private AuthnRequest createAuthnRequest
 			(String issuer, String destination, String assertionConsumerServiceUrl, String protocolBinding) {
-		AuthnRequest auth = this.saml.buildSamlObject(AuthnRequest.DEFAULT_ELEMENT_NAME);
+		AuthnRequest auth = this.authnRequestBuilder.buildObject();
 		auth.setID("ARQ" + UUID.randomUUID().toString().substring(1));
 		auth.setIssueInstant(new DateTime(this.clock.millis()));
 		auth.setForceAuthn(Boolean.FALSE);
 		auth.setIsPassive(Boolean.FALSE);
 		auth.setProtocolBinding(protocolBinding);
-		Issuer iss = this.saml.buildSamlObject(Issuer.DEFAULT_ELEMENT_NAME);
+		Issuer iss = this.issuerBuilder.buildObject();
 		iss.setValue(issuer);
 		auth.setIssuer(iss);
 		auth.setDestination(destination);
@@ -192,7 +225,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 	 * @param protocolBinding either {@link SAMLConstants#SAML2_POST_BINDING_URI} or
 	 * {@link SAMLConstants#SAML2_REDIRECT_BINDING_URI}
 	 * @throws IllegalArgumentException if the protocolBinding is not valid
-	 * @deprecated Use {@link org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.Builder#assertionConsumerServiceBinding}
+	 * @deprecated Use {@link org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.Builder#assertionConsumerServiceBinding(Saml2MessageBinding)}
 	 * instead
 	 */
 	@Deprecated
@@ -205,17 +238,16 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 		this.protocolBindingResolver = context -> protocolBinding;
 	}
 
-	private String signThenSerialize(AuthnRequest authnRequest, RelyingPartyRegistration relyingPartyRegistration) {
+	private AuthnRequest sign(AuthnRequest authnRequest, RelyingPartyRegistration relyingPartyRegistration) {
 		for (Saml2X509Credential credential : relyingPartyRegistration.getSigningX509Credentials()) {
 			Credential cred = getSigningCredential(
 					credential.getCertificate(), credential.getPrivateKey(), relyingPartyRegistration.getEntityId());
-			signAuthnRequest(authnRequest, cred);
-			return this.saml.serialize(authnRequest);
+			return sign(authnRequest, cred);
 		}
 		throw new IllegalArgumentException("No signing credential provided");
 	}
 
-	private void signAuthnRequest(AuthnRequest authnRequest, Credential credential) {
+	private AuthnRequest sign(AuthnRequest authnRequest, Credential credential) {
 		SignatureSigningParameters parameters = new SignatureSigningParameters();
 		parameters.setSigningCredential(credential);
 		parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
@@ -223,6 +255,7 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 		parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
 		try {
 			SignatureSupport.signObject(authnRequest, parameters);
+			return authnRequest;
 		} catch (MarshallingException | SignatureException | SecurityException e) {
 			throw new Saml2Exception(e);
 		}
@@ -234,4 +267,59 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 		cred.setUsageType(UsageType.SIGNING);
 		return cred;
 	}
+
+	private Map<String, String> signQueryParameters(
+			Credential credential,
+			String samlRequest,
+			String relayState) {
+		Assert.notNull(samlRequest, "samlRequest cannot be null");
+		String algorithmUri = SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256;
+		StringBuilder queryString = new StringBuilder();
+		queryString
+				.append("SAMLRequest")
+				.append("=")
+				.append(UriUtils.encode(samlRequest, StandardCharsets.ISO_8859_1))
+				.append("&");
+		if (hasText(relayState)) {
+			queryString
+					.append("RelayState")
+					.append("=")
+					.append(UriUtils.encode(relayState, StandardCharsets.ISO_8859_1))
+					.append("&");
+		}
+		queryString
+				.append("SigAlg")
+				.append("=")
+				.append(UriUtils.encode(algorithmUri, StandardCharsets.ISO_8859_1));
+
+		try {
+			byte[] rawSignature = XMLSigningUtil.signWithURI(
+					credential,
+					algorithmUri,
+					queryString.toString().getBytes(StandardCharsets.UTF_8)
+			);
+			String b64Signature = Saml2Utils.samlEncode(rawSignature);
+
+			Map<String, String> result = new LinkedHashMap<>();
+			result.put("SAMLRequest", samlRequest);
+			if (hasText(relayState)) {
+				result.put("RelayState", relayState);
+			}
+			result.put("SigAlg", algorithmUri);
+			result.put("Signature", b64Signature);
+			return result;
+		}
+		catch (SecurityException e) {
+			throw new Saml2Exception(e);
+		}
+	}
+
+	private String serialize(AuthnRequest authnRequest) {
+		try {
+			Element element = this.marshaller.marshall(authnRequest);
+			return SerializeSupport.nodeToString(element);
+		} catch (MarshallingException e) {
+			throw new Saml2Exception(e);
+		}
+	}
 }

+ 19 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java

@@ -16,6 +16,7 @@
 
 package org.springframework.security.saml2.provider.service.authentication;
 
+import java.io.ByteArrayInputStream;
 import java.util.function.Consumer;
 import java.util.function.Function;
 
@@ -26,7 +27,11 @@ import org.junit.Test;
 import org.junit.rules.ExpectedException;
 import org.opensaml.saml.common.xml.SAMLConstants;
 import org.opensaml.saml.saml2.core.AuthnRequest;
+import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
+import org.w3c.dom.Document;
+import org.w3c.dom.Element;
 
+import org.springframework.security.saml2.Saml2Exception;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
 
@@ -37,6 +42,8 @@ import static org.hamcrest.CoreMatchers.containsString;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
+import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getParserPool;
+import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getUnmarshallerFactory;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
 import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
 import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate;
@@ -56,6 +63,9 @@ public class OpenSamlAuthenticationRequestFactoryTests {
 	private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder;
 	private RelyingPartyRegistration relyingPartyRegistration;
 
+	private AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) getUnmarshallerFactory()
+			.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
+
 	@Rule
 	public ExpectedException exception = ExpectedException.none();
 
@@ -224,6 +234,14 @@ public class OpenSamlAuthenticationRequestFactoryTests {
 		else {
 			samlRequest = new String(samlDecode(samlRequest), UTF_8);
 		}
-		return (AuthnRequest) OpenSamlImplementation.getInstance().resolve(samlRequest);
+		try {
+			Document document = getParserPool().parse(
+					new ByteArrayInputStream(samlRequest.getBytes(UTF_8)));
+			Element element = document.getDocumentElement();
+			return (AuthnRequest) this.unmarshaller.unmarshall(element);
+		}
+		catch (Exception e) {
+			throw new Saml2Exception(e);
+		}
 	}
 }