소스 검색

Add ConditionValidator Support

Closes gh-8769
Josh Cummings 5 년 전
부모
커밋
a402c3884a

+ 76 - 28
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java

@@ -21,13 +21,13 @@ import java.time.Duration;
 import java.time.Instant;
 import java.util.ArrayList;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.function.Consumer;
 import java.util.function.Function;
 import javax.annotation.Nonnull;
 
@@ -193,10 +193,12 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 
 	private Converter<Saml2AuthenticationToken, SignatureTrustEngine> signatureTrustEngineConverter =
 			new SignatureTrustEngineConverter();
-	private Converter<Saml2AuthenticationToken, SAML20AssertionValidator> assertionValidatorConverter =
+	private Converter<Tuple, SAML20AssertionValidator> assertionValidatorConverter =
 			new SAML20AssertionValidatorConverter();
-	private Converter<Saml2AuthenticationToken, ValidationContext> validationContextConverter =
-			new ValidationContextConverter(params -> {});
+	private Collection<ConditionValidator> conditionValidators =
+			Collections.singleton(new AudienceRestrictionConditionValidator());
+	private Converter<Tuple, ValidationContext> validationContextConverter =
+			new ValidationContextConverter();
 	private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
 
 	/**
@@ -209,6 +211,33 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		this.parserPool = this.registry.getParserPool();
 	}
 
+	/**
+	 * Set the the collection of {@link ConditionValidator}s used when validating an assertion.
+	 *
+	 * @param conditionValidators the collection of validators to use
+	 * @since 5.4
+	 */
+	public void setConditionValidators(
+			Collection<ConditionValidator> conditionValidators) {
+
+		Assert.notEmpty(conditionValidators, "conditionValidators cannot be empty");
+		this.conditionValidators = conditionValidators;
+	}
+
+	/**
+	 * Set the strategy for retrieving the {@link ValidationContext} used when
+	 * validating an assertion.
+	 *
+	 * @param validationContextConverter the strategy to use
+	 * @since 5.4
+	 */
+	public void setValidationContextConverter(
+			Converter<Tuple, ValidationContext> validationContextConverter) {
+
+		Assert.notNull(validationContextConverter, "validationContextConverter cannot be empty");
+		this.validationContextConverter = validationContextConverter;
+	}
+
 	/**
 	 * Sets the {@link Converter} used for extracting assertion attributes that
 	 * can be mapped to authorities.
@@ -238,8 +267,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 	 */
 	public void setResponseTimeValidationSkew(Duration responseTimeValidationSkew) {
 		this.responseTimeValidationSkew = responseTimeValidationSkew;
-		this.validationContextConverter = new ValidationContextConverter(
-				params -> params.put(CLOCK_SKEW, responseTimeValidationSkew.toMillis()));
 	}
 
 	/**
@@ -303,7 +330,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 			throw authException(INVALID_SIGNATURE, "Either the response or one of the assertions is unsigned. " +
 					"Please either sign the response or all of the assertions.");
 		}
-		validationExceptions.putAll(validateAssertions(token, assertions));
+		validationExceptions.putAll(validateAssertions(token, response));
 
 		Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
 		NameID nameId = decryptPrincipal(decrypter, firstAssertion);
@@ -392,7 +419,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 	}
 
 	private Map<String, Saml2AuthenticationException> validateAssertions
-			(Saml2AuthenticationToken token, List<Assertion> assertions) {
+			(Saml2AuthenticationToken token, Response response) {
+		List<Assertion> assertions = response.getAssertions();
 		if (assertions.isEmpty()) {
 			throw authException(MALFORMED_RESPONSE_DATA, "No assertions found in response.");
 		}
@@ -401,14 +429,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		if (logger.isDebugEnabled()) {
 			logger.debug("Validating " + assertions.size() + " assertions");
 		}
+
+		Tuple tuple = new Tuple(token, response);
+		SAML20AssertionValidator validator = this.assertionValidatorConverter.convert(tuple);
+		ValidationContext context = this.validationContextConverter.convert(tuple);
 		for (Assertion assertion : assertions) {
 			if (logger.isTraceEnabled()) {
 				logger.trace("Validating assertion " + assertion.getID());
 			}
 			try {
-				ValidationContext context = this.validationContextConverter.convert(token);
-				ValidationResult result = this.assertionValidatorConverter.convert(token).validate(assertion, context);
-				if (result != ValidationResult.VALID) {
+				if (validator.validate(assertion, context) != ValidationResult.VALID) {
 					String message = String.format("Invalid assertion [%s] for SAML response [%s]: %s",
 							assertion.getID(), ((Response) assertion.getParent()).getID(),
 							context.getValidationFailureMessage());
@@ -512,6 +542,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 	}
 
 	private static class SignatureTrustEngineConverter implements Converter<Saml2AuthenticationToken, SignatureTrustEngine> {
+
 		@Override
 		public SignatureTrustEngine convert(Saml2AuthenticationToken token) {
 			Set<Credential> credentials = new HashSet<>();
@@ -530,35 +561,27 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		}
 	}
 
-	private static class ValidationContextConverter implements Converter<Saml2AuthenticationToken, ValidationContext> {
-		Consumer<Map<String, Object>> validationContextParametersConverter;
-
-		ValidationContextConverter(Consumer<Map<String, Object>> validationContextParametersConverter) {
-			this.validationContextParametersConverter = validationContextParametersConverter;
-		}
+	private class ValidationContextConverter implements Converter<Tuple, ValidationContext> {
 
 		@Override
-		public ValidationContext convert(Saml2AuthenticationToken token) {
-			String audience = token.getRelyingPartyRegistration().getEntityId();
-			String recipient = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
+		public ValidationContext convert(Tuple tuple) {
+			String audience = tuple.authentication.getRelyingPartyRegistration().getEntityId();
+			String recipient = tuple.authentication.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
 			Map<String, Object> params = new HashMap<>();
-			params.put(CLOCK_SKEW, Duration.ofMinutes(5).toMillis());
+			params.put(CLOCK_SKEW, OpenSamlAuthenticationProvider.this.responseTimeValidationSkew.toMillis());
 			params.put(COND_VALID_AUDIENCES, singleton(audience));
 			params.put(SC_VALID_RECIPIENTS, singleton(recipient));
 			params.put(SIGNATURE_REQUIRED, false); // this verification is performed earlier
-			this.validationContextParametersConverter.accept(params);
 			return new ValidationContext(params);
 		}
 	}
 
-	private class SAML20AssertionValidatorConverter implements Converter<Saml2AuthenticationToken, SAML20AssertionValidator> {
-		private final Collection<ConditionValidator> conditions = new ArrayList<>();
+	private class SAML20AssertionValidatorConverter implements Converter<Tuple, SAML20AssertionValidator> {
 		private final Collection<SubjectConfirmationValidator> subjects = new ArrayList<>();
 		private final Collection<StatementValidator> statements = new ArrayList<>();
 		private final SignaturePrevalidator validator = new SAMLSignatureProfileValidator();
 
 		SAML20AssertionValidatorConverter() {
-			this.conditions.add(new AudienceRestrictionConditionValidator());
 			this.subjects.add(new BearerSubjectConfirmationValidator() {
 				@Nonnull
 				@Override
@@ -571,9 +594,11 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		}
 
 		@Override
-		public SAML20AssertionValidator convert(Saml2AuthenticationToken token) {
-			return new SAML20AssertionValidator(this.conditions, this.subjects, this.statements,
-					OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(token),
+		public SAML20AssertionValidator convert(Tuple tuple) {
+			Collection<ConditionValidator> conditions =
+					OpenSamlAuthenticationProvider.this.conditionValidators;
+			return new SAML20AssertionValidator(conditions, this.subjects, this.statements,
+					OpenSamlAuthenticationProvider.this.signatureTrustEngineConverter.convert(tuple.authentication),
 					this.validator);
 		}
 	}
@@ -616,4 +641,27 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 
 		return new Saml2AuthenticationException(validationError(code, description), cause);
 	}
+
+	/**
+	 * A tuple containing the authentication token and the associated OpenSAML {@link Response}.
+	 *
+	 * @since 5.4
+	 */
+	public static class Tuple {
+		private final Saml2AuthenticationToken authentication;
+		private final Response response;
+
+		private Tuple(Saml2AuthenticationToken authentication, Response response) {
+			this.authentication = authentication;
+			this.response = response;
+		}
+
+		public Saml2AuthenticationToken getAuthentication() {
+			return this.authentication;
+		}
+
+		public Response getResponse() {
+			return this.response;
+		}
+	}
 }

