Browse Source

Polish InResponseTo support

- Moved methods so methods are listed before the methods they call
- Adjusted exception handling so no exceptions are eaten
- Adjusted so that malformed_request_data is returned with request data is malformed
- Refactored methods to have only immutable method parameters
- Removed usage of Stream API
- Moved AuthnRequestUnmarshaller into static block so that only looked
up once

Issue gh-9174
Josh Cummings 3 years ago
parent
commit
070514b9dd

+ 82 - 80
saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java

@@ -65,7 +65,6 @@ import org.opensaml.saml.saml2.core.EncryptedAssertion;
 import org.opensaml.saml.saml2.core.OneTimeUse;
 import org.opensaml.saml.saml2.core.Response;
 import org.opensaml.saml.saml2.core.StatusCode;
-import org.opensaml.saml.saml2.core.Subject;
 import org.opensaml.saml.saml2.core.SubjectConfirmation;
 import org.opensaml.saml.saml2.core.SubjectConfirmationData;
 import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
@@ -146,6 +145,13 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 
 	private final ResponseUnmarshaller responseUnmarshaller;
 
+	private static final AuthnRequestUnmarshaller authnRequestUnmarshaller;
+	static {
+		XMLObjectProviderRegistry registry = ConfigurationService.get(XMLObjectProviderRegistry.class);
+		authnRequestUnmarshaller = (AuthnRequestUnmarshaller) registry.getUnmarshallerFactory()
+				.getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
+	}
+
 	private final ParserPool parserPool;
 
 	private final Converter<ResponseToken, Saml2ResponseValidatorResult> responseSignatureValidator = createDefaultResponseSignatureValidator();
@@ -355,37 +361,6 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 		this.responseAuthenticationConverter = responseAuthenticationConverter;
 	}
 
