浏览代码

Polish Configurable SAML Decryption Support

- Renamed to setResponseElementsDecrypter and
setAssertionElementsDecrypter to align with ResponseToken and
AssertionToken
- Changed contract of setAssertionElementsDecrypter to use
AssertionToken
- Changed assertions in unit test to use isEqualTo

Issue gh-9044
Josh Cummings 4 年之前
父节点
当前提交
d0581c9a26

+ 248 - 190
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java

@@ -88,7 +88,6 @@ import org.opensaml.security.criteria.UsageCriterion;
 import org.opensaml.security.x509.BasicX509Credential;
 import org.opensaml.xmlsec.config.impl.DefaultSecurityConfigurationBootstrap;
 import org.opensaml.xmlsec.encryption.support.ChainingEncryptedKeyResolver;
-import org.opensaml.xmlsec.encryption.support.DecryptionException;
 import org.opensaml.xmlsec.encryption.support.EncryptedKeyResolver;
 import org.opensaml.xmlsec.encryption.support.InlineEncryptedKeyResolver;
 import org.opensaml.xmlsec.encryption.support.SimpleRetrievalMethodEncryptedKeyResolver;
@@ -185,58 +184,23 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 
 	private Duration responseTimeValidationSkew = Duration.ofMinutes(5);
 
