瀏覽代碼

Refactor OpenSamlAuthenticationProvider

Refactored into collaborators in preparation for introducing setters

Issue gh-8769
Josh Cummings 5 年之前
父節點
當前提交
5bfc6ea25a

+ 424 - 230
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java

@@ -15,19 +15,20 @@
  */
 package org.springframework.security.saml2.provider.service.authentication;
 
-import java.security.cert.X509Certificate;
 import java.time.Duration;
 import java.time.Instant;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
-import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.Consumer;
+import java.util.function.Function;
 import javax.annotation.Nonnull;
 
 import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
@@ -38,7 +39,6 @@ import org.opensaml.core.criterion.EntityIdCriterion;
 import org.opensaml.core.xml.XMLObject;
 import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
 import org.opensaml.core.xml.io.Marshaller;
-
 import org.opensaml.core.xml.schema.XSAny;
 import org.opensaml.core.xml.schema.XSBoolean;
 import org.opensaml.core.xml.schema.XSBooleanValue;
@@ -53,7 +53,6 @@ import org.opensaml.saml.criterion.ProtocolCriterion;
 import org.opensaml.saml.metadata.criteria.role.impl.EvaluableProtocolRoleDescriptorCriterion;
 import org.opensaml.saml.saml2.assertion.ConditionValidator;
 import org.opensaml.saml.saml2.assertion.SAML20AssertionValidator;
-import org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters;
 import org.opensaml.saml.saml2.assertion.StatementValidator;
 import org.opensaml.saml.saml2.assertion.SubjectConfirmationValidator;
 import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator;
@@ -65,9 +64,9 @@ import org.opensaml.saml.saml2.core.EncryptedAssertion;
 import org.opensaml.saml.saml2.core.EncryptedID;
 import org.opensaml.saml.saml2.core.NameID;
 import org.opensaml.saml.saml2.core.Response;
-import org.opensaml.saml.saml2.core.Subject;
 import org.opensaml.saml.saml2.core.SubjectConfirmation;
 import org.opensaml.saml.saml2.encryption.Decrypter;
+import org.opensaml.saml.saml2.encryption.EncryptedElementTypeEncryptedKeyResolver;
 import org.opensaml.saml.security.impl.SAMLSignatureProfileValidator;
 import org.opensaml.security.credential.Credential;
 import org.opensaml.security.credential.CredentialResolver;
@@ -79,7 +78,11 @@ import org.opensaml.security.credential.impl.CollectionCredentialResolver;
 import org.opensaml.security.criteria.UsageCriterion;
 import org.opensaml.security.x509.BasicX509Credential;
 import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap;
+import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver;
 import org.opensaml.xmlsec.encryption.support.DecryptionException;
+import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver;
+import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver;
+import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver;
 import org.opensaml.xmlsec.keyinfo.KeyInfoCredentialResolver;
 import org.opensaml.xmlsec.keyinfo.impl.StaticKeyInfoCredentialResolver;
 import org.opensaml.xmlsec.signature.support.SignaturePrevalidator;
@@ -87,6 +90,7 @@ import org.opensaml.xmlsec.signature.support.SignatureTrustEngine;
 import org.opensaml.xmlsec.signature.support.impl.ExplicitKeySignatureTrustEngine;
 
 import org.springframework.core.convert.converter.Converter;
+import org.springframework.security.authentication.AbstractAuthenticationToken;
 import org.springframework.security.authentication.AuthenticationProvider;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
@@ -96,12 +100,15 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap
 import org.springframework.security.saml2.Saml2Exception;
 import org.springframework.security.saml2.credentials.Saml2X509Credential;
 import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
 import org.springframework.util.StringUtils;
 
+import static java.util.Arrays.asList;
 import static java.util.Collections.singleton;
 import static java.util.Collections.singletonList;
 import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.CLOCK_SKEW;
 import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.COND_VALID_AUDIENCES;
+import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
 import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED;
 import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.DECRYPTION_ERROR;
 import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.INTERNAL_VALIDATION_ERROR;
@@ -112,7 +119,6 @@ import static org.springframework.security.saml2.provider.service.authentication
 import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.MALFORMED_RESPONSE_DATA;
 import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.SUBJECT_NOT_FOUND;
 import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.UNKNOWN_RESPONSE_CLASS;
