浏览代码

Polish OpenSamlAuthenticationRequestFactory

- Refactored to use SAMLMetadataSignatureSigningParametersResolver

Issue gh-7758
Josh Cummings 4 年之前
父节点
当前提交
a36baffb3a

+ 79 - 51
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java

@@ -21,11 +21,14 @@ import java.security.PrivateKey;
 import java.security.cert.X509Certificate;
 import java.security.cert.X509Certificate;
 import java.time.Clock;
 import java.time.Clock;
 import java.time.Instant;
 import java.time.Instant;
-import java.util.Collection;
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.LinkedHashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Map;
 import java.util.UUID;
 import java.util.UUID;
 
 
+import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
 import net.shibboleth.utilities.java.support.xml.SerializeSupport;
 import net.shibboleth.utilities.java.support.xml.SerializeSupport;
 import org.joda.time.DateTime;
 import org.joda.time.DateTime;
 import org.opensaml.core.config.ConfigurationService;
 import org.opensaml.core.config.ConfigurationService;
@@ -37,15 +40,18 @@ import org.opensaml.saml.saml2.core.Issuer;
 import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder;
 import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder;
 import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller;
 import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller;
 import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
 import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
+import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver;
 import org.opensaml.security.SecurityException;
 import org.opensaml.security.SecurityException;
 import org.opensaml.security.credential.BasicCredential;
 import org.opensaml.security.credential.BasicCredential;
 import org.opensaml.security.credential.Credential;
 import org.opensaml.security.credential.Credential;
 import org.opensaml.security.credential.CredentialSupport;
 import org.opensaml.security.credential.CredentialSupport;
 import org.opensaml.security.credential.UsageType;
 import org.opensaml.security.credential.UsageType;
 import org.opensaml.xmlsec.SignatureSigningParameters;
 import org.opensaml.xmlsec.SignatureSigningParameters;
+import org.opensaml.xmlsec.SignatureSigningParametersResolver;
+import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion;
 import org.opensaml.xmlsec.crypto.XMLSigningUtil;
 import org.opensaml.xmlsec.crypto.XMLSigningUtil;
+import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration;
 import org.opensaml.xmlsec.signature.support.SignatureConstants;
 import org.opensaml.xmlsec.signature.support.SignatureConstants;
-import org.opensaml.xmlsec.signature.support.SignatureException;
 import org.opensaml.xmlsec.signature.support.SignatureSupport;
 import org.opensaml.xmlsec.signature.support.SignatureSupport;
 import org.w3c.dom.Element;
 import org.w3c.dom.Element;
 
 
@@ -58,6 +64,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
 import org.springframework.util.Assert;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 import org.springframework.util.StringUtils;