-	private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = (
-			responseToken) -> {
-		Response response = responseToken.response;
-		Saml2AuthenticationToken token = responseToken.token;
-		Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
-		String username = assertion.getSubject().getNameID().getValue();
-		Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
-		return new Saml2Authentication(new DefaultSaml2AuthenticatedPrincipal(username, attributes),
-				token.getSaml2Response(), this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
-	};
-
-	private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionSignatureValidator = createDefaultAssertionValidator(
-			Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> {
-				SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(assertionToken.token);
-				return SAML20AssertionValidators.createSignatureValidator(engine);
-			}, (assertionToken) -> new ValidationContext(
-					Collections.singletonMap(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false)));
-
-	private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionValidator = createDefaultAssertionValidator(
-			Saml2ErrorCodes.INVALID_ASSERTION, (assertionToken) -> SAML20AssertionValidators.attributeValidator,
-			(assertionToken) -> createValidationContext(assertionToken, (params) -> params
-					.put(SAML2AssertionValidationParameters.CLOCK_SKEW, this.responseTimeValidationSkew.toMillis())));
+	private Converter<ResponseToken, Saml2ResponseValidatorResult> responseSignatureValidator = createDefaultResponseSignatureValidator();
 
-	private Converter<Saml2AuthenticationToken, SignatureTrustEngine> signatureTrustEngineConverter = new SignatureTrustEngineConverter();
+	private Consumer<ResponseToken> responseElementsDecrypter = createDefaultResponseElementsDecrypter();
 
-	private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
+	private Converter<ResponseToken, Saml2ResponseValidatorResult> responseValidator = createDefaultResponseValidator();
 
-	private Consumer<ResponseToken> assertionDecrypter = (responseToken) -> {
-		List<Assertion> assertions = new ArrayList<>();
-		for (EncryptedAssertion encryptedAssertion : responseToken.getResponse().getEncryptedAssertions()) {
-			try {
-				Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
-				Assertion assertion = decrypter.decrypt(encryptedAssertion);
-				assertions.add(assertion);
-			}
-			catch (DecryptionException ex) {
-				throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
-			}
-		}
-		responseToken.getResponse().getAssertions().addAll(assertions);
-	};
+	private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionSignatureValidator = createDefaultAssertionSignatureValidator();
 
-	private Consumer<ResponseToken> principalDecrypter = (responseToken) -> {
-		try {
-			Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
-			Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
-			assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID()));
-		}
-		catch (DecryptionException ex) {
-			throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
-		}
-	};
+	private Consumer<AssertionToken> assertionElementsDecrypter = createDefaultAssertionElementsDecrypter();
+
+	private Converter<AssertionToken, Saml2ResponseValidatorResult> assertionValidator = createCompatibleAssertionValidator();
+
+	private Converter<ResponseToken, ? extends AbstractAuthenticationToken> responseAuthenticationConverter = createCompatibleResponseAuthenticationConverter();
+
+	private Converter<Saml2AuthenticationToken, SignatureTrustEngine> signatureTrustEngineConverter = new SignatureTrustEngineConverter();
+
+	private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
 
 	/**
 	 * Creates an {@link OpenSamlAuthenticationProvider}
@@ -248,12 +212,60 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		this.parserPool = this.registry.getParserPool();
 	}
 
+	/**
+	 * Set the {@link Consumer} strategy to use for decrypting elements of a validated
+	 * {@link Response}. The default strategy decrypts all {@link EncryptedAssertion}s
+	 * using OpenSAML's {@link Decrypter}, adding the results to
+	 * {@link Response#getAssertions()}.
+	 *
+	 * You can use this method to configure the {@link Decrypter} instance like so:
+	 *
+	 * <pre>
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	provider.setResponseElementsDecrypter((responseToken) -> {
+	 *	    DecrypterParameters parameters = new DecrypterParameters();
+	 *	    // ... set parameters as needed
+	 *	    Decrypter decrypter = new Decrypter(parameters);
+	 *		Response response = responseToken.getResponse();
+	 *  	EncryptedAssertion encrypted = response.getEncryptedAssertions().get(0);
+	 *  	try {
+	 *  		Assertion assertion = decrypter.decrypt(encrypted);
+	 *  		response.getAssertions().add(assertion);
+	 *  	} catch (Exception e) {
+	 *  	 	throw new Saml2AuthenticationException(...);
+	 *  	}
+	 *	});
+	 * </pre>
+	 *
+	 * Or, in the event that you have your own custom decryption interface, the same
+	 * pattern applies:
+	 *
+	 * <pre>
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	Converter&lt;EncryptedAssertion, Assertion&gt; myService = ...
+	 *	provider.setResponseDecrypter((responseToken) -> {
+	 *	   Response response = responseToken.getResponse();
+	 *	   response.getEncryptedAssertions().stream()
+	 *	   		.map(service::decrypt).forEach(response.getAssertions()::add);
+	 *	});
+	 * </pre>
+	 *
+	 * This is valuable when using an external service to perform the decryption.
+	 * @param responseElementsDecrypter the {@link Consumer} for decrypting response
+	 * elements
+	 * @since 5.5
+	 */
+	public void setResponseElementsDecrypter(Consumer<ResponseToken> responseElementsDecrypter) {
+		Assert.notNull(responseElementsDecrypter, "responseElementsDecrypter cannot be null");
+		this.responseElementsDecrypter = responseElementsDecrypter;
+	}
+
 	/**
 	 * Set the {@link Converter} to use for validating each {@link Assertion} in the SAML
 	 * 2.0 Response.
 	 *
 	 * You can still invoke the default validator by delgating to
-	 * {@link #createDefaultAssertionValidator}, like so:
+	 * {@link #createAssertionValidator}, like so:
 	 *
 	 * <pre>
 	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
@@ -294,6 +306,49 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		this.assertionValidator = assertionValidator;
 	}
 
+	/**
+	 * Set the {@link Consumer} strategy to use for decrypting elements of a validated
+	 * {@link Assertion}.
+	 *
+	 * You can use this method to configure the {@link Decrypter} used like so:
+	 *
+	 * <pre>
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	provider.setResponseDecrypter((assertionToken) -> {
+	 *	    DecrypterParameters parameters = new DecrypterParameters();
+	 *	    // ... set parameters as needed
+	 *	    Decrypter decrypter = new Decrypter(parameters);
+	 *		Assertion assertion = assertionToken.getAssertion();
+	 *  	EncryptedID encrypted = assertion.getSubject().getEncryptedID();
+	 *  	try {
+	 *  		NameID name = decrypter.decrypt(encrypted);
+	 *  		assertion.getSubject().setNameID(name);
+	 *  	} catch (Exception e) {
+	 *  	 	throw new Saml2AuthenticationException(...);
+	 *  	}
+	 *	});
+	 * </pre>
+	 *
+	 * Or, in the event that you have your own custom interface, the same pattern applies:
+	 *
+	 * <pre>
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	MyDecryptionService myService = ...
+	 *	provider.setResponseDecrypter((responseToken) -> {
+	 *	   	Assertion assertion = assertionToken.getAssertion();
+	 *	   	EncryptedID encrypted = assertion.getSubject().getEncryptedID();
+	 *		NameID name = myService.decrypt(encrypted);
+	 *		assertion.getSubject().setNameID(name);
+	 *	});
+	 * </pre>
+	 * @param assertionDecrypter the {@link Consumer} for decrypting assertion elements
+	 * @since 5.5
+	 */
+	public void setAssertionElementsDecrypter(Consumer<AssertionToken> assertionDecrypter) {
+		Assert.notNull(assertionDecrypter, "assertionDecrypter cannot be null");
+		this.assertionElementsDecrypter = assertionDecrypter;
+	}
+
 	/**
 	 * Set the {@link Converter} to use for converting a validated {@link Response} into
 	 * an {@link AbstractAuthenticationToken}.
@@ -359,52 +414,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		this.responseTimeValidationSkew = responseTimeValidationSkew;
 	}
 
-	/**
-	 * Sets the assertion response custom decrypter.
-	 *
-	 * You can use this method like so:
-	 *
-	 * <pre>
-	 *	YourDecrypter decrypter = // ... your custom decrypter
-	 *
-	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
-	 *	provider.setAssertionDecrypter((responseToken) -> {
-	 *		Response response = responseToken.getResponse();
-	 *  	EncryptedAssertion encrypted = response.getEncryptedAssertions().get(0);
-	 *  	Assertion assertion = decrypter.decrypt(encrypted);
-	 *  	response.getAssertions().add(assertion);
-	 *	});
-	 * </pre>
-	 * @param assertionDecrypter response token consumer
-	 */
-	public void setAssertionDecrypter(Consumer<ResponseToken> assertionDecrypter) {
-		Assert.notNull(assertionDecrypter, "Consumer<ResponseToken> required");
-		this.assertionDecrypter = assertionDecrypter;
-	}
-
-	/**
-	 * Sets the principal custom decrypter.
-	 *
-	 * You can use this method like so:
-	 *
-	 * <pre>
-	 *	YourDecrypter decrypter = // ... your custom decrypter
-	 *
-	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
-	 *	provider.setAssertionDecrypter((responseToken) -> {
-	 *		Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
-	 *		EncryptedID encrypted = assertion.getSubject().getEncryptedID();
-	 *		NameID name = decrypter.decrypt(encrypted);
-	 *		assertion.getSubject().setNameID(name)
-	 *	});
-	 * </pre>
-	 * @param principalDecrypter response token consumer
-	 */
-	public void setPrincipalDecrypter(Consumer<ResponseToken> principalDecrypter) {
-		Assert.notNull(principalDecrypter, "Consumer<ResponseToken> required");
-		this.principalDecrypter = principalDecrypter;
-	}
-
 	/**
 	 * Construct a default strategy for validating each SAML 2.0 Assertion and associated
 	 * {@link Authentication} token
@@ -413,7 +422,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 	 */
 	public static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidator() {
 
-		return createDefaultAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION,
+		return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION,
 				(assertionToken) -> SAML20AssertionValidators.attributeValidator,
 				(assertionToken) -> createValidationContext(assertionToken, (params) -> {
 				}));
