瀏覽代碼

Generalize SAML 2.0 Assertion Validation Support

Closes gh-8970
Josh Cummings 5 年之前
父節點
當前提交
7b3dda161b

+ 121 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResult.java

@@ -0,0 +1,121 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.core;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+
+import org.springframework.util.Assert;
+
+/**
+ * A result emitted from a SAML 2.0 Response validation attempt
+ *
+ * @author Josh Cummings
+ * @since 5.4
+ */
+public final class Saml2ResponseValidatorResult {
+	static final Saml2ResponseValidatorResult NO_ERRORS = new Saml2ResponseValidatorResult(Collections.emptyList());
+
+	private final Collection<Saml2Error> errors;
+
+	private Saml2ResponseValidatorResult(Collection<Saml2Error> errors) {
+		Assert.notNull(errors, "errors cannot be null");
+		this.errors = new ArrayList<>(errors);
+	}
+
+	/**
+	 * Say whether this result indicates success
+	 *
+	 * @return whether this result has errors
+	 */
+	public boolean hasErrors() {
+		return !this.errors.isEmpty();
+	}
+
+	/**
+	 * Return error details regarding the validation attempt
+	 *
+	 * @return the collection of results in this result, if any; returns an empty list otherwise
+	 */
+	public Collection<Saml2Error> getErrors() {
+		return Collections.unmodifiableCollection(this.errors);
+	}
+
+	/**
+	 * Return a new {@link Saml2ResponseValidatorResult} that contains
+	 * both the given {@link Saml2Error} and the errors from the result
+	 *
+	 * @param error the {@link Saml2Error} to append
+	 * @return a new {@link Saml2ResponseValidatorResult} for further reporting
+	 */
+	public Saml2ResponseValidatorResult concat(Saml2Error error) {
+		Assert.notNull(error, "error cannot be null");
+		Collection<Saml2Error> errors = new ArrayList<>(this.errors);
+		errors.add(error);
+		return failure(errors);
+	}
+
+	/**
+	 * Return a new {@link Saml2ResponseValidatorResult} that contains
+	 * the errors from the given {@link Saml2ResponseValidatorResult} as well
+	 * as this result.
+	 *
+	 * @param result the {@link Saml2ResponseValidatorResult} to merge with this one
+	 * @return a new {@link Saml2ResponseValidatorResult} for further reporting
+	 */
+	public Saml2ResponseValidatorResult concat(Saml2ResponseValidatorResult result) {
+		Assert.notNull(result, "result cannot be null");
+		Collection<Saml2Error> errors = new ArrayList<>(this.errors);
+		errors.addAll(result.getErrors());
+		return failure(errors);
+	}
+
+	/**
+	 * Construct a successful {@link Saml2ResponseValidatorResult}
+	 *
+	 * @return an {@link Saml2ResponseValidatorResult} with no errors
+	 */
+	public static Saml2ResponseValidatorResult success() {
+		return NO_ERRORS;
+	}
+
+	/**
+	 * Construct a failure {@link Saml2ResponseValidatorResult} with the provided detail
+	 *
+	 * @param errors the list of errors
+	 * @return an {@link Saml2ResponseValidatorResult} with the errors specified
+	 */
+	public static Saml2ResponseValidatorResult failure(Saml2Error... errors) {
+		return failure(Arrays.asList(errors));
+	}
+
+	/**
+	 * Construct a failure {@link Saml2ResponseValidatorResult} with the provided detail
+	 *
+	 * @param errors the list of errors
+	 * @return an {@link Saml2ResponseValidatorResult} with the errors specified
+	 */
+	public static Saml2ResponseValidatorResult failure(Collection<Saml2Error> errors) {
+		if (errors.isEmpty()) {
+			return NO_ERRORS;
+		}
+
+		return new Saml2ResponseValidatorResult(errors);
+	}
+}

+ 188 - 108
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java

@@ -30,6 +30,7 @@ import java.util.Map;
 import java.util.Set;
 import java.util.function.Function;
 import javax.annotation.Nonnull;
+import javax.xml.namespace.QName;
 
 import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
 import net.shibboleth.utilities.java.support.xml.ParserPool;
