Procházet zdrojové kódy

Don't use raw xml saml authentication request for response validation

closes gh-12961
Liviu Gheorghe před 2 roky
rodič
revize
7e305dd003

+ 5 - 38
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProvider.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -37,7 +37,6 @@ import org.apache.commons.logging.LogFactory;
 import org.opensaml.core.config.ConfigurationService;
 import org.opensaml.core.xml.XMLObject;
 import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
-import org.opensaml.core.xml.config.XMLObjectProviderRegistrySupport;
 import org.opensaml.core.xml.schema.XSAny;
 import org.opensaml.core.xml.schema.XSBoolean;
 import org.opensaml.core.xml.schema.XSBooleanValue;
@@ -89,7 +88,6 @@ import org.springframework.security.saml2.core.Saml2Error;
 import org.springframework.security.saml2.core.Saml2ErrorCodes;
 import org.springframework.security.saml2.core.Saml2ResponseValidatorResult;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
-import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
 import org.springframework.util.Assert;
 import org.springframework.util.CollectionUtils;
 import org.springframework.util.LinkedMultiValueMap;
@@ -410,16 +408,15 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 		if (!StringUtils.hasText(inResponseTo)) {
 			return Saml2ResponseValidatorResult.success();
 		}
-		AuthnRequest request = parseRequest(storedRequest);
-		if (request == null) {
+		if (storedRequest == null) {
 			String message = "The response contained an InResponseTo attribute [" + inResponseTo + "]"
 					+ " but no saved authentication request was found";
 			return Saml2ResponseValidatorResult
 					.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
 		}
-		if (!inResponseTo.equals(request.getID())) {
+		if (!inResponseTo.equals(storedRequest.getId())) {
 			String message = "The InResponseTo attribute [" + inResponseTo + "] does not match the ID of the "
-					+ "authentication request [" + request.getID() + "]";
+					+ "authentication request [" + storedRequest.getId() + "]";
 			return Saml2ResponseValidatorResult
 					.failure(new Saml2Error(Saml2ErrorCodes.INVALID_IN_RESPONSE_TO, message));
 		}
@@ -776,37 +773,7 @@ public final class OpenSaml4AuthenticationProvider implements AuthenticationProv
 	}
 
 	private static String getAuthnRequestId(AbstractSaml2AuthenticationRequest serialized) {
-		AuthnRequest request = parseRequest(serialized);
-		if (request == null) {
-			return null;
-		}
-		return request.getID();
-	}
-
-	private static AuthnRequest parseRequest(AbstractSaml2AuthenticationRequest request) {
-		if (request == null) {
-			return null;
-		}
-		String samlRequest = request.getSamlRequest();
-		if (!StringUtils.hasText(samlRequest)) {
-			return null;
-		}
-		if (request.getBinding() == Saml2MessageBinding.REDIRECT) {
-			samlRequest = Saml2Utils.samlInflate(Saml2Utils.samlDecode(samlRequest));
-		}
-		else {
-			samlRequest = new String(Saml2Utils.samlDecode(samlRequest), StandardCharsets.UTF_8);
-		}
-		try {
-			Document document = XMLObjectProviderRegistrySupport.getParserPool()
-					.parse(new ByteArrayInputStream(samlRequest.getBytes(StandardCharsets.UTF_8)));
-			Element element = document.getDocumentElement();
-			return (AuthnRequest) authnRequestUnmarshaller.unmarshall(element);
-		}
-		catch (Exception ex) {
-			String message = "Failed to deserialize associated authentication request [" + ex.getMessage() + "]";
-			throw createAuthenticationException(Saml2ErrorCodes.MALFORMED_REQUEST_DATA, message, ex);
-		}
+		return (serialized != null) ? serialized.getId() : null;
 	}
 
 	private static class SAML20AssertionValidators {

+ 9 - 64
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationProviderTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,7 +19,6 @@ package org.springframework.security.saml2.provider.service.authentication;
 import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.ObjectOutputStream;
-import java.nio.charset.StandardCharsets;
 import java.time.Duration;
 import java.time.Instant;
 import java.util.Arrays;
@@ -48,7 +47,6 @@ import org.opensaml.saml.saml2.core.Assertion;
 import org.opensaml.saml.saml2.core.Attribute;
 import org.opensaml.saml.saml2.core.AttributeStatement;
 import org.opensaml.saml.saml2.core.AttributeValue;
-import org.opensaml.saml.saml2.core.AuthnRequest;
 import org.opensaml.saml.saml2.core.Conditions;
 import org.opensaml.saml.saml2.core.EncryptedAssertion;
 import org.opensaml.saml.saml2.core.EncryptedAttribute;
@@ -78,7 +76,6 @@ import org.springframework.security.saml2.core.TestSaml2X509Credentials;
 import org.springframework.security.saml2.provider.service.authentication.OpenSaml4AuthenticationProvider.ResponseToken;
 import org.springframework.security.saml2.provider.service.authentication.TestCustomOpenSamlObjects.CustomOpenSamlObject;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
-import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
 import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
 import org.springframework.util.StringUtils;
 
@@ -228,8 +225,7 @@ public class OpenSaml4AuthenticationProviderTests {
 		response.setInResponseTo("SAML2");
 		response.getAssertions().add(signed(assertion("SAML2")));
 		response.getAssertions().add(signed(assertion("SAML2")));
-		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
-				Saml2MessageBinding.POST, false);
+		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
 		Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
 		this.provider.authenticate(token);
 	}
@@ -239,32 +235,18 @@ public class OpenSaml4AuthenticationProviderTests {
 		Response response = response();
 		response.getAssertions().add(signed(assertion()));
 		response.getAssertions().add(signed(assertion("SAML2")));
-		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
-				Saml2MessageBinding.POST, false);
+		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
 		Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
 		this.provider.authenticate(token);
 	}
 