-import static org.springframework.security.saml2.provider.service.authentication.Saml2ErrorCodes.USERNAME_NOT_FOUND;
 import static org.springframework.util.Assert.notNull;
 
 /**
@@ -155,26 +161,31 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 
 	private static Log logger = LogFactory.getLog(OpenSamlAuthenticationProvider.class);
 
-	private final List<ConditionValidator> conditions = Collections.singletonList(new AudienceRestrictionConditionValidator());
-	private final SubjectConfirmationValidator subjectConfirmationValidator = new BearerSubjectConfirmationValidator() {
-		@Nonnull
-		@Override
-		protected ValidationResult validateAddress(@Nonnull SubjectConfirmation confirmation,
-				@Nonnull Assertion assertion, @Nonnull ValidationContext context) {
-			// skipping address validation - gh-7514
-			return ValidationResult.VALID;
-		}
-	};
-	private final List<SubjectConfirmationValidator> subjects = Collections.singletonList(this.subjectConfirmationValidator);
-	private final List<StatementValidator> statements = Collections.emptyList();
-	private final SignaturePrevalidator signaturePrevalidator = new SAMLSignatureProfileValidator();
-
 	private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
+
 	private Converter<Assertion, Collection<? extends GrantedAuthority>> authoritiesExtractor =
 			(a -> singletonList(new SimpleGrantedAuthority("ROLE_USER")));
 	private GrantedAuthoritiesMapper authoritiesMapper = (a -> a);
 	private Duration responseTimeValidationSkew = Duration.ofMinutes(5);
 
+	private Function<Saml2AuthenticationToken, Converter<Response, Map<String, Saml2AuthenticationException>>> responseValidator
+			= validator(Arrays.asList(new ResponseSignatureValidator(), new ResponseValidator()));
+	private Function<Saml2AuthenticationToken, Converter<EncryptedAssertion, Assertion>> assertionDecrypter
+			= new AssertionDecrypter();
+	private Function<Saml2AuthenticationToken, Converter<Assertion, Map<String, Saml2AuthenticationException>>> assertionValidator
+			= validator(Arrays.asList(new AssertionSignatureValidator(), new AssertionValidator.Builder().build()));
+	private Function<Saml2AuthenticationToken, Converter<EncryptedID, NameID>> principalDecrypter
+			= new PrincipalDecrypter();
+	private Function<Saml2AuthenticationToken, Converter<Response, AbstractAuthenticationToken>> authenticationConverter =
+			token -> response -> {
+				Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
+				String username = assertion.getSubject().getNameID().getValue();
+				Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
+				return new Saml2Authentication(
+						new SimpleSaml2AuthenticatedPrincipal(username, attributes), token.getSaml2Response(),
+						this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
+			};
+
 	/**
 	 * Sets the {@link Converter} used for extracting assertion attributes that
 	 * can be mapped to authorities.
@@ -204,6 +215,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 	 */
 	public void setResponseTimeValidationSkew(Duration responseTimeValidationSkew) {
 		this.responseTimeValidationSkew = responseTimeValidationSkew;
+		this.assertionValidator = validator(Arrays.asList(
+				new AssertionSignatureValidator(),
+				new AssertionValidator.Builder()
+					.validationContext(params -> params
+							.put(CLOCK_SKEW, responseTimeValidationSkew.toMillis()))
+					.build()));
 	}
 
 	/**
@@ -217,14 +234,10 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 	public Authentication authenticate(Authentication authentication) throws AuthenticationException {
 		try {
 			Saml2AuthenticationToken token = (Saml2AuthenticationToken) authentication;
-			Response response = parse(token.getSaml2Response());
-			List<Assertion> validAssertions = validateResponse(token, response);
-			Assertion assertion = validAssertions.get(0);
-			String username = getUsername(token, assertion);
-			Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
-			return new Saml2Authentication(
-					new SimpleSaml2AuthenticatedPrincipal(username, attributes), token.getSaml2Response(),
-					this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
+			String serializedResponse = token.getSaml2Response();
+			Response response = parse(serializedResponse);
+			process(token, response);
+			return this.authenticationConverter.apply(token).convert(response);
 		} catch (Saml2AuthenticationException e) {
 			throw e;
 		} catch (Exception e) {
@@ -259,88 +272,32 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 
 	}
 
-	private List<Assertion> validateResponse(Saml2AuthenticationToken token, Response response)
-			throws Saml2AuthenticationException {
-
-		List<Assertion> validAssertions = new ArrayList<>();
+	private void process(Saml2AuthenticationToken token, Response response) {
 		String issuer = response.getIssuer().getValue();
 		if (logger.isDebugEnabled()) {
-			logger.debug("Validating SAML response from " + issuer);
+			logger.debug("Processing SAML response from " + issuer);
 		}
 
-		List<Assertion> assertions = new ArrayList<>(response.getAssertions());
-		for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) {
-			Assertion assertion = decrypt(token, encryptedAssertion);
-			assertions.add(assertion);
-		}
-		if (assertions.isEmpty()) {
-			throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response.");
-		}
+		boolean responseSigned = response.isSigned();
+		Map<String, Saml2AuthenticationException> validationExceptions = validateResponse(token, response);
 
-		if (!isSigned(response, assertions)) {
+		List<Assertion> assertions = decryptAssertions(token, response);
+		if (!isSigned(responseSigned, assertions)) {
 			throw authException(INVALID_SIGNATURE, "Either the response or one of the assertions is unsigned. " +
 					"Please either sign the response or all of the assertions.");
 		}
+		validationExceptions.putAll(validateAssertions(token, assertions));
 
-		SignatureTrustEngine signatureTrustEngine = buildSignatureTrustEngine(token);
-
-		Map<String, Saml2AuthenticationException> validationExceptions = new HashMap<>();
-		if (response.isSigned()) {
-			SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator();
-			try {
-				profileValidator.validate(response.getSignature());
-			} catch (Exception e) {
-				validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
-						"Invalid signature for SAML Response [" + response.getID() + "]", e));
-			}
-
-			try {
-				CriteriaSet criteriaSet = new CriteriaSet();
-				criteriaSet.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer)));
-				criteriaSet.add(new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
-				criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
-				if (!signatureTrustEngine.validate(response.getSignature(), criteriaSet)) {
-					validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
-							"Invalid signature for SAML Response [" + response.getID() + "]"));
-				}
-			} catch (Exception e) {
-				validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
-						"Invalid signature for SAML Response [" + response.getID() + "]", e));
-			}
-		}
-
-		String destination = response.getDestination();
-		if (StringUtils.hasText(destination) && !destination.equals(token.getRecipientUri())) {
-			String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]";
-			validationExceptions.put(INVALID_DESTINATION, authException(INVALID_DESTINATION, message));
-		}
-
-		if (!StringUtils.hasText(issuer) || !issuer.equals(token.getIdpEntityId())) {
-			String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
-			validationExceptions.put(INVALID_ISSUER, authException(INVALID_ISSUER, message));
-		}
-
-		SAML20AssertionValidator validator = buildSamlAssertionValidator(signatureTrustEngine);
-		ValidationContext context = buildValidationContext(token, response);
-
-		if (logger.isDebugEnabled()) {
-			logger.debug("Validating " + assertions.size() + " assertions");
-		}
-		for (Assertion assertion : assertions) {
-			if (logger.isTraceEnabled()) {
-				logger.trace("Validating assertion " + assertion.getID());
-			}
-			try {
-				validAssertions.add(validateAssertion(assertion, validator, context));
-			} catch (Exception e) {
-				String message = String.format("Invalid assertion [%s] for SAML response [%s]", assertion.getID(), response.getID());
-				validationExceptions.put(INVALID_ASSERTION, authException(INVALID_ASSERTION, message, e));
-			}
+		Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
+		NameID nameId = decryptPrincipal(token, firstAssertion);
+		if (nameId == null || nameId.getValue() == null) {
+			validationExceptions.put(SUBJECT_NOT_FOUND, authException(SUBJECT_NOT_FOUND,
+					"Assertion [" + firstAssertion.getID() + "] is missing a subject"));
 		}
 
 		if (validationExceptions.isEmpty()) {
 			if (logger.isDebugEnabled()) {
-				logger.debug("Successfully validated SAML Response [" + response.getID() + "]");
+				logger.debug("Successfully processed SAML Response [" + response.getID() + "]");
 			}
 		} else {
 			if (logger.isTraceEnabled()) {
@@ -354,161 +311,71 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		if (!validationExceptions.isEmpty()) {
 			throw validationExceptions.values().iterator().next();
 		}
-		if (validAssertions.isEmpty()) {
-			throw authException(MALFORMED_RESPONSE_DATA, "No valid assertions found in response.");
-		}
-
-		return validAssertions;
 	}
 
-	private boolean isSigned(Response samlResponse, List<Assertion> assertions) {
-		if (samlResponse.isSigned()) {
-			return true;
-		}
+	private Map<String, Saml2AuthenticationException> validateResponse
+			(Saml2AuthenticationToken token, Response response) {
 
-		for (Assertion assertion : assertions) {
-			if (!assertion.isSigned()) {
-				return false;
-			}
-		}
-
-		return true;
-	}
-
-	private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken token) {
-		Set<Credential> credentials = new HashSet<>();
-		for (X509Certificate key : getVerificationCertificates(token)) {
-			BasicX509Credential cred = new BasicX509Credential(key);
-			cred.setUsageType(UsageType.SIGNING);
-			cred.setEntityId(token.getIdpEntityId());
-			credentials.add(cred);
-		}
-		CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
-		return new ExplicitKeySignatureTrustEngine(
-				credentialsResolver,
-				DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()
-		);
+		Map<String, Saml2AuthenticationException> validationExceptions = new HashMap<>();
+		validationExceptions.putAll(this.responseValidator.apply(token).convert(response));
+		return validationExceptions;
 	}
 
-	private ValidationContext buildValidationContext(Saml2AuthenticationToken token, Response response) {
-		Map<String, Object> validationParams = new HashMap<>();
-		validationParams.put(SIGNATURE_REQUIRED, !response.isSigned());
-		validationParams.put(CLOCK_SKEW, this.responseTimeValidationSkew.toMillis());
-		validationParams.put(COND_VALID_AUDIENCES, singleton(token.getLocalSpEntityId()));
-		if (StringUtils.hasText(token.getRecipientUri())) {
-			validationParams.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, singleton(token.getRecipientUri()));
+	private List<Assertion> decryptAssertions
+			(Saml2AuthenticationToken token, Response response) {
+		List<Assertion> assertions = new ArrayList<>();
+		for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) {
+			Assertion assertion = this.assertionDecrypter.apply(token).convert(encryptedAssertion);
+			assertions.add(assertion);
 		}
-		return new ValidationContext(validationParams);
+		response.getAssertions().addAll(assertions);
+		return response.getAssertions();
 	}
 
-	private SAML20AssertionValidator buildSamlAssertionValidator(SignatureTrustEngine signatureTrustEngine) {
-		return new SAML20AssertionValidator(
-				this.conditions, this.subjects, this.statements, signatureTrustEngine, this.signaturePrevalidator);
-	}
-
-	private Assertion validateAssertion(Assertion assertion,
-			SAML20AssertionValidator validator, ValidationContext context) {
-
-		ValidationResult result;
-		try {
-			result = validator.validate(assertion, context);
-		} catch (Exception e) {
-			throw new Saml2Exception("An error occurred while validation the assertion", e);
-		}
-		if (result != ValidationResult.VALID) {
-			throw new Saml2Exception("An error occurred while validating the assertion: " +
-					context.getValidationFailureMessage());
+	private Map<String, Saml2AuthenticationException> validateAssertions
+			(Saml2AuthenticationToken token, List<Assertion> assertions) {
+		if (assertions.isEmpty()) {
+			throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response.");
 		}
-		return assertion;
-	}
 
-	private Assertion decrypt(Saml2AuthenticationToken token, EncryptedAssertion assertion)
-			throws Saml2AuthenticationException {
-
-		Saml2AuthenticationException last = null;
-		List<Saml2X509Credential> decryptionCredentials = getDecryptionCredentials(token);
-		if (decryptionCredentials.isEmpty()) {
-			throw authException(DECRYPTION_ERROR, "No valid decryption credentials found.");
+		Map<String, Saml2AuthenticationException> validationExceptions = new LinkedHashMap<>();
+		if (logger.isDebugEnabled()) {
+			logger.debug("Validating " + assertions.size() + " assertions");
 		}
-		for (Saml2X509Credential key : decryptionCredentials) {
-			Decrypter decrypter = getDecrypter(key);
-			try {
-				return decrypter.decrypt(assertion);
-			}
-			catch (DecryptionException e) {
-				last = authException(DECRYPTION_ERROR, e.getMessage(), e);
+		for (Assertion assertion : assertions) {
+			if (logger.isTraceEnabled()) {
+				logger.trace("Validating assertion " + assertion.getID());
 			}
+			validationExceptions.putAll(this.assertionValidator.apply(token).convert(assertion));
 		}
-		throw last;
-	}
 
-	private Decrypter getDecrypter(Saml2X509Credential key) {
-		Credential credential = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey());
-		KeyInfoCredentialResolver resolver = new StaticKeyInfoCredentialResolver(credential);
-		Decrypter decrypter = new Decrypter(null, resolver, this.saml.getEncryptedKeyResolver());
-		decrypter.setRootInNewDocument(true);
-		return decrypter;
+		return validationExceptions;
 	}
 
-	private List<Saml2X509Credential> getDecryptionCredentials(Saml2AuthenticationToken token) {
-		List<Saml2X509Credential> result = new LinkedList<>();
-		for (Saml2X509Credential c : token.getX509Credentials()) {
-			if (c.isDecryptionCredential()) {
-				result.add(c);
-			}
+	private boolean isSigned(boolean responseSigned, List<Assertion> assertions) {
+		if (responseSigned) {
+			return true;
 		}
-		return result;
-	}
 
-	private List<X509Certificate> getVerificationCertificates(Saml2AuthenticationToken token) {
-		List<X509Certificate> result = new LinkedList<>();
-		for (Saml2X509Credential c : token.getX509Credentials()) {
-			if (c.isSignatureVerficationCredential()) {
-				result.add(c.getCertificate());
+		for (Assertion assertion : assertions) {
+			if (!assertion.isSigned()) {
+				return false;
 			}
 		}
-		return result;
-	}
 
-	private String getUsername(Saml2AuthenticationToken token, Assertion assertion)
-			throws Saml2AuthenticationException {
-
-		String username = null;
-		Subject subject = assertion.getSubject();
-		if (subject == null) {
-			throw authException(SUBJECT_NOT_FOUND, "Assertion [" + assertion.getID() + "] is missing a subject");
-		}
-		if (subject.getNameID() != null) {
-			username = subject.getNameID().getValue();
-		}
-		else if (subject.getEncryptedID() != null) {
-			NameID nameId = decrypt(token, subject.getEncryptedID());
-			username = nameId.getValue();
-		}
-		if (username == null) {
-			throw authException(USERNAME_NOT_FOUND, "Assertion [" + assertion.getID() + "] is missing a user identifier");
-		}
-		return username;
+		return true;
 	}
 
-	private NameID decrypt(Saml2AuthenticationToken token, EncryptedID assertion)
-			throws Saml2AuthenticationException {
-
-		Saml2AuthenticationException last = null;
-		List<Saml2X509Credential> decryptionCredentials = getDecryptionCredentials(token);
-		if (decryptionCredentials.isEmpty()) {
-			throw authException(DECRYPTION_ERROR, "No valid decryption credentials found.");
+	private NameID decryptPrincipal(Saml2AuthenticationToken token, Assertion assertion) {
+		if (assertion.getSubject() == null) {
+			return null;
 		}
-		for (Saml2X509Credential key : decryptionCredentials) {
-			Decrypter decrypter = getDecrypter(key);
-			try {
-				return (NameID) decrypter.decrypt(assertion);
-			}
-			catch (DecryptionException e) {
-				last = authException(DECRYPTION_ERROR, e.getMessage(), e);
-			}
+		if (assertion.getSubject().getEncryptedID() == null) {
+			return assertion.getSubject().getNameID();
 		}
-		throw last;
+		NameID nameId = this.principalDecrypter.apply(token).convert(assertion.getSubject().getEncryptedID());
+		assertion.getSubject().setNameID(nameId);
+		return nameId;
 	}
 
 	private Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
@@ -562,17 +429,344 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		return xsAny.getTextContent();
 	}
 
-	private Saml2Error validationError(String code, String description) {
+	private static <T extends XMLObject> Function<Saml2AuthenticationToken, Converter<T, Map<String, Saml2AuthenticationException>>>
+			validator(Collection<Function<Saml2AuthenticationToken, Converter<T, Map<String, Saml2AuthenticationException>>>> validators) {
+		return token -> response -> {
+			Map<String, Saml2AuthenticationException> errors = new LinkedHashMap<>();
+			for (Function<Saml2AuthenticationToken, Converter<T, Map<String, Saml2AuthenticationException>>> validator : validators) {
+				errors.putAll(validator.apply(token).convert(response));
+			}
+			return errors;
+		};
+	}
+
+	private static class ResponseSignatureValidator implements
+			Function<Saml2AuthenticationToken, Converter<Response, Map<String, Saml2AuthenticationException>>> {
+
+		private final SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator();
+
+		@Override
+		public Converter<Response, Map<String, Saml2AuthenticationException>> apply(Saml2AuthenticationToken token) {
+			return response -> {
+				Map<String, Saml2AuthenticationException> validationExceptions = new LinkedHashMap<>();
+				String issuer = response.getIssuer().getValue();
+				if (response.isSigned()) {
+					try {
+						this.profileValidator.validate(response.getSignature());
+					} catch (Exception e) {
+						validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
+								"Invalid signature for SAML Response [" + response.getID() + "]: ", e));
+					}
+
+					try {
+						CriteriaSet criteriaSet = new CriteriaSet();
+						criteriaSet.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer)));
+						criteriaSet.add(new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
+						criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
+						if (!buildSignatureTrustEngine(token).validate(response.getSignature(), criteriaSet)) {
+							validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
+									"Invalid signature for SAML Response [" + response.getID() + "]"));
+						}
+					} catch (Exception e) {
+						validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
+								"Invalid signature for SAML Response [" + response.getID() + "]: ", e));
+					}
+				}
+
+				return validationExceptions;
+			};
+		}
+
+		private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken token) {
+			Set<Credential> credentials = new HashSet<>();
+			for (Saml2X509Credential key : token.getX509Credentials()) {
+				if (!key.isSignatureVerficationCredential()) {
+					continue;
+				}
+				BasicX509Credential cred = new BasicX509Credential(key.getCertificate());
+				cred.setUsageType(UsageType.SIGNING);
+				cred.setEntityId(token.getIdpEntityId());
+				credentials.add(cred);
+			}
+			CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
+			return new ExplicitKeySignatureTrustEngine(
+					credentialsResolver,
+					DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()
+			);
+		}
+	}
+
+	private static class ResponseValidator
+			implements Function<Saml2AuthenticationToken, Converter<Response, Map<String, Saml2AuthenticationException>>> {
+
+		@Override
+		public Converter<Response, Map<String, Saml2AuthenticationException>> apply(Saml2AuthenticationToken token) {
+			return response -> {
+				Map<String, Saml2AuthenticationException> validationExceptions = new LinkedHashMap<>();
+
+				String destination = response.getDestination();
+				if (StringUtils.hasText(destination) && !destination.equals(token.getRecipientUri())) {
+					String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]";
+					validationExceptions.put(INVALID_DESTINATION, authException(INVALID_DESTINATION, message));
+				}
+
+				String issuer = response.getIssuer().getValue();
+				String assertingPartyEntityId = token.getIdpEntityId();
+				if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
+					String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
+					validationExceptions.put(INVALID_ISSUER, authException(INVALID_ISSUER, message));
+				}
+
+				return validationExceptions;
+			};
+		}
+	}
+
+	private static class AssertionDecrypter
+			implements Function<Saml2AuthenticationToken, Converter<EncryptedAssertion, Assertion>> {
+		private final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver(
+				asList(
+						new InlineEncryptedKeyResolver(),
+						new EncryptedElementTypeEncryptedKeyResolver(),
+						new SimpleRetrievalMethodEncryptedKeyResolver()
+				)
+		);
+
+		@Override
+		public Converter<EncryptedAssertion, Assertion> apply(Saml2AuthenticationToken token) {
+			return encrypted -> {
+				Saml2AuthenticationException last =
+						authException(DECRYPTION_ERROR, "No valid decryption credentials found.");
+				List<Saml2X509Credential> decryptionCredentials = token.getX509Credentials();
+				for (Saml2X509Credential key : decryptionCredentials) {
+					if (!key.isDecryptionCredential()) {
+						continue;
+					}
+					Decrypter decrypter = getDecrypter(key);
+					try {
+						return decrypter.decrypt(encrypted);
+					} catch (DecryptionException e) {
+						last = authException(DECRYPTION_ERROR, e.getMessage(), e);
+					}
+				}
+				throw last;
+			};
+		}
+
+		private Decrypter getDecrypter(Saml2X509Credential key) {
+			Credential credential = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey());
+			KeyInfoCredentialResolver resolver = new StaticKeyInfoCredentialResolver(credential);
+			Decrypter decrypter = new Decrypter(null, resolver, this.encryptedKeyResolver);
+			decrypter.setRootInNewDocument(true);
+			return decrypter;
+		}
+	}
+
+	private static class AssertionSignatureValidator
+			implements Function<Saml2AuthenticationToken, Converter<Assertion, Map<String, Saml2AuthenticationException>>> {
+
+		private final SignaturePrevalidator signaturePrevalidator = new SAMLSignatureProfileValidator();
+
+		@Override
+		public Converter<Assertion, Map<String, Saml2AuthenticationException>> apply(Saml2AuthenticationToken token) {
+			return assertion -> {
+				Map<String, Saml2AuthenticationException> validationExceptions = new LinkedHashMap<>();
+				try {
+					ValidationContext context = buildValidationContext();
+					ValidationResult result = buildSamlAssertionValidator(token).validate(assertion, context);
+					if (result != ValidationResult.VALID) {
+						String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
+								assertion.getID(), ((Response) assertion.getParent()).getID(),
+								context.getValidationFailureMessage());
+						validationExceptions.put(INVALID_ASSERTION, authException(INVALID_ASSERTION, message));
+					}
+				} catch (Exception e) {
+					String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
+							assertion.getID(), ((Response) assertion.getParent()).getID(),
+							e.getMessage());
+					validationExceptions.put(INVALID_ASSERTION, authException(INVALID_ASSERTION, message, e));
+				}
+				return validationExceptions;
+			};
+		}
+
+		private ValidationContext buildValidationContext() {
+			Map<String, Object> validationParams = new HashMap<>();
+			validationParams.put(SIGNATURE_REQUIRED, Boolean.FALSE); // this requirement is checked earlier
+			return new ValidationContext(validationParams);
+		}
+
+		private SAML20AssertionValidator buildSamlAssertionValidator(Saml2AuthenticationToken token) {
+			SignatureTrustEngine signatureTrustEngine = buildSignatureTrustEngine(token);
+			return new SAML20AssertionValidator(
+					Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), signatureTrustEngine, signaturePrevalidator) {
+				@Nonnull
+				@Override
+				protected ValidationResult validateConditions(@Nonnull Assertion assertion, @Nonnull ValidationContext context) {
+					return ValidationResult.VALID;
+				}
+
+				@Nonnull
+				@Override
+				protected ValidationResult validateSubjectConfirmation(@Nonnull Assertion assertion, @Nonnull ValidationContext context) {
+					return ValidationResult.VALID;
+				}
+
+				@Nonnull
+				@Override
+				protected ValidationResult validateStatements(@Nonnull Assertion assertion, @Nonnull ValidationContext context) {
+					return ValidationResult.VALID;
+				}
+			};
+		}
+
+		private SignatureTrustEngine buildSignatureTrustEngine(Saml2AuthenticationToken token) {
+			Set<Credential> credentials = new HashSet<>();
+			for (Saml2X509Credential key : token.getX509Credentials()) {
+				if (!key.isSignatureVerficationCredential()) continue;
+				BasicX509Credential cred = new BasicX509Credential(key.getCertificate());
+				cred.setUsageType(UsageType.SIGNING);
+				cred.setEntityId(token.getIdpEntityId());
+				credentials.add(cred);
+			}
+			CredentialResolver credentialsResolver = new CollectionCredentialResolver(credentials);
+			return new ExplicitKeySignatureTrustEngine(
+					credentialsResolver,
+					DefaultSecurityConfigurationBootstrap.buildBasicInlineKeyInfoCredentialResolver()
+			);
+		}
+	}
+
+	private static class AssertionValidator
+			implements Function<Saml2AuthenticationToken, Converter<Assertion, Map<String, Saml2AuthenticationException>>> {
+
+		private final Function<Saml2AuthenticationToken, ValidationContext> validationContextResolver;
+		private final Function<Saml2AuthenticationToken, SAML20AssertionValidator> assertionValidatorResolver;
+
+		AssertionValidator(Function<Saml2AuthenticationToken, SAML20AssertionValidator> assertionValidatorResolver,
+			Function<Saml2AuthenticationToken, ValidationContext> validationContextResolver) {
+
+			this.validationContextResolver = validationContextResolver;
+			this.assertionValidatorResolver = assertionValidatorResolver;
+		}
+
+		@Override
+		public Converter<Assertion, Map<String, Saml2AuthenticationException>> apply(Saml2AuthenticationToken token) {
+			return assertion -> {
+				Map<String, Saml2AuthenticationException> validationExceptions = new LinkedHashMap<>();
+				try {
+					ValidationContext context = this.validationContextResolver.apply(token);
+					ValidationResult result = this.assertionValidatorResolver.apply(token).validate(assertion, context);
+					if (result != ValidationResult.VALID) {
+						String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
+								assertion.getID(), ((Response) assertion.getParent()).getID(),
+								context.getValidationFailureMessage());
+						validationExceptions.put(INVALID_ASSERTION, authException(INVALID_ASSERTION, message));
+					}
+				} catch (Exception e) {
+					String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
+							assertion.getID(), ((Response) assertion.getParent()).getID(),
+							e.getMessage());
+					validationExceptions.put(INVALID_ASSERTION, authException(INVALID_ASSERTION, message, e));
+				}
+				return validationExceptions;
+			};
+		}
+
+		private static class Builder {
+			private final Collection<ConditionValidator> conditions = new ArrayList<>();
+			private final Collection<SubjectConfirmationValidator> subjects = new ArrayList<>();
+			private final Collection<StatementValidator> statements = new ArrayList<>();
+			private final Map<String, Object> validationContextParameters = new HashMap<>();
+
+			Builder() {
+				this.conditions.add(new AudienceRestrictionConditionValidator());
+				this.subjects.add(new BearerSubjectConfirmationValidator() {
+					@Nonnull
+					@Override
+					protected ValidationResult validateAddress(@Nonnull SubjectConfirmation confirmation,
+							@Nonnull Assertion assertion, @Nonnull ValidationContext context) {
+						// skipping address validation - gh-7514
+						return ValidationResult.VALID;
+					}
+				});
+			}
+
+			public AssertionValidator.Builder validationContext(
+					Consumer<Map<String, Object>> validationContextParameters) {
+				validationContextParameters.accept(this.validationContextParameters);
+				return this;
+			}
+
+			public AssertionValidator build() {
+				return new AssertionValidator(
+						token -> new SAML20AssertionValidator(this.conditions, this.subjects, this.statements, null, null) {
+							@Nonnull
+							@Override
+							protected ValidationResult validateSignature(@Nonnull Assertion token, @Nonnull ValidationContext context) {
+								return ValidationResult.VALID;
+							}
+						},
+						token -> {
+							Map<String, Object> params = new HashMap<>();
+							params.put(CLOCK_SKEW, Duration.ofMinutes(5).toMillis());
+							params.put(COND_VALID_AUDIENCES, singleton(token.getIdpEntityId()));
+							params.put(SC_VALID_RECIPIENTS, singleton(token.getRecipientUri()));
+							params.putAll(this.validationContextParameters);
+							return new ValidationContext(params);
+						});
+			}
+		}
+	}
+
+	private static class PrincipalDecrypter implements Function<Saml2AuthenticationToken, Converter<EncryptedID, NameID>> {
+		private final EncryptedKeyResolver encryptedKeyResolver = new ChainingEncryptedKeyResolver(
+				asList(
+						new InlineEncryptedKeyResolver(),
+						new EncryptedElementTypeEncryptedKeyResolver(),
+						new SimpleRetrievalMethodEncryptedKeyResolver()
+				)
+		);
+
+		@Override
+		public Converter<EncryptedID, NameID> apply(Saml2AuthenticationToken token) {
+			return encrypted -> {
+				Saml2AuthenticationException last =
+						authException(DECRYPTION_ERROR, "No valid decryption credentials found.");
+				List<Saml2X509Credential> decryptionCredentials = token.getX509Credentials();
+				for (Saml2X509Credential key : decryptionCredentials) {
+					if (!key.isDecryptionCredential()) continue;
+					Decrypter decrypter = getDecrypter(key);
+					try {
+						return (NameID) decrypter.decrypt(encrypted);
+					} catch (DecryptionException e) {
+						last = authException(DECRYPTION_ERROR, e.getMessage(), e);
+					}
+				}
+				throw last;
+			};
+		}
+
+		private Decrypter getDecrypter(Saml2X509Credential key) {
+			Credential credential = CredentialSupport.getSimpleCredential(key.getCertificate(), key.getPrivateKey());
+			KeyInfoCredentialResolver resolver = new StaticKeyInfoCredentialResolver(credential);
+			Decrypter decrypter = new Decrypter(null, resolver, this.encryptedKeyResolver);
+			decrypter.setRootInNewDocument(true);
+			return decrypter;
+		}
+	}
+
+	private static Saml2Error validationError(String code, String description) {
 		return new Saml2Error(code, description);
 	}
 
-	private Saml2AuthenticationException authException(String code, String description)
+	private static Saml2AuthenticationException authException(String code, String description)
 			throws Saml2AuthenticationException {
 
 		return new Saml2AuthenticationException(validationError(code, description));
 	}
 
-	private Saml2AuthenticationException authException(String code, String description, Exception cause)
+	private static Saml2AuthenticationException authException(String code, String description, Exception cause)
 			throws Saml2AuthenticationException {
 
 		return new Saml2AuthenticationException(validationError(code, description), cause);

+ 1 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java

@@ -186,7 +186,7 @@ public class OpenSamlAuthenticationProviderTests {
 
 	@Test
 	public void authenticateWhenUsernameMissingThenThrowAuthenticationException() throws Exception {
-		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.USERNAME_NOT_FOUND));
+		this.exception.expect(authenticationMatcher(Saml2ErrorCodes.SUBJECT_NOT_FOUND));
 
 		Response response = response();
 		Assertion assertion = assertion();

+ 17 - 16
samples/boot/saml2login/src/integration-test/java/org/springframework/security/saml2/provider/service/authentication/Saml2LoginIntegrationTests.java

@@ -16,6 +16,18 @@
 
 package org.springframework.security.saml2.provider.service.authentication;
 
+import java.io.ByteArrayInputStream;
+import java.net.URLDecoder;
+import java.nio.charset.StandardCharsets;
+import java.security.KeyException;
+import java.security.PrivateKey;
+import java.security.PublicKey;
+import java.security.cert.CertificateException;
+import java.security.cert.CertificateFactory;
+import java.security.cert.X509Certificate;
+import java.util.UUID;
+import javax.servlet.http.HttpSession;
+
 import net.shibboleth.utilities.java.support.component.ComponentInitializationException;
 import net.shibboleth.utilities.java.support.xml.BasicParserPool;
 import net.shibboleth.utilities.java.support.xml.SerializeSupport;
@@ -47,6 +59,9 @@ import org.opensaml.xmlsec.SignatureSigningParameters;
 import org.opensaml.xmlsec.signature.support.SignatureConstants;
 import org.opensaml.xmlsec.signature.support.SignatureException;
 import org.opensaml.xmlsec.signature.support.SignatureSupport;
+import org.w3c.dom.Document;
+import org.w3c.dom.Element;
+
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.boot.SpringBootConfiguration;
 import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
@@ -60,20 +75,6 @@ import org.springframework.test.web.servlet.ResultActions;
 import org.springframework.test.web.servlet.ResultMatcher;
 import org.springframework.util.MultiValueMap;
 import org.springframework.web.util.UriComponentsBuilder;
-import org.w3c.dom.Document;
-import org.w3c.dom.Element;
-
-import javax.servlet.http.HttpSession;
-import java.io.ByteArrayInputStream;
-import java.net.URLDecoder;
-import java.nio.charset.StandardCharsets;
-import java.security.KeyException;
-import java.security.PrivateKey;
-import java.security.PublicKey;
-import java.security.cert.CertificateException;
-import java.security.cert.CertificateFactory;
-import java.security.cert.X509Certificate;
-import java.util.UUID;
 
 import static java.nio.charset.StandardCharsets.UTF_8;
 import static org.hamcrest.Matchers.containsString;
@@ -287,9 +288,9 @@ public class Saml2LoginIntegrationTests {
 				.andExpect(unauthenticated())
 				.andExpect(
 						saml2AuthenticationExceptionMatcher(
-								"invalid_issuer",
+								"invalid_signature",
 								containsString(
-										"Invalid issuer [invalid issuer] for SAML response"
+										"Invalid signature for SAML Response"
 								)
 						)
 				);