+ 67 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java

@@ -23,9 +23,11 @@ import java.io.StringReader;
 import java.time.Instant;
 import java.util.Arrays;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import javax.xml.namespace.QName;
 import javax.xml.parsers.DocumentBuilder;
 import javax.xml.parsers.DocumentBuilderFactory;
 
@@ -42,12 +44,17 @@ import org.opensaml.core.xml.XMLObject;
 import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
 import org.opensaml.core.xml.io.Marshaller;
 import org.opensaml.core.xml.io.MarshallingException;
+import org.opensaml.saml.common.assertion.ValidationContext;
+import org.opensaml.saml.common.assertion.ValidationResult;
+import org.opensaml.saml.saml2.assertion.impl.OneTimeUseConditionValidator;
 import org.opensaml.saml.saml2.core.Assertion;
 import org.opensaml.saml.saml2.core.AttributeStatement;
 import org.opensaml.saml.saml2.core.AttributeValue;
+import org.opensaml.saml.saml2.core.Condition;
 import org.opensaml.saml.saml2.core.EncryptedAssertion;
 import org.opensaml.saml.saml2.core.EncryptedID;
 import org.opensaml.saml.saml2.core.NameID;
+import org.opensaml.saml.saml2.core.OneTimeUse;
 import org.opensaml.saml.saml2.core.Response;
 import org.w3c.dom.Document;
 import org.w3c.dom.Element;