@@ -61,11 +62,14 @@ import org.opensaml.saml.saml2.assertion.StatementValidator;
 import org.opensaml.saml.saml2.assertion.SubjectConfirmationValidator;
 import org.opensaml.saml.saml2.assertion.impl.AudienceRestrictionConditionValidator;
 import org.opensaml.saml.saml2.assertion.impl.BearerSubjectConfirmationValidator;
+import org.opensaml.saml.saml2.assertion.impl.DelegationRestrictionConditionValidator;
 import org.opensaml.saml.saml2.core.Assertion;
 import org.opensaml.saml.saml2.core.Attribute;
 import org.opensaml.saml.saml2.core.AttributeStatement;
+import org.opensaml.saml.saml2.core.Condition;
 import org.opensaml.saml.saml2.core.EncryptedAssertion;
 import org.opensaml.saml.saml2.core.NameID;
+import org.opensaml.saml.saml2.core.OneTimeUse;
 import org.opensaml.saml.saml2.core.Response;
 import org.opensaml.saml.saml2.core.SubjectConfirmation;
 import org.opensaml.saml.saml2.core.impl.ResponseUnmarshaller;
@@ -106,6 +110,7 @@ import org.springframework.security.core.authority.mapping.GrantedAuthoritiesMap
 import org.springframework.security.saml2.Saml2Exception;
 import org.springframework.security.saml2.core.OpenSamlInitializationService;
 import org.springframework.security.saml2.core.Saml2Error;
+import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
 import org.springframework.security.saml2.core.Saml2X509Credential;
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
@@ -126,6 +131,8 @@ import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_IS
 import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE;
 import static org.springframework.security.saml2.core.Saml2ErrorCodes.MALFORMED_RESPONSE_DATA;
 import static org.springframework.security.saml2.core.Saml2ErrorCodes.SUBJECT_NOT_FOUND;
+import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.failure;
+import static org.springframework.security.saml2.core.Saml2ResponseValidatorResult.success;
 import static org.springframework.util.Assert.notNull;
 
 /**
@@ -191,16 +198,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 						this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
 			};
 
+	private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionValidator = assertionToken -> {
+		ValidationContext context = createValidationContext(assertionToken);
+		return createDefaultAssertionValidator(context).convert(assertionToken);
+	};
+
 	private Converter<Saml2AuthenticationToken, SignatureTrustEngine> signatureTrustEngineConverter =
 			new SignatureTrustEngineConverter();
-	private Converter<Tuple, SAML20AssertionValidator> assertionValidatorConverter =
-			new SAML20AssertionValidatorConverter();
-	private Collection<ConditionValidator> conditionValidators =
-			Collections.singleton(new AudienceRestrictionConditionValidator());
-	private Converter<Tuple, ValidationContext> validationContextConverter =
-			new ValidationContextConverter();
 	private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
 
+
 	/**
 	 * Creates an {@link OpenSamlAuthenticationProvider}
 	 */
@@ -212,30 +219,43 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 	}
 
 	/**
-	 * Set the the collection of {@link ConditionValidator}s used when validating an assertion.
+	 * Set the {@link Converter} to use for validating each {@link Assertion} in the SAML 2.0 Response.
 	 *
-	 * @param conditionValidators the collection of validators to use
-	 * @since 5.4
-	 */
-	public void setConditionValidators(
-			Collection<ConditionValidator> conditionValidators) {
-
-		Assert.notEmpty(conditionValidators, "conditionValidators cannot be empty");
-		this.conditionValidators = conditionValidators;
-	}
-
-	/**
-	 * Set the strategy for retrieving the {@link ValidationContext} used when
-	 * validating an assertion.
+	 * You can still invoke the default validator by delgating to
+	 * {@link #createDefaultAssertionValidator(ValidationContext)}, like so:
+	 *
+	 * <pre>
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *  provider.setAssertionValidator(assertionToken -> {
+	 *		ValidationContext context = // ... build using authentication token
+	 *		Saml2ResponseValidatorResult result = createDefaultAssertionValidator(context)
+	 *			.convert(assertionToken)
+	 *		return result.concat(myCustomValiator.convert(assertionToken));
+	 *  });
+	 * </pre>
+	 *
+	 * Consider taking a look at {@link #createValidationContext(AssertionToken)} to see how it
+	 * constructs a {@link ValidationContext}.
 	 *
-	 * @param validationContextConverter the strategy to use
+	 * You can also use this method to configure the provider to use a different
+	 * {@link ValidationContext} from the default, like so:
+	 *
+	 * <pre>
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	ValidationContext context = // ...
+	 *	provider.setAssertionValidator(createDefaultAssertionValidator(context));
+	 * </pre>
+	 *
+	 * It is not necessary to delegate to the default validator. You can safely replace it
+	 * entirely with your own. Note that signature verification is performed as a separate
+	 * step from this validator.
+	 *
+	 * @param assertionValidator
 	 * @since 5.4
 	 */
