Переглянути джерело

Merge branch '6.1.x'

Closes gh-14039
Marcus Da Coregio 1 рік тому
батько
коміт
10c85ccd29

+ 6 - 4
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java

@@ -176,10 +176,12 @@ class OpenSamlAuthenticationRequestResolver {
 				.id(authnRequest.getID());
 			if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()
 					|| registration.isAuthnRequestsSigned()) {
-				Map<String, String> parameters = OpenSamlSigningUtils.sign(registration)
-					.param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded)
-					.param(Saml2ParameterNames.RELAY_STATE, relayState)
-					.parameters();
+				OpenSamlSigningUtils.QueryParametersPartial parametersPartial = OpenSamlSigningUtils.sign(registration)
+					.param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded);
+				if (relayState != null) {
+					parametersPartial = parametersPartial.param(Saml2ParameterNames.RELAY_STATE, relayState);
+				}
+				Map<String, String> parameters = parametersPartial.parameters();
 				builder.sigAlg(parameters.get(Saml2ParameterNames.SIG_ALG))
 					.signature(parameters.get(Saml2ParameterNames.SIGNATURE));
 			}

+ 61 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java

@@ -23,10 +23,13 @@ import org.junit.jupiter.api.Test;
 import org.junit.jupiter.params.ParameterizedTest;
 import org.junit.jupiter.params.provider.Arguments;
 import org.junit.jupiter.params.provider.MethodSource;
+import org.mockito.Answers;
+import org.mockito.MockedStatic;
 import org.opensaml.xmlsec.signature.support.SignatureConstants;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.security.saml2.Saml2Exception;
+import org.springframework.security.saml2.core.Saml2ParameterNames;
 import org.springframework.security.saml2.core.Saml2X509Credential;
 import org.springframework.security.saml2.core.TestSaml2X509Credentials;
 import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
@@ -39,6 +42,12 @@ import org.springframework.security.saml2.provider.service.web.RelyingPartyRegis
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mockStatic;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
 
 /**
  * Tests for {@link OpenSamlAuthenticationRequestResolver}
@@ -198,6 +207,58 @@ public class OpenSamlAuthenticationRequestResolverTests {
 		assertThat(result.getId()).isNotEmpty();
 	}
 
+	@Test
+	public void resolveAuthenticationRequestWhenSignedAndRelayStateIsNullThenSignsWithoutRelayState() {
+		try (MockedStatic<OpenSamlSigningUtils> openSamlSigningUtilsMockedStatic = mockStatic(
+				OpenSamlSigningUtils.class, Answers.CALLS_REAL_METHODS)) {
+			MockHttpServletRequest request = new MockHttpServletRequest();
+			request.setPathInfo("/saml2/authenticate/registration-id");
+			RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder
+				.assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(true))
+				.build();
+			OpenSamlSigningUtils.QueryParametersPartial queryParametersPartialSpy = spy(
+					new OpenSamlSigningUtils.QueryParametersPartial(registration));
+			openSamlSigningUtilsMockedStatic.when(() -> OpenSamlSigningUtils.sign(any()))
+				.thenReturn(queryParametersPartialSpy);
+			OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
+			resolver.setRelayStateResolver((source) -> null);
+			Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
+			});
+			assertThat(result.getSamlRequest()).isNotEmpty();
+			assertThat(result.getRelayState()).isNull();
+			assertThat(result.getSigAlg()).isNotNull();
+			assertThat(result.getSignature()).isNotNull();
+			assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
+			verify(queryParametersPartialSpy, never()).param(eq(Saml2ParameterNames.RELAY_STATE), any());
+		}
+	}
+
+	@Test
+	public void resolveAuthenticationRequestWhenSignedAndRelayStateIsEmptyThenSignsWithEmptyRelayState() {
+		try (MockedStatic<OpenSamlSigningUtils> openSamlSigningUtilsMockedStatic = mockStatic(
+				OpenSamlSigningUtils.class, Answers.CALLS_REAL_METHODS)) {
+			MockHttpServletRequest request = new MockHttpServletRequest();
+			request.setPathInfo("/saml2/authenticate/registration-id");
+			RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder
+				.assertingPartyDetails((party) -> party.wantAuthnRequestsSigned(true))
+				.build();
+			OpenSamlSigningUtils.QueryParametersPartial queryParametersPartialSpy = spy(
+					new OpenSamlSigningUtils.QueryParametersPartial(registration));
+			openSamlSigningUtilsMockedStatic.when(() -> OpenSamlSigningUtils.sign(any()))
+				.thenReturn(queryParametersPartialSpy);
+			OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
+			resolver.setRelayStateResolver((source) -> "");
+			Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
+			});
+			assertThat(result.getSamlRequest()).isNotEmpty();
+			assertThat(result.getRelayState()).isEmpty();
+			assertThat(result.getSigAlg()).isNotNull();
+			assertThat(result.getSignature()).isNotNull();
+			assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
+			verify(queryParametersPartialSpy).param(eq(Saml2ParameterNames.RELAY_STATE), eq(""));
+		}
+	}
+
 	private OpenSamlAuthenticationRequestResolver authenticationRequestResolver(RelyingPartyRegistration registration) {
 		return new OpenSamlAuthenticationRequestResolver((request, id) -> registration);
 	}