@@ -57,7 +64,9 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.saml2.Saml2Exception;
 import org.springframework.security.saml2.credentials.Saml2X509Credential;
 
+import static java.util.Collections.singleton;
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
@@ -65,6 +74,8 @@ import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getBuilderFactory;
 import static org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport.getMarshallerFactory;
+import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS;
+import static org.opensaml.saml.saml2.assertion.SAML2AssertionValidationParameters.SIGNATURE_REQUIRED;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyEncryptingCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartyPrivateCredential;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.assertingPartySigningCredential;
@@ -353,6 +364,62 @@ public class OpenSamlAuthenticationProviderTests {
 		objectOutputStream.flush();
 	}
 
+	@Test
+	public void authenticateWhenConditionValidatorsCustomizedThenUses() throws Exception {
+		OneTimeUseConditionValidator validator = mock(OneTimeUseConditionValidator.class);
+		OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+		provider.setConditionValidators(Collections.singleton(validator));
+		Response response = response();
+		Assertion assertion = assertion();
+		OneTimeUse oneTimeUse = build(OneTimeUse.DEFAULT_ELEMENT_NAME);
+		assertion.getConditions().getConditions().add(oneTimeUse);
+		response.getAssertions().add(assertion);
+		signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
+		Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
+		when(validator.getServicedCondition()).thenReturn(OneTimeUse.DEFAULT_ELEMENT_NAME);
+		when(validator.validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class)))
+				.thenReturn(ValidationResult.VALID);
+		provider.authenticate(token);
+		verify(validator).validate(any(Condition.class), any(Assertion.class), any(ValidationContext.class));
+	}
+
+	@Test
+	public void authenticateWhenValidationContextCustomizedThenUsers() {
+		Map<String, Object> parameters = new HashMap<>();
+		parameters.put(SC_VALID_RECIPIENTS, singleton(DESTINATION));
+		parameters.put(SIGNATURE_REQUIRED, false);
+		ValidationContext context = mock(ValidationContext.class);
+		when(context.getStaticParameters()).thenReturn(parameters);
+		OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+		provider.setValidationContextConverter(tuple -> context);
+		Response response = response();
+		Assertion assertion = assertion();
+		response.getAssertions().add(assertion);
+		signed(response, assertingPartySigningCredential(), ASSERTING_PARTY_ENTITY_ID);
+		Saml2AuthenticationToken token = token(response, relyingPartyVerifyingCredential());
+		provider.authenticate(token);
+		verify(context, atLeastOnce()).getStaticParameters();
+	}
+
+	@Test
+	public void setValidationContextConverterWhenNullThenIllegalArgument() {
+		assertThatCode(() -> this.provider.setValidationContextConverter(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void setConditionValidatorsWhenNullOrEmptyThenIllegalArgument() {
+		assertThatCode(() -> this.provider.setConditionValidators(null))
+				.isInstanceOf(IllegalArgumentException.class);
+
+		assertThatCode(() -> this.provider.setConditionValidators(Collections.emptyList()))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	private <T extends XMLObject> T build(QName qName) {
+		return (T) getBuilderFactory().getBuilder(qName).buildObject(qName);
+	}
+
 	private String serialize(XMLObject object) {
 		try {
 			Marshaller marshaller = getMarshallerFactory().getMarshaller(object);