-	public void setValidationContextConverter(
-			Converter<Tuple, ValidationContext> validationContextConverter) {
-
-		Assert.notNull(validationContextConverter, "validationContextConverter cannot be empty");
-		this.validationContextConverter = validationContextConverter;
+	public void setAssertionValidator(Converter<AssertionToken, Saml2ResponseValidatorResult> assertionValidator) {
+		Assert.notNull(assertionValidator, "assertionValidator cannot be null");
+		this.assertionValidator = assertionValidator;
 	}
 
 	/**
@@ -322,7 +342,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		}
 
 		boolean responseSigned = response.isSigned();
-		Map<String, Saml2AuthenticationException> validationExceptions = validateResponse(token, response);
+		Saml2ResponseValidatorResult result = validateResponse(token, response);
 
 		Decrypter decrypter = this.decrypterConverter.convert(token);
 		List<Assertion> assertions = decryptAssertions(decrypter, response);
@@ -330,37 +350,37 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 			throw authException(INVALID_SIGNATURE, "Either the response or one of the assertions is unsigned. " +
 					"Please either sign the response or all of the assertions.");
 		}
-		validationExceptions.putAll(validateAssertions(token, response));
+		result = result.concat(validateAssertions(token, response));
 
 		Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
 		NameID nameId = decryptPrincipal(decrypter, firstAssertion);
 		if (nameId == null || nameId.getValue() == null) {
-			validationExceptions.put(SUBJECT_NOT_FOUND, authException(SUBJECT_NOT_FOUND,
-					"Assertion [" + firstAssertion.getID() + "] is missing a subject"));
+			Saml2Error error = new Saml2Error(SUBJECT_NOT_FOUND,
+					"Assertion [" + firstAssertion.getID() + "] is missing a subject");
+			result = result.concat(error);
 		}
 
-		if (validationExceptions.isEmpty()) {
-			if (logger.isDebugEnabled()) {
-				logger.debug("Successfully processed SAML Response [" + response.getID() + "]");
-			}
-		} else {
+		if (result.hasErrors()) {
+			Collection<Saml2Error> errors = result.getErrors();
 			if (logger.isTraceEnabled()) {
-				logger.debug("Found " + validationExceptions.size() + " validation errors in SAML response [" + response.getID() + "]: " +
-						validationExceptions.values());
+				logger.debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]: " +
+						errors);
 			} else if (logger.isDebugEnabled()) {
-				logger.debug("Found " + validationExceptions.size() + " validation errors in SAML response [" + response.getID() + "]");
+				logger.debug("Found " + errors.size() + " validation errors in SAML response [" + response.getID() + "]");
+			}
+			Saml2Error first = errors.iterator().next();
+			throw authException(first.getErrorCode(), first.getDescription());
+		} else {
+			if (logger.isDebugEnabled()) {
+				logger.debug("Successfully processed SAML Response [" + response.getID() + "]");
 			}
-		}
-
-		if (!validationExceptions.isEmpty()) {
-			throw validationExceptions.values().iterator().next();
 		}
 	}
 
-	private Map<String, Saml2AuthenticationException> validateResponse
+	private Saml2ResponseValidatorResult validateResponse
 			(Saml2AuthenticationToken token, Response response) {
 
-		Map<String, Saml2AuthenticationException> validationExceptions = new HashMap<>();
+		Collection<Saml2Error> errors = new ArrayList<>();
 		String issuer = response.getIssuer().getValue();
 
 		if (response.isSigned()) {
@@ -368,8 +388,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 			try {
 				profileValidator.validate(response.getSignature());
 			} catch (Exception e) {
-				validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
-						"Invalid signature for SAML Response [" + response.getID() + "]: ", e));
+				errors.add(new Saml2Error(INVALID_SIGNATURE,
+						"Invalid signature for SAML Response [" + response.getID() + "]: "));
 			}
 
 			try {
@@ -378,12 +398,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 				criteriaSet.add(new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
 				criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
 				if (!this.signatureTrustEngineConverter.convert(token).validate(response.getSignature(), criteriaSet)) {
-					validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
+					errors.add(new Saml2Error(INVALID_SIGNATURE,
 							"Invalid signature for SAML Response [" + response.getID() + "]"));
 				}
 			} catch (Exception e) {
-				validationExceptions.put(INVALID_SIGNATURE, authException(INVALID_SIGNATURE,
-						"Invalid signature for SAML Response [" + response.getID() + "]: ", e));
+				errors.add(new Saml2Error(INVALID_SIGNATURE,
+						"Invalid signature for SAML Response [" + response.getID() + "]: "));
 			}
 		}
 
@@ -391,16 +411,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
 		if (StringUtils.hasText(destination) && !destination.equals(location)) {
 			String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]";
-			validationExceptions.put(INVALID_DESTINATION, authException(INVALID_DESTINATION, message));
+			errors.add(new Saml2Error(INVALID_DESTINATION, message));
 		}
 
 		String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId();
 		if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
 			String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
-			validationExceptions.put(INVALID_ISSUER, authException(INVALID_ISSUER, message));
+			errors.add(new Saml2Error(INVALID_ISSUER, message));
 		}
 
-		return validationExceptions;
+		return failure(errors);
 	}
 
 	private List<Assertion> decryptAssertions
@@ -418,41 +438,35 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		return response.getAssertions();
 	}
 
-	private Map<String, Saml2AuthenticationException> validateAssertions
+	private Saml2ResponseValidatorResult validateAssertions
 			(Saml2AuthenticationToken token, Response response) {
 		List<Assertion> assertions = response.getAssertions();
 		if (assertions.isEmpty()) {
 			throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response.");
 		}
 
-		Map<String, Saml2AuthenticationException> validationExceptions = new LinkedHashMap<>();
+		Saml2ResponseValidatorResult result = success();
 		if (logger.isDebugEnabled()) {
 			logger.debug("Validating " + assertions.size() + " assertions");
 		}
 
-		Tuple tuple = new Tuple(token, response);
-		SAML20AssertionValidator validator = this.assertionValidatorConverter.convert(tuple);
-		ValidationContext context = this.validationContextConverter.convert(tuple);
+		ValidationContext signatureContext = new ValidationContext
+				(Collections.singletonMap(SIGNATURE_REQUIRED, false)); // check already performed
+		SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(token);
+		Converter<AssertionToken, Saml2ResponseValidatorResult> signatureValidator =
+				createDefaultAssertionValidator(INVALID_SIGNATURE,
+						SAML20AssertionValidators.createSignatureValidator(engine), signatureContext);
 		for (Assertion assertion : assertions) {
 			if (logger.isTraceEnabled()) {
 				logger.trace("Validating assertion " + assertion.getID());
 			}
-			try {
-				if (validator.validate(assertion, context) != ValidationResult.VALID) {
-					String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
-							assertion.getID(), ((Response) assertion.getParent()).getID(),
-							context.getValidationFailureMessage());
-					validationExceptions.put(INVALID_ASSERTION, authException(INVALID_ASSERTION, message));
-				}
-			} catch (Exception e) {
-				String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
-						assertion.getID(), ((Response) assertion.getParent()).getID(),
-						e.getMessage());
-				validationExceptions.put(INVALID_ASSERTION, authException(INVALID_ASSERTION, message, e));
-			}
+			AssertionToken assertionToken = new AssertionToken(assertion, token);
+			result = result
+					.concat(signatureValidator.convert(assertionToken))
+					.concat(this.assertionValidator.convert(assertionToken));
 		}
 
-		return validationExceptions;
+		return result;
 	}
 
 	private boolean isSigned(boolean responseSigned, List<Assertion> assertions) {
@@ -561,45 +575,111 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		}
 	}
 
-	private class ValidationContextConverter implements Converter<Tuple, ValidationContext> {
+	public static Converter<AssertionToken, Saml2ResponseValidatorResult>
+			createDefaultAssertionValidator(ValidationContext context) {
 
-		@Override
-		public ValidationContext convert(Tuple tuple) {
-			String audience = tuple.authentication.getRelyingPartyRegistration().getEntityId();
-			String recipient = tuple.authentication.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
-			Map<String, Object> params = new HashMap<>();
-			params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis());
-			params.put(COND_VALID_AUDIENCES, singleton(audience));
-			params.put(SC_VALID_RECIPIENTS, singleton(recipient));
-			params.put(SIGNATURE_REQUIRED, false); // this verification is performed earlier
-			return new ValidationContext(params);
-		}
+		return createDefaultAssertionValidator(INVALID_ASSERTION,
+				SAML20AssertionValidators.createAttributeValidator(), context);
 	}
 
-	private class SAML20AssertionValidatorConverter implements Converter<Tuple, SAML20AssertionValidator> {
-		private final Collection<SubjectConfirmationValidator> subjects = new ArrayList<>();
-		private final Collection<StatementValidator> statements = new ArrayList<>();
-		private final SignaturePrevalidator validator = new SAMLSignatureProfileValidator();
+	private static Converter<AssertionToken, Saml2ResponseValidatorResult>
+			createDefaultAssertionValidator(String errorCode, SAML20AssertionValidator validator, ValidationContext context) {
+
+		return assertionToken -> {
+			Assertion assertion = assertionToken.assertion;
+			try {
+				ValidationResult result = validator.validate(assertion, context);
+				if (result == ValidationResult.VALID) {
+					return success();
+				}
+			} catch (Exception e) {
+				String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
+						assertion.getID(), ((Response) assertion.getParent()).getID(),
+						e.getMessage());
+				return failure(new Saml2Error(errorCode, message));
+			}
+			String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
+					assertion.getID(), ((Response) assertion.getParent()).getID(),
+					context.getValidationFailureMessage());
+			return failure(new Saml2Error(errorCode, message));
+		};
+	}
+
+	private ValidationContext createValidationContext(AssertionToken assertionToken) {
+		String audience = assertionToken.token.getRelyingPartyRegistration().getEntityId();
+		String recipient = assertionToken.token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
+		Map<String, Object> params = new HashMap<>();
+		params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis());
+		params.put(COND_VALID_AUDIENCES, singleton(audience));
+		params.put(SC_VALID_RECIPIENTS, singleton(recipient));
+		return new ValidationContext(params);
+	}
+
+	private static class SAML20AssertionValidators {
+		private static final Collection<ConditionValidator> conditions = new ArrayList<>();
+		private static final Collection<SubjectConfirmationValidator> subjects = new ArrayList<>();
+		private static final Collection<StatementValidator> statements = new ArrayList<>();
+		private static final SignaturePrevalidator validator = new SAMLSignatureProfileValidator();
+
+		static {
+			conditions.add(new AudienceRestrictionConditionValidator());
+			conditions.add(new DelegationRestrictionConditionValidator());
+			conditions.add(new ConditionValidator() {
+				@Nonnull
+				@Override
+				public QName getServicedCondition() {
+					return OneTimeUse.DEFAULT_ELEMENT_NAME;
+				}
 
-		SAML20AssertionValidatorConverter() {
-			this.subjects.add(new BearerSubjectConfirmationValidator() {
+				@Nonnull
+				@Override
+				public ValidationResult validate(Condition condition, Assertion assertion, ValidationContext context) {
+					// applications should validate their own OneTimeUse conditions
+					return ValidationResult.VALID;
+				}
+			});
+			subjects.add(new BearerSubjectConfirmationValidator() {
 				@Nonnull
 				@Override
 				protected ValidationResult validateAddress(@Nonnull SubjectConfirmation confirmation,
 						@Nonnull Assertion assertion, @Nonnull ValidationContext context) {
-					// skipping address validation - gh-7514
+					// applications should validate their own addresses - gh-7514
 					return ValidationResult.VALID;
 				}
 			});
 		}
 
-		@Override
-		public SAML20AssertionValidator convert(Tuple tuple) {
-			Collection<ConditionValidator> conditions =
-					OpenSamlAuthenticationProvider.this.conditionValidators;
-			return new SAML20AssertionValidator(conditions, this.subjects, this.statements,
-					OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(tuple.authentication),
-					this.validator);
+		static SAML20AssertionValidator createAttributeValidator() {
+			return new SAML20AssertionValidator(conditions, subjects, statements, null, null) {
+				@Nonnull
+				@Override
+				protected ValidationResult validateSignature(Assertion token, ValidationContext context) {
+					return ValidationResult.VALID;
+				}
+			};
+		}
+
+		static SAML20AssertionValidator createSignatureValidator(SignatureTrustEngine engine) {
+			return new SAML20AssertionValidator(new ArrayList<>(), new ArrayList<>(), new ArrayList<>(),
+					engine, validator) {
+				@Nonnull
+				@Override
+				protected ValidationResult validateConditions(Assertion assertion, ValidationContext context) {
+					return ValidationResult.VALID;
+				}
+
+				@Nonnull
+				@Override
+				protected ValidationResult validateSubjectConfirmation(Assertion assertion, ValidationContext context) {
+					return ValidationResult.VALID;
+				}
+
+				@Nonnull
+				@Override
+				protected ValidationResult validateStatements(Assertion assertion, ValidationContext context) {
+					return ValidationResult.VALID;
+				}
+			};
 		}
 	}
 
@@ -643,25 +723,25 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 	}
 
 	/**
-	 * A tuple containing the authentication token and the associated OpenSAML {@link Response}.
+	 * A tuple containing an OpenSAML {@link Assertion} and its associated authentication token.
 	 *
 	 * @since 5.4
 	 */
-	public static class Tuple {
-		private final Saml2AuthenticationToken authentication;
-		private final Response response;
+	public static class AssertionToken {
+		private final Saml2AuthenticationToken token;
+		private final Assertion assertion;
 
-		private Tuple(Saml2AuthenticationToken authentication, Response response) {
-			this.authentication = authentication;
-			this.response = response;
+		private AssertionToken(Assertion assertion, Saml2AuthenticationToken token) {
+			this.token = token;
+			this.assertion = assertion;
 		}
 
-		public Saml2AuthenticationToken getAuthentication() {
-			return this.authentication;
+		public Assertion getAssertion() {
+			return this.assertion;
 		}
 
-		public Response getResponse() {
-			return this.response;
+		public Saml2AuthenticationToken getToken() {
+			return this.token;
 		}
 	}
 }

+ 89 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/core/Saml2ResponseValidatorResultTests.java

@@ -0,0 +1,89 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.saml2.core;
+
+import org.junit.Test;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
+/**
+ * Tests for verifying {@link Saml2ResponseValidatorResult}
+ *
+ * @author Josh Cummings
+ */
+public class Saml2ResponseValidatorResultTests {
+	private static final Saml2Error DETAIL = new Saml2Error(
+			"error", "description");
+
+	@Test
+	public void successWhenInvokedThenReturnsSuccessfulResult() {
+		Saml2ResponseValidatorResult success = Saml2ResponseValidatorResult.success();
+		assertThat(success.hasErrors()).isFalse();
+	}
+
+	@Test
+	public void failureWhenInvokedWithDetailReturnsFailureResultIncludingDetail() {
+		Saml2ResponseValidatorResult failure = Saml2ResponseValidatorResult.failure(DETAIL);
+
+		assertThat(failure.hasErrors()).isTrue();
+		assertThat(failure.getErrors()).containsExactly(DETAIL);
+	}
+
+	@Test
+	public void failureWhenInvokedWithMultipleDetailsReturnsFailureResultIncludingAll() {
+		Saml2ResponseValidatorResult failure = Saml2ResponseValidatorResult.failure(DETAIL, DETAIL);
+
+		assertThat(failure.hasErrors()).isTrue();
+		assertThat(failure.getErrors()).containsExactly(DETAIL, DETAIL);
+	}
+
+	@Test
+	public void concatErrorWhenInvokedThenReturnsCopyContainingAll() {
+		Saml2ResponseValidatorResult failure = Saml2ResponseValidatorResult.failure(DETAIL);
+		Saml2ResponseValidatorResult added = failure.concat(DETAIL);
+
+		assertThat(added.hasErrors()).isTrue();
+		assertThat(added.getErrors()).containsExactly(DETAIL, DETAIL);
+		assertThat(failure).isNotSameAs(added);
+	}
+
+	@Test
+	public void concatResultWhenInvokedThenReturnsCopyContainingAll() {
+		Saml2ResponseValidatorResult failure = Saml2ResponseValidatorResult.failure(DETAIL);
+		Saml2ResponseValidatorResult merged = failure
+				.concat(failure)
+				.concat(failure);
+
+		assertThat(merged.hasErrors()).isTrue();
+		assertThat(merged.getErrors()).containsExactly(DETAIL, DETAIL, DETAIL);
+		assertThat(failure).isNotSameAs(merged);
+	}
+
+	@Test
+	public void concatErrorWhenNullThenIllegalArgument() {
+		assertThatThrownBy(() -> Saml2ResponseValidatorResult.failure(DETAIL)
+				.concat((Saml2Error) null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void concatResultWhenNullThenIllegalArgument() {
+		assertThatThrownBy(() -> Saml2ResponseValidatorResult.failure(DETAIL)
+				.concat((Saml2ResponseValidatorResult) null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+}

+ 55 - 22
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java

@@ -45,12 +45,9 @@ import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
 import org.opensaml.core.xml.io.Marshaller;
 import org.opensaml.core.xml.io.MarshallingException;
 import org.opensaml.saml.common.assertion.ValidationContext;
-import org.opensaml.saml.common.assertion.ValidationResult;
-import org.opensaml.saml.saml2.assertion.impl.OneTimeUseConditionValidator;
 import org.opensaml.saml.saml2.core.Assertion;
 import org.opensaml.saml.saml2.core.AttributeStatement;
 import org.opensaml.saml.saml2.core.AttributeValue;
-import org.opensaml.saml.saml2.core.Condition;
 import org.opensaml.saml.saml2.core.EncryptedAssertion;
 import org.opensaml.saml.saml2.core.EncryptedID;
 import org.opensaml.saml.saml2.core.NameID;
@@ -60,13 +57,17 @@ import org.w3c.dom.Document;
 import org.w3c.dom.Element;
 import org.xml.sax.InputSource;
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.saml2.Saml2Exception;
+import org.springframework.security.saml2.core.Saml2Error;
+import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
 import org.springframework.security.saml2.credentials.Saml2X509Credential;
 
 import static java.util.Collections.singleton;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
@@ -76,11 +77,14 @@ import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getB
 import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory;
 import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
 import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED;
+import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_ASSERTION;
+import static org.springframework.security.saml2.core.Saml2ErrorCodes.INVALID_SIGNATURE;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyDecryptingCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
+import static org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.createDefaultAssertionValidator;
 import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.assertion;
 import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.attributeStatements;
 import static org.springframework.security.saml2.provider.service.authentication.TestOpenSamlObjects.encrypted;
@@ -365,10 +369,13 @@ public class OpenSamlAuthenticationProviderTests {
 	}
 
 	@Test
-	public void authenticateWhenConditionValidatorsCustomizedThenUses() throws Exception {
-		OneTimeUseConditionValidator validator = mock(OneTimeUseConditionValidator.class);
+	public void authenticateWhenDelegatingToDefaultAssertionValidatorThenUses() {
 		OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
-		provider.setConditionValidators(Collections.singleton(validator));
+		provider.setAssertionValidator(assertionToken -> {
+			ValidationContext context = new ValidationContext();
+			return createDefaultAssertionValidator(context).convert(assertionToken)
+					.concat(new Saml2Error("wrong error", "wrong error"));
+		});
 		Response response = response();
 		Assertion assertion = assertion();
 		OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME);
@@ -376,11 +383,46 @@ public class OpenSamlAuthenticationProviderTests {
 		response.getAssertions().add(assertion);
 		signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
 		Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
-		when(validator.getServicedCondition()).thenReturn(OneTimeUse.DEFAULT_ELEMENT_NAME);
-		when(validator.validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class)))
-				.thenReturn(ValidationResult.VALID);
+		assertThatThrownBy(() -> provider.authenticate(token))
+				.isInstanceOf(Saml2AuthenticationException.class)
+				.hasFieldOrPropertyWithValue("error.errorCode", INVALID_ASSERTION);
+	}
+
+	@Test
+	public void authenticateWhenCustomAssertionValidatorThenUses() {
+		Converter<OpenSamlAuthenticationProvider.AssertionToken, Saml2ResponseValidatorResult> validator =
+				mock(Converter.class);
+		OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+		provider.setAssertionValidator(assertionToken -> {
+			ValidationContext context = new ValidationContext(
+					Collections.singletonMap(SC_VALID_RECIPIENTS, singleton(DESTINATION)));
+			return createDefaultAssertionValidator(context).convert(assertionToken)
+					.concat(validator.convert(assertionToken));
+		});
+		Response response = response();
+		Assertion assertion = assertion();
+		response.getAssertions().add(assertion);
+		signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
+		Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
+		when(validator.convert(any(OpenSamlAuthenticationProvider.AssertionToken.class)))
+			.thenReturn(Saml2ResponseValidatorResult.success());
 		provider.authenticate(token);
-		verify(validator).validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class));
+		verify(validator).convert(any(OpenSamlAuthenticationProvider.AssertionToken.class));
+	}
+
+	@Test
+	public void authenticateWhenDefaultConditionValidatorNotUsedThenSignatureStillChecked() {
+		OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+		provider.setAssertionValidator(assertionToken -> Saml2ResponseValidatorResult.success());
+		Response response = response();
+		Assertion assertion = assertion();
+		signed(assertion, relyingPartyDecryptingCredential(), RELYING_PARTY_ENTITY_ID); // broken signature
+		response.getAssertions().add(assertion);
+		signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
+		Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
+		assertThatThrownBy(() -> provider.authenticate(token))
+				.isInstanceOf(Saml2AuthenticationException.class)
+				.hasFieldOrPropertyWithValue("error.errorCode", INVALID_SIGNATURE);
 	}
 
 	@Test
@@ -391,7 +433,7 @@ public class OpenSamlAuthenticationProviderTests {
 		ValidationContext context = mock(ValidationContext.class);
 		when(context.getStaticParameters()).thenReturn(parameters);
 		OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
-		provider.setValidationContextConverter(tuple -> context);
+		provider.setAssertionValidator(assertionToken -> createDefaultAssertionValidator(context).convert(assertionToken));
 		Response response = response();
 		Assertion assertion = assertion();
 		response.getAssertions().add(assertion);
@@ -402,17 +444,8 @@ public class OpenSamlAuthenticationProviderTests {
 	}
 
 	@Test
-	public void setValidationContextConverterWhenNullThenIllegalArgument() {
-		assertThatCode(() -> this.provider.setValidationContextConverter(null))
-				.isInstanceOf(IllegalArgumentException.class);
-	}
-
-	@Test
-	public void setConditionValidatorsWhenNullOrEmptyThenIllegalArgument() {
-		assertThatCode(() -> this.provider.setConditionValidators(null))
-				.isInstanceOf(IllegalArgumentException.class);
-
-		assertThatCode(() -> this.provider.setConditionValidators(Collections.emptyList()))
+	public void setAssertionValidatorWhenNullThenIllegalArgument() {
+		assertThatCode(() -> this.provider.setAssertionValidator(null))
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 

+ 3 - 3
samples/boot/saml2login/src/integration-test/java/org/springframework/security/saml2/provider/service/authentication/Saml2LoginIntegrationTests.java

@@ -242,7 +242,7 @@ public class Saml2LoginIntegrationTests {
 		sendResponse(response, "/login?error")
 				.andExpect(
 						saml2AuthenticationExceptionMatcher(
-								"invalid_assertion",
+								"invalid_signature",
 								containsString("Invalid assertion [assertion] for SAML response")
 						)
 				);
@@ -288,9 +288,9 @@ public class Saml2LoginIntegrationTests {
 				.andExpect(unauthenticated())
 				.andExpect(
 						saml2AuthenticationExceptionMatcher(
-								"invalid_issuer",
+								"invalid_signature",
 								containsString(
-										"Invalid issuer"
+										"Invalid signature"
 								)
 						)
 				);