浏览代码

Add Configurable SAML Response Decryption

Closes gh-9044
ryan.cassar 4 年之前
父节点
当前提交
535ae3e27d

+ 83 - 25
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProvider.java

@@ -157,6 +157,7 @@ import org.springframework.util.StringUtils;
  * asserting party, IDP, verification certificates.
  * asserting party, IDP, verification certificates.
  * </p>
  * </p>
  *
  *
+ * @author Ryan Cassar
  * @since 5.2
  * @since 5.2
  * @see <a href=
  * @see <a href=
  * "https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=38">SAML 2
  * "https://docs.oasis-open.org/security/saml/v2.0/saml-core-2.0-os.pdf#page=38">SAML 2
@@ -211,6 +212,32 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 
 
 	private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
 	private Converter<Saml2AuthenticationToken, Decrypter> decrypterConverter = new DecrypterConverter();
 
 
+	private Consumer<ResponseToken> assertionDecrypter = (responseToken) -> {
+		List<Assertion> assertions = new ArrayList<>();
+		for (EncryptedAssertion encryptedAssertion : responseToken.getResponse().getEncryptedAssertions()) {
+			try {
+				Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
+				Assertion assertion = decrypter.decrypt(encryptedAssertion);
+				assertions.add(assertion);
+			}
+			catch (DecryptionException ex) {
+				throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
+			}
+		}
+		responseToken.getResponse().getAssertions().addAll(assertions);
+	};
+
+	private Consumer<ResponseToken> principalDecrypter = (responseToken) -> {
+		try {
+			Decrypter decrypter = this.decrypterConverter.convert(responseToken.getToken());
+			Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
+			assertion.getSubject().setNameID((NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID()));
+		}
+		catch (DecryptionException ex) {
+			throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
+		}
+	};
+
 	/**
 	/**
 	 * Creates an {@link OpenSamlAuthenticationProvider}
 	 * Creates an {@link OpenSamlAuthenticationProvider}
 	 */
 	 */
@@ -332,6 +359,52 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		this.responseTimeValidationSkew = responseTimeValidationSkew;
 		this.responseTimeValidationSkew = responseTimeValidationSkew;
 	}
 	}
 
 