@@ -430,7 +439,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 	public static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidator(
 			Converter<AssertionToken, ValidationContext> contextConverter) {
 
-		return createDefaultAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION,
+		return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION,
 				(assertionToken) -> SAML20AssertionValidators.attributeValidator, contextConverter);
 	}
 
@@ -480,10 +489,6 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		return authentication != null && Saml2AuthenticationToken.class.isAssignableFrom(authentication);
 	}
 
-	private Collection<? extends GrantedAuthority> getAssertionAuthorities(Assertion assertion) {
-		return this.authoritiesExtractor.convert(assertion);
-	}
-
 	private Response parse(String response) throws Saml2Exception, Saml2AuthenticationException {
 		try {
 			Document document = this.parserPool
@@ -500,20 +505,30 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		String issuer = response.getIssuer().getValue();
 		logger.debug(LogMessage.format("Processing SAML response from %s", issuer));
 		boolean responseSigned = response.isSigned();
-		Saml2ResponseValidatorResult result = validateResponse(token, response);
 
 		ResponseToken responseToken = new ResponseToken(response, token);
-		List<Assertion> assertions = decryptAssertions(responseToken);
-		if (!isSigned(responseSigned, assertions)) {
+		Saml2ResponseValidatorResult result = this.responseSignatureValidator.convert(responseToken);
+		if (responseSigned) {
+			this.responseElementsDecrypter.accept(responseToken);
+		}
+		result = result.concat(this.responseValidator.convert(responseToken));
+		boolean allAssertionsSigned = true;
+		for (Assertion assertion : response.getAssertions()) {
+			AssertionToken assertionToken = new AssertionToken(assertion, token);
+			result = result.concat(this.assertionSignatureValidator.convert(assertionToken));
+			allAssertionsSigned = allAssertionsSigned && assertion.isSigned();
+			if (responseSigned || assertion.isSigned()) {
+				this.assertionElementsDecrypter.accept(new AssertionToken(assertion, token));
+			}
+			result = result.concat(this.assertionValidator.convert(assertionToken));
+		}
+		if (!responseSigned && !allAssertionsSigned) {
 			String description = "Either the response or one of the assertions is unsigned. "
 					+ "Please either sign the response or all of the assertions.";
 			throw createAuthenticationException(Saml2ErrorCodes.INVALID_SIGNATURE, description, null);
 		}
-		result = result.concat(validateAssertions(token, response));
-
 		Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
-		NameID nameId = decryptPrincipal(responseToken);
-		if (nameId == null || nameId.getValue() == null) {
+		if (!hasName(firstAssertion)) {
 			Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
 					"Assertion [" + firstAssertion.getID() + "] is missing a subject");
 			result = result.concat(error);
@@ -539,107 +554,150 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		}
 	}
 
-	private Saml2ResponseValidatorResult validateResponse(Saml2AuthenticationToken token, Response response) {
-
-		Collection<Saml2Error> errors = new ArrayList<>();
-		String issuer = response.getIssuer().getValue();
-		if (response.isSigned()) {
-			SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator();
-			try {
-				profileValidator.validate(response.getSignature());
-			}
-			catch (Exception ex) {
-				errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
-						"Invalid signature for SAML Response [" + response.getID() + "]: "));
-			}
+	private Converter<ResponseToken, Saml2ResponseValidatorResult> createDefaultResponseSignatureValidator() {
+		return (responseToken) -> {
+			Response response = responseToken.getResponse();
+			Saml2AuthenticationToken token = responseToken.getToken();
+			Collection<Saml2Error> errors = new ArrayList<>();
+			String issuer = response.getIssuer().getValue();
+			if (response.isSigned()) {
+				SAMLSignatureProfileValidator profileValidator = new SAMLSignatureProfileValidator();
+				try {
+					profileValidator.validate(response.getSignature());
+				}
+				catch (Exception ex) {
+					errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
+							"Invalid signature for SAML Response [" + response.getID() + "]: "));
+				}
 
-			try {
-				CriteriaSet criteriaSet = new CriteriaSet();
-				criteriaSet.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer)));
-				criteriaSet.add(
-						new EvaluableProtocolRoleDescriptorCriterion(new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
-				criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
-				if (!this.signatureTrustEngineConverter.convert(token).validate(response.getSignature(), criteriaSet)) {
+				try {
+					CriteriaSet criteriaSet = new CriteriaSet();
+					criteriaSet.add(new EvaluableEntityIDCredentialCriterion(new EntityIdCriterion(issuer)));
+					criteriaSet.add(new EvaluableProtocolRoleDescriptorCriterion(
+							new ProtocolCriterion(SAMLConstants.SAML20P_NS)));
+					criteriaSet.add(new EvaluableUsageCredentialCriterion(new UsageCriterion(UsageType.SIGNING)));
+					if (!this.signatureTrustEngineConverter.convert(token).validate(response.getSignature(),
+							criteriaSet)) {
+						errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
+								"Invalid signature for SAML Response [" + response.getID() + "]"));
+					}
+				}
+				catch (Exception ex) {
 					errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
-							"Invalid signature for SAML Response [" + response.getID() + "]"));
+							"Invalid signature for SAML Response [" + response.getID() + "]: "));
 				}
 			}