-	@Test
-	public void evaluateInResponseToFailsWhenInResponseToInAssertionOnlyAndCorruptedStoredRequest() {
-		Response response = response();
-		response.getAssertions().add(signed(assertion()));
-		response.getAssertions().add(signed(assertion("SAML2")));
-		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
-				Saml2MessageBinding.POST, true);
-		Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
-		assertThatExceptionOfType(Saml2AuthenticationException.class)
-				.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("malformed_request_data");
-	}
-
 	@Test
 	public void evaluateInResponseToFailsWhenInResponseToInAssertionMismatchWithRequestID() {
 		Response response = response();
 		response.setInResponseTo("SAML2");
 		response.getAssertions().add(signed(assertion("SAML2")));
 		response.getAssertions().add(signed(assertion("BAD")));
-		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
-				Saml2MessageBinding.POST, false);
+		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
 		Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
 		assertThatExceptionOfType(Saml2AuthenticationException.class)
 				.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion");
@@ -275,8 +257,7 @@ public class OpenSaml4AuthenticationProviderTests {
 		Response response = response();
 		response.getAssertions().add(signed(assertion()));
 		response.getAssertions().add(signed(assertion("BAD")));
-		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
-				Saml2MessageBinding.POST, false);
+		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
 		Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
 		assertThatExceptionOfType(Saml2AuthenticationException.class)
 				.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_assertion");
@@ -288,26 +269,12 @@ public class OpenSaml4AuthenticationProviderTests {
 		response.setInResponseTo("BAD");
 		response.getAssertions().add(signed(assertion("SAML2")));
 		response.getAssertions().add(signed(assertion("SAML2")));
-		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
-				Saml2MessageBinding.POST, false);
+		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
 		Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
 		assertThatExceptionOfType(Saml2AuthenticationException.class)
 				.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("invalid_in_response_to");
 	}
 
-	@Test
-	public void evaluateInResponseToFailsWhenInResponseInToResponseAndCorruptedStoredRequest() {
-		Response response = response();
-		response.setInResponseTo("SAML2");
-		response.getAssertions().add(signed(assertion()));
-		response.getAssertions().add(signed(assertion()));
-		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
-				Saml2MessageBinding.POST, true);
-		Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
-		assertThatExceptionOfType(Saml2AuthenticationException.class)
-				.isThrownBy(() -> this.provider.authenticate(token)).withStackTraceContaining("malformed_request_data");
-	}
-
 	@Test
 	public void evaluateInResponseToFailsWhenInResponseToInResponseButNoSavedRequest() {
 		Response response = response();
@@ -321,8 +288,7 @@ public class OpenSaml4AuthenticationProviderTests {
 	public void evaluateInResponseToSucceedsWhenNoInResponseToInResponseOrAssertions() {
 		Response response = response();
 		response.getAssertions().add(signed(assertion()));
-		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2",
-				Saml2MessageBinding.POST, false);
+		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mockedStoredAuthenticationRequest("SAML2");
 		Saml2AuthenticationToken token = token(response, verifying(registration()), mockAuthenticationRequest);
 		this.provider.authenticate(token);
 	}
@@ -805,17 +771,6 @@ public class OpenSaml4AuthenticationProviderTests {
 		return response;
 	}
 
-	private AuthnRequest request() {
-		AuthnRequest request = TestOpenSamlObjects.authnRequest();
-		return request;
-	}
-
-	private String serializedRequest(AuthnRequest request, Saml2MessageBinding binding) {
-		String xml = serialize(request);
-		return (binding == Saml2MessageBinding.POST) ? Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8))
-				: Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
-	}
-
 	private Assertion assertion(String inResponseTo) {
 		Assertion assertion = TestOpenSamlObjects.assertion();
 		assertion.setIssueInstant(Instant.now());
@@ -871,19 +826,9 @@ public class OpenSaml4AuthenticationProviderTests {
 		return new Saml2AuthenticationToken(registration.build(), serialize(response), authenticationRequest);
 	}
 
-	private AbstractSaml2AuthenticationRequest mockedStoredAuthenticationRequest(String requestId,
-			Saml2MessageBinding binding, boolean corruptRequestString) {
-		AuthnRequest request = request();
-		if (requestId != null) {
-			request.setID(requestId);
-		}
-		String serializedRequest = serializedRequest(request, binding);
-		if (corruptRequestString) {
-			serializedRequest = serializedRequest.substring(2, serializedRequest.length() - 2);
-		}
+	private AbstractSaml2AuthenticationRequest mockedStoredAuthenticationRequest(String requestId) {
 		AbstractSaml2AuthenticationRequest mockAuthenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
-		given(mockAuthenticationRequest.getSamlRequest()).willReturn(serializedRequest);
-		given(mockAuthenticationRequest.getBinding()).willReturn(binding);
+		given(mockAuthenticationRequest.getId()).willReturn(requestId);
 		return mockAuthenticationRequest;
 	}