-	private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,
-			String inResponseTo) {
-		if (!StringUtils.hasText(inResponseTo)) {
-			return Saml2ResponseValidatorResult.success();
-		}
-		AuthnRequest request;
-		try {
-			request = parseRequest(storedRequest);
-		}
-		catch (Exception ex) {
-			String message = "The stored AuthNRequest could not be properly deserialized [" + ex.getMessage() + "]";
-			return Saml2ResponseValidatorResult
-					.failure(new Saml2Error(Saml2ErrorCodes.MALFORMED_REQUEST_DATA, message));
-		}
-		if (request == null) {
-			String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]"
-					+ " but no saved AuthNRequest request was found";
-			return Saml2ResponseValidatorResult
-					.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
-		}
-		else if (!request.getID().equals(inResponseTo)) {
-			String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the "
-					+ "AuthNRequest [" + request.getID() + "]";
-			return Saml2ResponseValidatorResult
-					.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
-		}
-		else {
-			return Saml2ResponseValidatorResult.success();
-		}
-	}
-
 	/**
 	 * Construct a default strategy for validating the SAML 2.0 Response
 	 * @return the default response validator strategy
@@ -428,6 +403,27 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 		};
 	}
 
+	private static Saml2ResponseValidatorResult validateInResponseTo(AbstractSaml2AuthenticationRequest storedRequest,
+			String inResponseTo) {
+		if (!StringUtils.hasText(inResponseTo)) {
+			return Saml2ResponseValidatorResult.success();
+		}
+		AuthnRequest request = parseRequest(storedRequest);
+		if (request == null) {
+			String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]"
+					+ " but no saved authentication request was found";
+			return Saml2ResponseValidatorResult
+					.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
+		}
+		if (!inResponseTo.equals(request.getID())) {
+			String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the "
+					+ "authentication request [" + request.getID() + "]";
+			return Saml2ResponseValidatorResult
+					.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
+		}
+		return Saml2ResponseValidatorResult.success();
+	}
+
 	/**
 	 * Construct a default strategy for validating each SAML 2.0 Assertion and associated
 	 * {@link Authentication} token
@@ -522,28 +518,6 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 		}
 	}
 
-	private static AuthnRequest parseRequest(AbstractSaml2AuthenticationRequest request) throws Exception {
-		if (request == null) {
-			return null;
-		}
-		String samlRequest = request.getSamlRequest();
-		if (!StringUtils.hasText(samlRequest)) {
-			return null;
-		}
-		if (request.getBinding() == Saml2MessageBinding.REDIRECT) {
-			samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest));
-		}
-		else {
-			samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8);
-		}
-		Document document = XMLObjectProviderRegistrySupport.getParserPool()
-				.parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8)));
-		Element element = document.getDocumentElement();
-		AuthnRequestUnmarshaller unmarshaller = (AuthnRequestUnmarshaller) XMLObjectProviderRegistrySupport
-				.getUnmarshallerFactory().getUnmarshaller(AuthnRequest.DEFAULT_ELEMENT_NAME);
-		return (AuthnRequest) unmarshaller.unmarshall(element);
-	}
-
 	private void process(Saml2AuthenticationToken token, Response response) {
 		String issuer = response.getIssuer().getValue();
 		this.logger.debug(LogMessage.format("Processing SAML response from %s", issuer));
@@ -748,40 +722,18 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 		};
 	}
 
-	private static boolean assertionContainsInResponseTo(Assertion assertion) {
-		Subject subject = (assertion != null) ? assertion.getSubject() : null;
-		List<SubjectConfirmation> confirmations = (subject != null) ? subject.getSubjectConfirmations()
-				: new ArrayList<>();
-		return confirmations.stream().filter((confirmation) -> {
-			SubjectConfirmationData confirmationData = confirmation.getSubjectConfirmationData();
-			return confirmationData != null && StringUtils.hasText(confirmationData.getInResponseTo());
-		}).findFirst().orElse(null) != null;
-	}
-
-	private static void addRequestIdToValidationContext(AbstractSaml2AuthenticationRequest storedRequest,
-			Map<String, Object> context) {
-		String requestId = null;
-		try {
-			AuthnRequest request = parseRequest(storedRequest);
-			requestId = (request != null) ? request.getID() : null;
-		}
-		catch (Exception ex) {
-		}
-		if (StringUtils.hasText(requestId)) {
-			context.put(SAML2AssertionValidationParameters.SC_VALID_IN_RESPONSE_TO, requestId);
-		}
-	}
-
 	private static ValidationContext createValidationContext(AssertionToken assertionToken,
 			Consumer<Map<String, Object>> paramsConsumer) {
-		RelyingPartyRegistration relyingPartyRegistration = assertionToken.token.getRelyingPartyRegistration();
+		Saml2AuthenticationToken token = assertionToken.token;
+		RelyingPartyRegistration relyingPartyRegistration = token.getRelyingPartyRegistration();
 		String audience = relyingPartyRegistration.getEntityId();
 		String recipient = relyingPartyRegistration.getAssertionConsumerServiceLocation();
 		String assertingPartyEntityId = relyingPartyRegistration.getAssertingPartyDetails().getEntityId();
 		Map<String, Object> params = new HashMap<>();
 		Assertion assertion = assertionToken.getAssertion();
 		if (assertionContainsInResponseTo(assertion)) {
-			addRequestIdToValidationContext(assertionToken.token.getAuthenticationRequest(), params);
+			String requestId = getAuthnRequestId(token.getAuthenticationRequest());
+			params.put(SAML2AssertionValidationParameters.SC_VALID_IN_RESPONSE_TO, requestId);
 		}
 		params.put(SAML2AssertionValidationParameters.COND_VALID_AUDIENCES, Collections.singleton(audience));
 		params.put(SAML2AssertionValidationParameters.SC_VALID_RECIPIENTS, Collections.singleton(recipient));
@@ -790,6 +742,56 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 		return new ValidationContext(params);
 	}
 
+	private static boolean assertionContainsInResponseTo(Assertion assertion) {
+		if (assertion.getSubject() == null) {
+			return false;
+		}
+		for (SubjectConfirmation confirmation : assertion.getSubject().getSubjectConfirmations()) {
+			SubjectConfirmationData confirmationData = confirmation.getSubjectConfirmationData();
+			if (confirmationData == null) {
+				continue;
+			}
+			if (StringUtils.hasText(confirmationData.getInResponseTo())) {
+				return true;
+			}
+		}
+		return false;
+	}
+
+	private static String getAuthnRequestId(AbstractSaml2AuthenticationRequest serialized) {
+		AuthnRequest request = parseRequest(serialized);
+		if (request == null) {
+			return null;
+		}
+		return request.getID();
+	}
+
+	private static AuthnRequest parseRequest(AbstractSaml2AuthenticationRequest request) {
+		if (request == null) {
+			return null;
+		}
+		String samlRequest = request.getSamlRequest();
+		if (!StringUtils.hasText(samlRequest)) {
+			return null;
+		}
+		if (request.getBinding() == Saml2MessageBinding.REDIRECT) {
+			samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest));
+		}
+		else {
+			samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8);
+		}
+		try {
+			Document document = XMLObjectProviderRegistrySupport.getParserPool()
+					.parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8)));
+			Element element = document.getDocumentElement();
+			return (AuthnRequest) authnRequestUnmarshaller.unmarshall(element);
+		}
+		catch (Exception ex) {
+			String message = "Failed to deserialize associated authentication request [" + ex.getMessage() + "]";
+			throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_REQUEST_DATA, message, ex);
+		}
+	}
+
 	private static class SAML20AssertionValidators {
 
 		private static final Collection<ConditionValidator> conditions = new ArrayList<>();

+ 1 - 1
saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java

@@ -252,7 +252,7 @@ public class OpenSaml4AuthenticationProviderTests {
 				Saml2MessageBinding.POST, true);
 		Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
 		assertThatExceptionOfType(Saml2AuthenticationException.class)
-				.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion");
+				.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("malformed_request_data");
 	}
 
 	@Test