-			catch (Exception ex) {
-				errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_SIGNATURE,
-						"Invalid signature for SAML Response [" + response.getID() + "]: "));
-			}
-		}
-		String destination = response.getDestination();
-		String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
-		if (StringUtils.hasText(destination) && !destination.equals(location)) {
-			String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID() + "]";
-			errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message));
-		}
-		String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails().getEntityId();
-		if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
-			String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
-			errors.add(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message));
-		}
 
-		return Saml2ResponseValidatorResult.failure(errors);
+			return Saml2ResponseValidatorResult.failure(errors);
+		};
 	}
 
-	private List<Assertion> decryptAssertions(ResponseToken response) {
-		this.assertionDecrypter.accept(response);
-		return response.getResponse().getAssertions();
+	private Consumer<ResponseToken> createDefaultResponseElementsDecrypter() {
+		return (responseToken) -> {
+			Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
+			Response response = responseToken.getResponse();
+			for (EncryptedAssertion encryptedAssertion : responseToken.getResponse().getEncryptedAssertions()) {
+				try {
+					Assertion assertion = decrypter.decrypt(encryptedAssertion);
+					response.getAssertions().add(assertion);
+				}
+				catch (Exception ex) {
+					throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
+				}
+			}
+		};
 	}
 
-	private Saml2ResponseValidatorResult validateAssertions(Saml2AuthenticationToken token, Response response) {
-		List<Assertion> assertions = response.getAssertions();
-		if (assertions.isEmpty()) {
-			throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA,
-					"No assertions found in response.", null);
-		}
+	private Converter<ResponseToken, Saml2ResponseValidatorResult> createDefaultResponseValidator() {
+		return (responseToken) -> {
+			Response response = responseToken.getResponse();
+			Saml2AuthenticationToken token = responseToken.getToken();
+			Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
+			String issuer = response.getIssuer().getValue();
+			String destination = response.getDestination();
+			String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
+			if (StringUtils.hasText(destination) && !destination.equals(location)) {
+				String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID()
+						+ "]";
+				result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message));
+			}
+			String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails()
+					.getEntityId();
+			if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
+				String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
+				result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message));
+			}
+			if (response.getAssertions().isEmpty()) {
+				throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA,
+						"No assertions found in response.", null);
+			}
+			return result;
+		};
+	}
 
-		Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
-		if (logger.isDebugEnabled()) {
-			logger.debug("Validating " + assertions.size() + " assertions");
-		}
+	private Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionSignatureValidator() {
+		return createAssertionValidator(Saml2ErrorCodes.INVALID_SIGNATURE, (assertionToken) -> {
+			SignatureTrustEngine engine = this.signatureTrustEngineConverter.convert(assertionToken.token);
+			return SAML20AssertionValidators.createSignatureValidator(engine);
+		}, (assertionToken) -> new ValidationContext(
+				Collections.singletonMap(SAML2AssertionValidationParameters.SIGNATURE_REQUIRED, false)));
+	}
 
-		for (Assertion assertion : assertions) {
-			if (logger.isTraceEnabled()) {
-				logger.trace("Validating assertion " + assertion.getID());
+	private Consumer<AssertionToken> createDefaultAssertionElementsDecrypter() {
+		return (assertionToken) -> {
+			Decrypter decrypter = this.decrypterConverter.convert(assertionToken.getToken());
+			Assertion assertion = assertionToken.getAssertion();
+			if (assertion.getSubject() == null) {
+				return;
 			}
-			AssertionToken assertionToken = new AssertionToken(assertion, token);
-			result = result.concat(this.assertionSignatureValidator.convert(assertionToken))
-					.concat(this.assertionValidator.convert(assertionToken));
-		}
+			if (assertion.getSubject().getEncryptedID() == null) {
+				return;
+			}
+			try {
+				assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID()));
+			}
+			catch (Exception ex) {
+				throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
+			}
+		};
+	}
 
-		return result;
+	private Converter<AssertionToken, Saml2ResponseValidatorResult> createCompatibleAssertionValidator() {
+		return createAssertionValidator(Saml2ErrorCodes.INVALID_ASSERTION,
+				(assertionToken) -> SAML20AssertionValidators.attributeValidator,
+				(assertionToken) -> createValidationContext(assertionToken,
+						(params) -> params.put(SAML2AssertionValidationParameters.CLOCK_SKEW,
+								this.responseTimeValidationSkew.toMillis())));
 	}
 