+import org.springframework.web.util.UriComponentsBuilder;
 import org.springframework.web.util.UriUtils;
 import org.springframework.web.util.UriUtils;
 
 
 /**
 /**
@@ -105,9 +112,17 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 				request.getAssertionConsumerServiceUrl(), this.protocolBindingResolver.convert(null));
 				request.getAssertionConsumerServiceUrl(), this.protocolBindingResolver.convert(null));
 		for (org.springframework.security.saml2.credentials.Saml2X509Credential credential : request.getCredentials()) {
 		for (org.springframework.security.saml2.credentials.Saml2X509Credential credential : request.getCredentials()) {
 			if (credential.isSigningCredential()) {
 			if (credential.isSigningCredential()) {
-				Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(),
-						request.getIssuer());
-				return serialize(sign(authnRequest, cred));
+				X509Certificate certificate = credential.getCertificate();
+				PrivateKey privateKey = credential.getPrivateKey();
+				BasicCredential cred = CredentialSupport.getSimpleCredential(certificate, privateKey);
+				cred.setEntityId(request.getIssuer());
+				cred.setUsageType(UsageType.SIGNING);
+				SignatureSigningParameters parameters = new SignatureSigningParameters();
+				parameters.setSigningCredential(cred);
+				parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
+				parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256);
+				parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
+				return serialize(sign(authnRequest, parameters));
 			}
 			}
 		}
 		}
 		throw new IllegalArgumentException("No signing credential provided");
 		throw new IllegalArgumentException("No signing credential provided");
@@ -132,16 +147,13 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 		String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
 		String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
 		result.samlRequest(deflatedAndEncoded).relayState(context.getRelayState());
 		result.samlRequest(deflatedAndEncoded).relayState(context.getRelayState());
 		if (context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
 		if (context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
-			Collection<Saml2X509Credential> signingCredentials = context.getRelyingPartyRegistration()
-					.getSigningX509Credentials();
-			for (Saml2X509Credential credential : signingCredentials) {
-				Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), "");
-				Map<String, String> signedParams = signQueryParameters(cred, deflatedAndEncoded,
-						context.getRelayState());
-				return result.samlRequest(signedParams.get("SAMLRequest")).relayState(signedParams.get("RelayState"))
-						.sigAlg(signedParams.get("SigAlg")).signature(signedParams.get("Signature")).build();
+			Map<String, String> parameters = new LinkedHashMap<>();
+			parameters.put("SAMLRequest", deflatedAndEncoded);
+			if (StringUtils.hasText(context.getRelayState())) {
+				parameters.put("RelayState", context.getRelayState());
 			}
 			}
-			throw new Saml2Exception("No signing credential provided");
+			sign(parameters, context.getRelyingPartyRegistration());
+			return result.sigAlg(parameters.get("SigAlg")).signature(parameters.get("Signature")).build();
 		}
 		}
 		return result.build();
 		return result.build();
 	}
 	}
@@ -211,59 +223,39 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 	}
 	}
 
 
 	private AuthnRequest sign(AuthnRequest authnRequest, RelyingPartyRegistration relyingPartyRegistration) {
 	private AuthnRequest sign(AuthnRequest authnRequest, RelyingPartyRegistration relyingPartyRegistration) {
-		for (Saml2X509Credential credential : relyingPartyRegistration.getSigningX509Credentials()) {
-			Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(),
-					relyingPartyRegistration.getEntityId());
-			return sign(authnRequest, cred);
-		}
-		throw new IllegalArgumentException("No signing credential provided");
+		SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration);
+		return sign(authnRequest, parameters);
 	}
 	}
 
 
-	private AuthnRequest sign(AuthnRequest authnRequest, Credential credential) {
-		SignatureSigningParameters parameters = new SignatureSigningParameters();
-		parameters.setSigningCredential(credential);
-		parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
-		parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256);
-		parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
+	private AuthnRequest sign(AuthnRequest authnRequest, SignatureSigningParameters parameters) {
 		try {
 		try {
 			SignatureSupport.signObject(authnRequest, parameters);
 			SignatureSupport.signObject(authnRequest, parameters);
 			return authnRequest;
 			return authnRequest;
 		}
 		}
-		catch (MarshallingException | SignatureException | SecurityException ex) {
+		catch (Exception ex) {
 			throw new Saml2Exception(ex);
 			throw new Saml2Exception(ex);
 		}
 		}
 	}
 	}
 
 
-	private Credential getSigningCredential(X509Certificate certificate, PrivateKey privateKey, String entityId) {
-		BasicCredential cred = CredentialSupport.getSimpleCredential(certificate, privateKey);
-		cred.setEntityId(entityId);
-		cred.setUsageType(UsageType.SIGNING);
-		return cred;
+	private void sign(Map<String, String> components, RelyingPartyRegistration relyingPartyRegistration) {
+		SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration);
+		sign(components, parameters);
 	}
 	}
 
 
-	private Map<String, String> signQueryParameters(Credential credential, String samlRequest, String relayState) {
-		Assert.notNull(samlRequest, "samlRequest cannot be null");
-		String algorithmUri = SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256;
-		StringBuilder queryString = new StringBuilder();
-		queryString.append("SAMLRequest").append("=").append(UriUtils.encode(samlRequest, StandardCharsets.ISO_8859_1))
-				.append("&");
-		if (StringUtils.hasText(relayState)) {
-			queryString.append("RelayState").append("=")
-					.append(UriUtils.encode(relayState, StandardCharsets.ISO_8859_1)).append("&");
+	private void sign(Map<String, String> components, SignatureSigningParameters parameters) {
+		Credential credential = parameters.getSigningCredential();
+		String algorithmUri = parameters.getSignatureAlgorithm();
+		components.put("SigAlg", algorithmUri);
+		UriComponentsBuilder builder = UriComponentsBuilder.newInstance();
+		for (Map.Entry<String, String> component : components.entrySet()) {
+			builder.queryParam(component.getKey(), UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1));
 		}
 		}
-		queryString.append("SigAlg").append("=").append(UriUtils.encode(algorithmUri, StandardCharsets.ISO_8859_1));
+		String queryString = builder.build(true).toString().substring(1);
 		try {
 		try {
 			byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri,
 			byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri,
-					queryString.toString().getBytes(StandardCharsets.UTF_8));
+					queryString.getBytes(StandardCharsets.UTF_8));
 			String b64Signature = Saml2Utils.samlEncode(rawSignature);
 			String b64Signature = Saml2Utils.samlEncode(rawSignature);
-			Map<String, String> result = new LinkedHashMap<>();
-			result.put("SAMLRequest", samlRequest);
-			if (StringUtils.hasText(relayState)) {
-				result.put("RelayState", relayState);
-			}
-			result.put("SigAlg", algorithmUri);
-			result.put("Signature", b64Signature);
-			return result;
+			components.put("Signature", b64Signature);
 		}
 		}
 		catch (SecurityException ex) {
 		catch (SecurityException ex) {
 			throw new Saml2Exception(ex);
 			throw new Saml2Exception(ex);
@@ -280,4 +272,40 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 		}
 		}
 	}
 	}
 
 
+	private SignatureSigningParameters resolveSigningParameters(RelyingPartyRegistration relyingPartyRegistration) {
+		List<Credential> credentials = resolveSigningCredentials(relyingPartyRegistration);
+		List<String> algorithms = Collections.singletonList(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
+		List<String> digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256);
+		String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS;
+		SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver();
+		CriteriaSet criteria = new CriteriaSet();
+		BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration();
+		signingConfiguration.setSigningCredentials(credentials);
+		signingConfiguration.setSignatureAlgorithms(algorithms);
+		signingConfiguration.setSignatureReferenceDigestMethods(digests);
+		signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization);
+		criteria.add(new SignatureSigningConfigurationCriterion(signingConfiguration));
+		try {
+			SignatureSigningParameters parameters = resolver.resolveSingle(criteria);
+			Assert.notNull(parameters, "Failed to resolve any signing credential");
+			return parameters;
+		}
+		catch (Exception ex) {
+			throw new Saml2Exception(ex);
+		}
+	}
+
+	private List<Credential> resolveSigningCredentials(RelyingPartyRegistration relyingPartyRegistration) {
+		List<Credential> credentials = new ArrayList<>();
+		for (Saml2X509Credential x509Credential : relyingPartyRegistration.getSigningX509Credentials()) {
+			X509Certificate certificate = x509Credential.getCertificate();
+			PrivateKey privateKey = x509Credential.getPrivateKey();
+			BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey);
+			credential.setEntityId(relyingPartyRegistration.getEntityId());
+			credential.setUsageType(UsageType.SIGNING);
+			credentials.add(credential);
+		}
+		return credentials;
+	}
+
 }
 }

+ 38 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java

@@ -26,16 +26,20 @@ import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
 import org.opensaml.saml.common.xml.SAMLConstants;
 import org.opensaml.saml.common.xml.SAMLConstants;
 import org.opensaml.saml.saml2.core.AuthnRequest;
 import org.opensaml.saml.saml2.core.AuthnRequest;
 import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
 import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
+import org.opensaml.xmlsec.signature.support.SignatureConstants;
 import org.w3c.dom.Document;
 import org.w3c.dom.Document;
 import org.w3c.dom.Element;
 import org.w3c.dom.Element;
 
 
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.saml2.Saml2Exception;
 import org.springframework.security.saml2.Saml2Exception;
+import org.springframework.security.saml2.core.Saml2X509Credential;
 import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
 import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
+import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
@@ -110,6 +114,28 @@ public class OpenSamlAuthenticationRequestFactoryTests {
 		assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
 		assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
 	}
 	}
 
 
+	@Test
+	public void createRedirectAuthenticationRequestWhenSignRequestThenSignatureIsPresent() {
+		this.context = this.contextBuilder.relayState("Relay State Value")
+				.relyingPartyRegistration(this.relyingPartyRegistration).build();
+		Saml2RedirectAuthenticationRequest request = this.factory.createRedirectAuthenticationRequest(this.context);
+		assertThat(request.getRelayState()).isEqualTo("Relay State Value");
+		assertThat(request.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
+		assertThat(request.getSignature()).isNotNull();
+	}
+
+	@Test
+	public void createRedirectAuthenticationRequestWhenSignRequestThenCredentialIsRequired() {
+		Saml2X509Credential credential = org.springframework.security.saml2.core.TestSaml2X509Credentials
+				.relyingPartyVerifyingCredential();
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials()
+				.assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build();
+		this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration)
+				.build();
+		assertThatExceptionOfType(Saml2Exception.class)
+				.isThrownBy(() -> this.factory.createPostAuthenticationRequest(this.context));
+	}
+
 	@Test
 	@Test
 	public void createPostAuthenticationRequestWhenNotSignRequestThenNoSignatureIsPresent() {
 	public void createPostAuthenticationRequestWhenNotSignRequestThenNoSignatureIsPresent() {
 		this.context = this.contextBuilder.relayState("Relay State Value")
 		this.context = this.contextBuilder.relayState("Relay State Value")
@@ -139,6 +165,18 @@ public class OpenSamlAuthenticationRequestFactoryTests {
 				.contains("ds:Signature");
 				.contains("ds:Signature");
 	}
 	}
 
 
+	@Test
+	public void createPostAuthenticationRequestWhenSignRequestThenCredentialIsRequired() {
+		Saml2X509Credential credential = org.springframework.security.saml2.core.TestSaml2X509Credentials
+				.relyingPartyVerifyingCredential();
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials()
+				.assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build();
+		this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration)
+				.build();
+		assertThatExceptionOfType(Saml2Exception.class)
+				.isThrownBy(() -> this.factory.createPostAuthenticationRequest(this.context));
+	}
+
 	@Test
 	@Test
 	public void createAuthenticationRequestWhenDefaultThenReturnsPostBinding() {
 	public void createAuthenticationRequestWhenDefaultThenReturnsPostBinding() {
 		AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST);
 		AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST);