2
0
Эх сурвалжийг харах

Allow Defining Custom SAML Response Validator

Add a setter method into OpenSaml4AuthenticationProvider that allows defining a custom ResponseValidator

Closes gh-9721
Marcus Hert da Coregio 4 жил өмнө
parent
commit
03ded987af

+ 22 - 1
docs/manual/src/docs/asciidoc/_includes/servlet/saml2/saml2-login.adoc

@@ -1271,8 +1271,29 @@ It's not required to call `OpenSaml4AuthenticationProvider` 's default authentic
 It returns a `Saml2AuthenticatedPrincipal` containing the attributes it extracted from `AttributeStatement` s as well as the single `ROLE_USER` authority.
 
 [[servlet-saml2login-opensamlauthenticationprovider-additionalvalidation]]
-==== Performing Additional Validation
+==== Performing Additional Response Validation
 
+`OpenSaml4AuthenticationProvider` validates the `Issuer` and `Destination` values right after decrypting the `Response`.
+You can customize the validation by extending the default validator concatenating with your own response validator, or you can replace it entirely with yours.
+
+For example, you can throw a custom exception with any additional information available in the `Response` object, like so:
+[source,java]
+----
+OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider();
+provider.setResponseValidator((responseToken) -> {
+	Saml2ResponseValidatorResult result = OpenSamlAuthenticationProvider
+		.createDefaultResponseValidator()
+		.convert(responseToken)
+		.concat(myCustomValidator.convert(responseToken));
+	if (!result.getErrors().isEmpty()) {
+		String inResponseTo = responseToken.getInResponseTo();
+		throw new CustomSaml2AuthenticationException(result, inResponseTo);
+	}
+	return result;
+});
+----
+
+==== Performing Additional Assertion Validation
 `OpenSaml4AuthenticationProvider` performs minimal validation on SAML 2.0 Assertions.
 After verifying the signature, it will:
 

+ 62 - 35
saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java

@@ -145,7 +145,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 
 	private Consumer<ResponseToken> responseElementsDecrypter = createDefaultResponseElementsDecrypter();
 
-	private final Converter<ResponseToken, Saml2ResponseValidatorResult> responseValidator = createDefaultResponseValidator();
+	private Converter<ResponseToken, Saml2ResponseValidatorResult> responseValidator = createDefaultResponseValidator();
 
 	private final Converter<AssertionToken, Saml2ResponseValidatorResult> assertionSignatureValidator = createDefaultAssertionSignatureValidator();
 
@@ -213,6 +213,28 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 		this.responseElementsDecrypter = responseElementsDecrypter;
 	}
 
+	/**
+	 * Set the {@link Converter} to use for validating the SAML 2.0 Response.
+	 *
+	 * You can still invoke the default validator by delegating to
+	 * {@link #createDefaultResponseValidator()}, like so:
+	 *
+	 * <pre>
+	 * OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider();
+	 * provider.setResponseValidator(responseToken -&gt; {
+	 * 		Saml2ResponseValidatorResult result = createDefaultResponseValidator()
+	 * 			.convert(responseToken)
+	 * 		return result.concat(myCustomValidator.convert(responseToken));
+	 * });
+	 * </pre>
+	 * @param responseValidator the {@link Converter} to use
+	 * @since 5.6
+	 */
+	public void setResponseValidator(Converter<ResponseToken, Saml2ResponseValidatorResult> responseValidator) {
+		Assert.notNull(responseValidator, "responseValidator cannot be null");
+		this.responseValidator = responseValidator;
+	}
+
 	/**
 	 * Set the {@link Converter} to use for validating each {@link Assertion} in the SAML
 	 * 2.0 Response.
@@ -326,6 +348,44 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 		this.responseAuthenticationConverter = responseAuthenticationConverter;
 	}
 
+	/**
+	 * Construct a default strategy for validating the SAML 2.0 Response
+	 * @return the default response validator strategy
+	 * @since 5.6
+	 */
+	public static Converter<ResponseToken, Saml2ResponseValidatorResult> createDefaultResponseValidator() {
+		return (responseToken) -> {
+			Response response = responseToken.getResponse();
+			Saml2AuthenticationToken token = responseToken.getToken();
+			Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
+			String statusCode = getStatusCode(response);
+			if (!StatusCode.SUCCESS.equals(statusCode)) {
+				String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
+						response.getID());
+				result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
+			}
+			String issuer = response.getIssuer().getValue();
+			String destination = response.getDestination();
+			String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
+			if (StringUtils.hasText(destination) && !destination.equals(location)) {
+				String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID()
+						+ "]";
+				result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message));
+			}
+			String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails()
+					.getEntityId();
+			if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
+				String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
+				result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message));
+			}
+			if (response.getAssertions().isEmpty()) {
+				throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA,
+						"No assertions found in response.", null);
+			}
+			return result;
+		};
+	}
+
 	/**
 	 * Construct a default strategy for validating each SAML 2.0 Assertion and associated
 	 * {@link Authentication} token
@@ -487,40 +547,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 		};
 	}
 
-	private Converter<ResponseToken, Saml2ResponseValidatorResult> createDefaultResponseValidator() {
-		return (responseToken) -> {
-			Response response = responseToken.getResponse();
-			Saml2AuthenticationToken token = responseToken.getToken();
-			Saml2ResponseValidatorResult result = Saml2ResponseValidatorResult.success();
-			String statusCode = getStatusCode(response);
-			if (!StatusCode.SUCCESS.equals(statusCode)) {
-				String message = String.format("Invalid status [%s] for SAML response [%s]", statusCode,
-						response.getID());
-				result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_RESPONSE, message));
-			}
-			String issuer = response.getIssuer().getValue();
-			String destination = response.getDestination();
-			String location = token.getRelyingPartyRegistration().getAssertionConsumerServiceLocation();
-			if (StringUtils.hasText(destination) && !destination.equals(location)) {
-				String message = "Invalid destination [" + destination + "] for SAML response [" + response.getID()
-						+ "]";
-				result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_DESTINATION, message));
-			}
-			String assertingPartyEntityId = token.getRelyingPartyRegistration().getAssertingPartyDetails()
-					.getEntityId();
-			if (!StringUtils.hasText(issuer) || !issuer.equals(assertingPartyEntityId)) {
-				String message = String.format("Invalid issuer [%s] for SAML response [%s]", issuer, response.getID());
-				result = result.concat(new Saml2Error(Saml2ErrorCodes.INVALID_ISSUER, message));
-			}
-			if (response.getAssertions().isEmpty()) {
-				throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_RESPONSE_DATA,
-						"No assertions found in response.", null);
-			}
-			return result;
-		};
-	}
-
-	private String getStatusCode(Response response) {
+	private static String getStatusCode(Response response) {
 		if (response.getStatus() == null) {
 			return StatusCode.SUCCESS;
 		}

+ 28 - 0
saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java

@@ -585,6 +585,34 @@ public class OpenSaml4AuthenticationProviderTests {
 		assertThat(authentication.getName()).isEqualTo("test@saml.user");
 	}
 
+	@Test
+	public void setResponseValidatorWhenNullThenIllegalArgument() {
+		assertThatIllegalArgumentException().isThrownBy(() -> this.provider.setResponseValidator(null));
+	}
+
+	@Test
+	public void authenticateWhenCustomResponseValidatorThenUses() {
+		Converter<OpenSaml4AuthenticationProvider.ResponseToken, Saml2ResponseValidatorResult> validator = mock(
+				Converter.class);
+		OpenSaml4AuthenticationProvider provider = new OpenSaml4AuthenticationProvider();
+		// @formatter:off
+		provider.setResponseValidator((responseToken) -> OpenSaml4AuthenticationProvider.createDefaultResponseValidator()
+				.convert(responseToken)
+				.concat(validator.convert(responseToken))
+		);
+		// @formatter:on
+		Response response = response();
+		Assertion assertion = assertion();
+		response.getAssertions().add(assertion);
+		TestOpenSamlObjects.signed(response, TestSaml2X509Credentials.assertingPartySigningCredential(),
+				ASSERTING_PARTY_ENTITY_ID);
+		Saml2AuthenticationToken token = token(response, verifying(registration()));
+		given(validator.convert(any(OpenSaml4AuthenticationProvider.ResponseToken.class)))
+				.willReturn(Saml2ResponseValidatorResult.success());
+		provider.authenticate(token);
+		verify(validator).convert(any(OpenSaml4AuthenticationProvider.ResponseToken.class));
+	}
+
 	private <T extends XMLObject> T build(QName qName) {
 		return (T) XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(qName).buildObject(qName);
 	}