-	private void addValidationException(Map<String, Saml2AuthenticationException> exceptions, String code,
-			String message, Exception cause) {
-		exceptions.put(code, createAuthenticationException(code, message, cause));
+	private Converter<ResponseToken, Saml2Authentication> createCompatibleResponseAuthenticationConverter() {
+		return (responseToken) -> {
+			Response response = responseToken.response;
+			Saml2AuthenticationToken token = responseToken.token;
+			Assertion assertion = CollectionUtils.firstElement(response.getAssertions());
+			String username = assertion.getSubject().getNameID().getValue();
+			Map<String, List<Object>> attributes = getAssertionAttributes(assertion);
+			return new Saml2Authentication(new DefaultSaml2AuthenticatedPrincipal(username, attributes),
+					token.getSaml2Response(),
+					this.authoritiesMapper.mapAuthorities(getAssertionAuthorities(assertion)));
+		};
 	}
 
-	private boolean isSigned(boolean responseSigned, List<Assertion> assertions) {
-		if (responseSigned) {
-			return true;
-		}
-		for (Assertion assertion : assertions) {
-			if (!assertion.isSigned()) {
-				return false;
-			}
-		}
-		return true;
+	private Collection<? extends GrantedAuthority> getAssertionAuthorities(Assertion assertion) {
+		return this.authoritiesExtractor.convert(assertion);
 	}
 
-	private NameID decryptPrincipal(ResponseToken responseToken) {
-		Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
+	private boolean hasName(Assertion assertion) {
+		if (assertion == null) {
+			return false;
+		}
 		if (assertion.getSubject() == null) {
-			return null;
+			return false;
 		}
-		if (assertion.getSubject().getEncryptedID() == null) {
-			return assertion.getSubject().getNameID();
+		if (assertion.getSubject().getNameID() == null) {
+			return false;
 		}
-		this.principalDecrypter.accept(responseToken);
-		return assertion.getSubject().getNameID();
+		return assertion.getSubject().getNameID().getValue() != null;
 	}
 
 	private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
@@ -688,8 +746,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		return new Saml2AuthenticationException(new Saml2Error(code, message), cause);
 	}
 
-	private static Converter<AssertionToken, Saml2ResponseValidatorResult> createDefaultAssertionValidator(
-			String errorCode, Converter<AssertionToken, SAML20AssertionValidator> validatorConverter,
+	private static Converter<AssertionToken, Saml2ResponseValidatorResult> createAssertionValidator(String errorCode,
+			Converter<AssertionToken, SAML20AssertionValidator> validatorConverter,
 			Converter<AssertionToken, ValidationContext> contextConverter) {
 
 		return (assertionToken) -> {

+ 36 - 28
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java

@@ -47,6 +47,10 @@ import org.opensaml.saml.saml2.core.EncryptedID;
 import org.opensaml.saml.saml2.core.NameID;
 import org.opensaml.saml.saml2.core.OneTimeUse;
 import org.opensaml.saml.saml2.core.Response;
+import org.opensaml.saml.saml2.core.impl.EncryptedAssertionBuilder;
+import org.opensaml.saml.saml2.core.impl.EncryptedIDBuilder;
+import org.opensaml.saml.saml2.core.impl.NameIDBuilder;
+import org.opensaml.xmlsec.encryption.impl.EncryptedDataBuilder;
 import org.w3c.dom.Element;
 
 import org.springframework.core.convert.converter.Converter;
@@ -241,6 +245,8 @@ public class OpenSamlAuthenticationProviderTests {
 		EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
 				TestSaml2X509Credentials.assertingPartyEncryptingCredential());
 		response.getEncryptedAssertions().add(encryptedAssertion);
+		TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
+				RELYING_PARTY_ENTITY_ID);
 		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyDecryptingCredential());
 		assertThatExceptionOfType(Saml2AuthenticationException.class)
 				.isThrownBy(() -> this.provider.authenticate(token))
@@ -255,6 +261,8 @@ public class OpenSamlAuthenticationProviderTests {
 		EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion,
 				TestSaml2X509Credentials.assertingPartyEncryptingCredential());
 		response.getEncryptedAssertions().add(encryptedAssertion);
+		TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
+				RELYING_PARTY_ENTITY_ID);
 		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(),
 				TestSaml2X509Credentials.relyingPartyDecryptingCredential());
 		this.provider.authenticate(token);
@@ -296,6 +304,8 @@ public class OpenSamlAuthenticationProviderTests {
 		EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
 				TestSaml2X509Credentials.assertingPartyEncryptingCredential());
 		response.getEncryptedAssertions().add(encryptedAssertion);
+		TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
+				RELYING_PARTY_ENTITY_ID);
 		Saml2AuthenticationToken token = token(serialize(response),
 				TestSaml2X509Credentials.relyingPartyVerifyingCredential());
 		assertThatExceptionOfType(Saml2AuthenticationException.class)
@@ -309,6 +319,8 @@ public class OpenSamlAuthenticationProviderTests {
 		EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(TestOpenSamlObjects.assertion(),
 				TestSaml2X509Credentials.assertingPartyEncryptingCredential());
 		response.getEncryptedAssertions().add(encryptedAssertion);
+		TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
+				RELYING_PARTY_ENTITY_ID);
 		Saml2AuthenticationToken token = token(serialize(response),
 				TestSaml2X509Credentials.assertingPartyPrivateCredential());
 		assertThatExceptionOfType(Saml2AuthenticationException.class)
@@ -324,6 +336,8 @@ public class OpenSamlAuthenticationProviderTests {
 		EncryptedAssertion encryptedAssertion = TestOpenSamlObjects.encrypted(assertion,
 				TestSaml2X509Credentials.assertingPartyEncryptingCredential());
 		response.getEncryptedAssertions().add(encryptedAssertion);
+		TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
+				RELYING_PARTY_ENTITY_ID);
 		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential(),
 				TestSaml2X509Credentials.relyingPartyDecryptingCredential());
 		Saml2Authentication authentication = (Saml2Authentication) this.provider.authenticate(token);