+	/**
+	 * Sets the assertion response custom decrypter.
+	 *
+	 * You can use this method like so:
+	 *
+	 * <pre>
+	 *	YourDecrypter decrypter = // ... your custom decrypter
+	 *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	provider.setAssertionDecrypter((responseToken) -> {
+	 *		Response response = responseToken.getResponse();
+	 *  	EncryptedAssertion encrypted = response.getEncryptedAssertions().get(0);
+	 *  	Assertion assertion = decrypter.decrypt(encrypted);
+	 *  	response.getAssertions().add(assertion);
+	 *	});
+	 * </pre>
+	 * @param assertionDecrypter response token consumer
+	 */
+	public void setAssertionDecrypter(Consumer<ResponseToken> assertionDecrypter) {
+		Assert.notNull(assertionDecrypter, "Consumer<ResponseToken> required");
+		this.assertionDecrypter = assertionDecrypter;
+	}
+
+	/**
+	 * Sets the principal custom decrypter.
+	 *
+	 * You can use this method like so:
+	 *
+	 * <pre>
+	 *	YourDecrypter decrypter = // ... your custom decrypter
+	 *
+	 *	OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
+	 *	provider.setAssertionDecrypter((responseToken) -> {
+	 *		Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
+	 *		EncryptedID encrypted = assertion.getSubject().getEncryptedID();
+	 *		NameID name = decrypter.decrypt(encrypted);
+	 *		assertion.getSubject().setNameID(name)
+	 *	});
+	 * </pre>
+	 * @param principalDecrypter response token consumer
+	 */
+	public void setPrincipalDecrypter(Consumer<ResponseToken> principalDecrypter) {
+		Assert.notNull(principalDecrypter, "Consumer<ResponseToken> required");
+		this.principalDecrypter = principalDecrypter;
+	}
+
 	/**
 	/**
 	 * Construct a default strategy for validating each SAML 2.0 Assertion and associated
 	 * Construct a default strategy for validating each SAML 2.0 Assertion and associated
 	 * {@link Authentication} token
 	 * {@link Authentication} token
@@ -429,8 +502,8 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		boolean responseSigned = response.isSigned();
 		boolean responseSigned = response.isSigned();
 		Saml2ResponseValidatorResult result = validateResponse(token, response);
 		Saml2ResponseValidatorResult result = validateResponse(token, response);
 
 
-		Decrypter decrypter = this.decrypterConverter.convert(token);
-		List<Assertion> assertions = decryptAssertions(decrypter, response);
+		ResponseToken responseToken = new ResponseToken(response, token);
+		List<Assertion> assertions = decryptAssertions(responseToken);
 		if (!isSigned(responseSigned, assertions)) {
 		if (!isSigned(responseSigned, assertions)) {
 			String description = "Either the response or one of the assertions is unsigned. "
 			String description = "Either the response or one of the assertions is unsigned. "
 					+ "Please either sign the response or all of the assertions.";
 					+ "Please either sign the response or all of the assertions.";
@@ -439,7 +512,7 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		result = result.concat(validateAssertions(token, response));
 		result = result.concat(validateAssertions(token, response));
 
 
 		Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
 		Assertion firstAssertion = CollectionUtils.firstElement(response.getAssertions());
-		NameID nameId = decryptPrincipal(decrypter, firstAssertion);
+		NameID nameId = decryptPrincipal(responseToken);
 		if (nameId == null || nameId.getValue() == null) {
 		if (nameId == null || nameId.getValue() == null) {
 			Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
 			Saml2Error error = new Saml2Error(Saml2ErrorCodes.SUBJECT_NOT_FOUND,
 					"Assertion [" + firstAssertion.getID() + "] is missing a subject");
 					"Assertion [" + firstAssertion.getID() + "] is missing a subject");
@@ -511,19 +584,9 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		return Saml2ResponseValidatorResult.failure(errors);
 		return Saml2ResponseValidatorResult.failure(errors);
 	}
 	}
 
 
-	private List<Assertion> decryptAssertions(Decrypter decrypter, Response response) {
-		List<Assertion> assertions = new ArrayList<>();
-		for (EncryptedAssertion encryptedAssertion : response.getEncryptedAssertions()) {
-			try {
-				Assertion assertion = decrypter.decrypt(encryptedAssertion);
-				assertions.add(assertion);
-			}
-			catch (DecryptionException ex) {
-				throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
-			}
-		}
-		response.getAssertions().addAll(assertions);
-		return response.getAssertions();
+	private List<Assertion> decryptAssertions(ResponseToken response) {
+		this.assertionDecrypter.accept(response);
+		return response.getResponse().getAssertions();
 	}
 	}
 
 
 	private Saml2ResponseValidatorResult validateAssertions(Saml2AuthenticationToken token, Response response) {
 	private Saml2ResponseValidatorResult validateAssertions(Saml2AuthenticationToken token, Response response) {
@@ -567,21 +630,16 @@ public final class OpenSamlAuthenticationProvider implements AuthenticationProvi
 		return true;
 		return true;
 	}
 	}
 
 
-	private NameID decryptPrincipal(Decrypter decrypter, Assertion assertion) {
+	private NameID decryptPrincipal(ResponseToken responseToken) {
+		Assertion assertion = CollectionUtils.firstElement(responseToken.getResponse().getAssertions());
 		if (assertion.getSubject() == null) {
 		if (assertion.getSubject() == null) {
 			return null;
 			return null;
 		}
 		}
 		if (assertion.getSubject().getEncryptedID() == null) {
 		if (assertion.getSubject().getEncryptedID() == null) {
 			return assertion.getSubject().getNameID();
 			return assertion.getSubject().getNameID();
 		}
 		}
-		try {
-			NameID nameId = (NameID) decrypter.decrypt(assertion.getSubject().getEncryptedID());
-			assertion.getSubject().setNameID(nameId);
-			return nameId;
-		}
-		catch (DecryptionException ex) {
-			throw createAuthenticationException(Saml2ErrorCodes.DECRYPTION_ERROR, ex.getMessage(), ex);
-		}
+		this.principalDecrypter.accept(responseToken);
+		return assertion.getSubject().getNameID();
 	}
 	}
 
 
 	private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {
 	private static Map<String, List<Object>> getAssertionAttributes(Assertion assertion) {

+ 54 - 4
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationProviderTests.java

@@ -56,6 +56,7 @@ import org.springframework.security.saml2.core.Saml2Error;
 import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
 import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
 import org.springframework.security.saml2.credentials.Saml2X509Credential;
 import org.springframework.security.saml2.credentials.Saml2X509Credential;
 import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
 import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
+import org.springframework.security.saml2.provider.service.authentication.OpenSamlAuthenticationProvider.ResponseToken;
 import org.springframework.util.StringUtils;
 import org.springframework.util.StringUtils;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
@@ -446,8 +447,7 @@ public class OpenSamlAuthenticationProviderTests {
 	public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() {
 	public void createDefaultResponseAuthenticationConverterWhenResponseThenConverts() {
 		Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
 		Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
 		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
 		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
-		OpenSamlAuthenticationProvider.ResponseToken responseToken = new OpenSamlAuthenticationProvider.ResponseToken(
-				response, token);
+		ResponseToken responseToken = new ResponseToken(response, token);
 		Saml2Authentication authentication = OpenSamlAuthenticationProvider
 		Saml2Authentication authentication = OpenSamlAuthenticationProvider
 				.createDefaultResponseAuthenticationConverter().convert(responseToken);
 				.createDefaultResponseAuthenticationConverter().convert(responseToken);
 		assertThat(authentication.getName()).isEqualTo("test@saml.user");
 		assertThat(authentication.getName()).isEqualTo("test@saml.user");
@@ -455,8 +455,7 @@ public class OpenSamlAuthenticationProviderTests {
 
 
 	@Test
 	@Test
 	public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() {
 	public void authenticateWhenResponseAuthenticationConverterConfiguredThenUses() {
-		Converter<OpenSamlAuthenticationProvider.ResponseToken, Saml2Authentication> authenticationConverter = mock(
-				Converter.class);
+		Converter<ResponseToken, Saml2Authentication> authenticationConverter = mock(Converter.class);
 		OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
 		OpenSamlAuthenticationProvider provider = new OpenSamlAuthenticationProvider();
 		provider.setResponseAuthenticationConverter(authenticationConverter);
 		provider.setResponseAuthenticationConverter(authenticationConverter);
 		Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
 		Response response = TestOpenSamlObjects.signedResponseWithOneAssertion();
@@ -473,6 +472,57 @@ public class OpenSamlAuthenticationProviderTests {
 		// @formatter:on
 		// @formatter:on
 	}
 	}
 
 
+	@Test
+	public void setAssertionDecrypterWhenNullThenIllegalArgument() {
+		assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setAssertionDecrypter(null));
+	}
+
+	@Test
+	public void setPrincipalDecrypterWhenNullThenIllegalArgument() {
+		assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setPrincipalDecrypter(null));
+	}
+
+	@Test
+	public void setAssertionDecrypterThenChangesAssertion() {
+		Response response = TestOpenSamlObjects.response();
+		Assertion assertion = TestOpenSamlObjects.assertion();
+		assertion.getSubject().getSubjectConfirmations()
+				.forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
+		TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
+				RELYING_PARTY_ENTITY_ID);
+		response.getAssertions().add(assertion);
+		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
+		this.provider.setAssertionDecrypter(mockAssertionAndPrincipalDecrypter());
+		assertThatExceptionOfType(Saml2AuthenticationException.class)
+				.isThrownBy(() -> this.provider.authenticate(token))
+				.satisfies(errorOf(Saml2ErrorCodes.INVALID_SIGNATURE));
+		assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
+	}
+
+	@Test
+	public void setPrincipalDecrypterThenChangesAssertion() {
+		Response response = TestOpenSamlObjects.response();
+		Assertion assertion = TestOpenSamlObjects.assertion();
+		assertion.getSubject().getSubjectConfirmations()
+				.forEach((sc) -> sc.getSubjectConfirmationData().setAddress("10.10.10.10"));
+		TestOpenSamlObjects.signed(assertion, TestSaml2X509Credentials.assertingPartySigningCredential(),
+				RELYING_PARTY_ENTITY_ID);
+		response.getAssertions().add(assertion);
+		Saml2AuthenticationToken token = token(response, TestSaml2X509Credentials.relyingPartyVerifyingCredential());
+		this.provider.setPrincipalDecrypter(mockAssertionAndPrincipalDecrypter());
+		this.provider.authenticate(token);
+		assertThat(response.getAssertions().get(0).equals(TestOpenSamlObjects.assertion("1", "2", "3", "4")));
+	}
+
+	private Consumer<ResponseToken> mockAssertionAndPrincipalDecrypter() {
+		return (responseToken) -> {
+			responseToken.getResponse().getAssertions().clear();
+			responseToken.getResponse().getAssertions()
+					.add(TestOpenSamlObjects.signed(TestOpenSamlObjects.assertion("1", "2", "3", "4"),
+							TestSaml2X509Credentials.assertingPartySigningCredential(), RELYING_PARTY_ENTITY_ID));
+		};
+	}
+
 	private <T extends XMLObject> T build(QName qName) {
 	private <T extends XMLObject> T build(QName qName) {
 		return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName);
 		return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName);
 	}
 	}