@@ -473,54 +487,48 @@ public class OpenSamlAuthenticationProviderTests {
 	}
 
 	@Test
-	public void setAssertionDecrypterWhenNullThenIllegalArgument() {
-		assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionDecrypter(null));
+	public void setResponseElementsDecrypterWhenNullThenIllegalArgument() {
+		assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setResponseElementsDecrypter(null));
 	}
 
 	@Test
-	public void setPrincipalDecrypterWhenNullThenIllegalArgument() {
-		assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setPrincipalDecrypter(null));
+	public void setAssertionElementsDecrypterWhenNullThenIllegalArgument() {
+		assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionElementsDecrypter(null));
 	}
 
 	@Test
-	public void setAssertionDecrypterThenChangesAssertion() {
+	public void authenticateWhenCustomResponseElementsDecrypterThenDecryptsResponse() {
 		Response response = TestOpenSamlObjects.response();
 		Assertion assertion = TestOpenSamlObjects.assertion();
-		assertion.getSubject().getSubjectConfirmations()
-				.forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
 		TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
 				RELYING_PARTY_ENTITY_ID);
-		response.getAssertions().add(assertion);
+		response.getEncryptedAssertions().add(new EncryptedAssertionBuilder().buildObject());
+		TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
+				RELYING_PARTY_ENTITY_ID);
 		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
-		this.provider.setAssertionDecrypter(mockAssertionAndPrincipalDecrypter());
-		assertThatExceptionOfType(Saml2AuthenticationException.class)
-				.isThrownBy(() -> this.provider.authenticate(token))
-				.satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE));
-		assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
+		this.provider.setResponseElementsDecrypter((tuple) -> tuple.getResponse().getAssertions().add(assertion));
+		Authentication authentication = this.provider.authenticate(token);
+		assertThat(authentication.getName()).isEqualTo("test@saml.user");
 	}
 
 	@Test
-	public void setPrincipalDecrypterThenChangesAssertion() {
+	public void authenticateWhenCustomAssertionElementsDecrypterThenDecryptsAssertion() {
 		Response response = TestOpenSamlObjects.response();
 		Assertion assertion = TestOpenSamlObjects.assertion();
-		assertion.getSubject().getSubjectConfirmations()
-				.forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
+		EncryptedID id = new EncryptedIDBuilder().buildObject();
+		id.setEncryptedData(new EncryptedDataBuilder().buildObject());
+		assertion.getSubject().setEncryptedID(id);
 		TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
 				RELYING_PARTY_ENTITY_ID);
 		response.getAssertions().add(assertion);
 		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
-		this.provider.setPrincipalDecrypter(mockAssertionAndPrincipalDecrypter());
-		this.provider.authenticate(token);
-		assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
-	}
-
-	private Consumer<ResponseToken> mockAssertionAndPrincipalDecrypter() {
-		return (responseToken) -> {
-			responseToken.getResponse().getAssertions().clear();
-			responseToken.getResponse().getAssertions()
-					.add(TestOpenSamlObjects.signed(TestOpenSamlObjects.assertion("1", "2", "3", "4"),
-							TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID));
-		};
+		this.provider.setAssertionElementsDecrypter((tuple) -> {
+			NameID name = new NameIDBuilder().buildObject();
+			name.setValue("decrypted name");
+			tuple.getAssertion().getSubject().setNameID(name);
+		});
+		Authentication authentication = this.provider.authenticate(token);
+		assertThat(authentication.getName()).isEqualTo("decrypted name");
 	}
 
 	private <T extends XMLObject> T build(QName qName) {