Przeglądaj źródła

Add support for AuthnRequestsSigned setting

closes gh-12604
Liviu Gheorghe 2 lat temu
rodzic
commit
21d919169a

+ 29 - 3
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java

@@ -86,6 +86,8 @@ public class RelyingPartyRegistration {
 
 	private final String nameIdFormat;
 
+	private final boolean authnRequestsSigned;
+
 	private final AssertingPartyDetails assertingPartyDetails;
 
 	private final Collection<Saml2X509Credential> decryptionX509Credentials;
@@ -95,7 +97,7 @@ public class RelyingPartyRegistration {
 	protected RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation,
 			Saml2MessageBinding assertionConsumerServiceBinding, String singleLogoutServiceLocation,
 			String singleLogoutServiceResponseLocation, Collection<Saml2MessageBinding> singleLogoutServiceBindings,
-			AssertingPartyDetails assertingPartyDetails, String nameIdFormat,
+			AssertingPartyDetails assertingPartyDetails, String nameIdFormat, boolean authnRequestsSigned,
 			Collection<Saml2X509Credential> decryptionX509Credentials,
 			Collection<Saml2X509Credential> signingX509Credentials) {
 		Assert.hasText(registrationId, "registrationId cannot be empty");
@@ -124,6 +126,7 @@ public class RelyingPartyRegistration {
 		this.singleLogoutServiceResponseLocation = singleLogoutServiceResponseLocation;
 		this.singleLogoutServiceBindings = Collections.unmodifiableList(new LinkedList<>(singleLogoutServiceBindings));
 		this.nameIdFormat = nameIdFormat;
+		this.authnRequestsSigned = authnRequestsSigned;
 		this.assertingPartyDetails = assertingPartyDetails;
 		this.decryptionX509Credentials = Collections.unmodifiableList(new LinkedList<>(decryptionX509Credentials));
 		this.signingX509Credentials = Collections.unmodifiableList(new LinkedList<>(signingX509Credentials));
@@ -281,6 +284,15 @@ public class RelyingPartyRegistration {
 		return this.nameIdFormat;
 	}
 
+	/**
+	 * Get the WantAuthnRequestsSigned setting
+	 * @return the WantAuthnRequestsSigned setting
+	 * @since 6.0
+	 */
+	public boolean isAuthnRequestsSigned() {
+		return authnRequestsSigned;
+	}
+
 	/**
 	 * Get the {@link Collection} of decryption {@link Saml2X509Credential}s associated
 	 * with this relying party
@@ -357,6 +369,7 @@ public class RelyingPartyRegistration {
 				.singleLogoutServiceResponseLocation(registration.getSingleLogoutServiceResponseLocation())
 				.singleLogoutServiceBindings((c) -> c.addAll(registration.getSingleLogoutServiceBindings()))
 				.nameIdFormat(registration.getNameIdFormat())
+				.authnRequestsSigned(registration.isAuthnRequestsSigned())
 				.assertingPartyDetails((assertingParty) -> assertingParty
 						.entityId(registration.getAssertingPartyDetails().getEntityId())
 						.wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned())
@@ -788,6 +801,8 @@ public class RelyingPartyRegistration {
 
 		private String nameIdFormat = null;
 
+		private boolean authnRequestsSigned = false;
+
 		private AssertingPartyDetails.Builder assertingPartyDetailsBuilder;
 
 		protected Builder(String registrationId, AssertingPartyDetails.Builder assertingPartyDetailsBuilder) {
@@ -974,6 +989,17 @@ public class RelyingPartyRegistration {
 			return this;
 		}
 
+		/**
+		 * Set the AuthnRequestsSigned setting
+		 * @param authnRequestsSigned
+		 * @return the {@link Builder} for further configuration
+		 * @since 6.0
+		 */
+		public Builder authnRequestsSigned(Boolean authnRequestsSigned) {
+			this.authnRequestsSigned = authnRequestsSigned;
+			return this;
+		}
+
 		/**
 		 * Apply this {@link Consumer} to further configure the Asserting Party details
 		 * @param assertingPartyDetails The {@link Consumer} to apply
@@ -1003,8 +1029,8 @@ public class RelyingPartyRegistration {
 			return new RelyingPartyRegistration(this.registrationId, this.entityId,
 					this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding,
 					this.singleLogoutServiceLocation, this.singleLogoutServiceResponseLocation,
-					this.singleLogoutServiceBindings, party, this.nameIdFormat, this.decryptionX509Credentials,
-					this.signingX509Credentials);
+					this.singleLogoutServiceBindings, party, this.nameIdFormat, this.authnRequestsSigned,
+					this.decryptionX509Credentials, this.signingX509Credentials);
 		}
 
 	}

+ 3 - 3
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -142,7 +142,7 @@ class OpenSamlAuthenticationRequestResolver {
 		String relayState = this.relayStateResolver.convert(request);
 		Saml2MessageBinding binding = registration.getAssertingPartyDetails().getSingleSignOnServiceBinding();
 		if (binding == Saml2MessageBinding.POST) {
-			if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
+			if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned() || registration.isAuthnRequestsSigned()) {
 				OpenSamlSigningUtils.sign(authnRequest, registration);
 			}
 			String xml = serialize(authnRequest);
@@ -156,7 +156,7 @@ class OpenSamlAuthenticationRequestResolver {
 			Saml2RedirectAuthenticationRequest.Builder builder = Saml2RedirectAuthenticationRequest
 					.withRelyingPartyRegistration(registration).samlRequest(deflatedAndEncoded).relayState(relayState)
 					.id(authnRequest.getID());
-			if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
+			if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned() || registration.isAuthnRequestsSigned()) {
 				Map<String, String> parameters = OpenSamlSigningUtils.sign(registration)
 						.param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded)
 						.param(Saml2ParameterNames.RELAY_STATE, relayState).parameters();

+ 3 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -30,6 +30,7 @@ public class RelyingPartyRegistrationTests {
 	public void withRelyingPartyRegistrationWorks() {
 		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration()
 				.nameIdFormat("format")
+				.authnRequestsSigned(true)
 				.assertingPartyDetails((a) -> a.singleSignOnServiceBinding(Saml2MessageBinding.POST))
 				.assertingPartyDetails((a) -> a.wantAuthnRequestsSigned(false))
 				.assertingPartyDetails((a) -> a.signingAlgorithms((algs) -> algs.add("alg")))
@@ -82,6 +83,7 @@ public class RelyingPartyRegistrationTests {
 		assertThat(copy.getAssertingPartyDetails().getSigningAlgorithms())
 				.isEqualTo(registration.getAssertingPartyDetails().getSigningAlgorithms());
 		assertThat(copy.getNameIdFormat()).isEqualTo(registration.getNameIdFormat());
+		assertThat(copy.isAuthnRequestsSigned()).isEqualTo(registration.isAuthnRequestsSigned());
 	}
 
 	@Test

+ 31 - 9
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -18,6 +18,9 @@ package org.springframework.security.saml2.provider.service.web.authentication;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
 import org.opensaml.xmlsec.signature.support.SignatureConstants;
 
 import org.springframework.mock.web.MockHttpServletRequest;
@@ -32,6 +35,8 @@ import org.springframework.security.saml2.provider.service.registration.TestRely
 import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers;
 import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver;
 
+import java.util.stream.Stream;
+
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 
@@ -47,11 +52,15 @@ public class OpenSamlAuthenticationRequestResolverTests {
 		this.relyingPartyRegistrationBuilder = TestRelyingPartyRegistrations.relyingPartyRegistration();
 	}
 
-	@Test
-	public void resolveAuthenticationRequestWhenSignedRedirectThenSignsAndRedirects() {
+	@ParameterizedTest
+	@MethodSource("provideSignRequestFlags")
+	public void resolveAuthenticationRequestWhenSignedRedirectThenSignsAndRedirects(boolean wantAuthRequestsSigned, boolean authnRequestsSigned) {
 		MockHttpServletRequest request = new MockHttpServletRequest();
 		request.setPathInfo("/saml2/authenticate/registration-id");
-		RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder.build();
+		RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder
+				.authnRequestsSigned(authnRequestsSigned)
+				.assertingPartyDetails(party -> party.wantAuthnRequestsSigned(wantAuthRequestsSigned))
+				.build();
 		OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
 		Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
 			UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
@@ -113,8 +122,9 @@ public class OpenSamlAuthenticationRequestResolverTests {
 	public void resolveAuthenticationRequestWhenUnsignedPostThenOnlyPosts() {
 		MockHttpServletRequest request = new MockHttpServletRequest();
 		request.setPathInfo("/saml2/authenticate/registration-id");
-		RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder.assertingPartyDetails(
-				(party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST).wantAuthnRequestsSigned(false))
+		RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder
+				.assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST).wantAuthnRequestsSigned(false))
+				.authnRequestsSigned(false)
 				.build();
 		OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
 		Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
@@ -134,12 +144,16 @@ public class OpenSamlAuthenticationRequestResolverTests {
 		assertThat(result.getId()).isNotEmpty();
 	}
 
-	@Test
-	public void resolveAuthenticationRequestWhenSignedPostThenSignsAndPosts() {
+	@ParameterizedTest
+	@MethodSource("provideSignRequestFlags")
+	public void resolveAuthenticationRequestWhenSignedPostThenSignsAndPosts(boolean wantAuthRequestsSigned, boolean authnRequestsSigned) {
 		MockHttpServletRequest request = new MockHttpServletRequest();
 		request.setPathInfo("/saml2/authenticate/registration-id");
 		RelyingPartyRegistration registration = this.relyingPartyRegistrationBuilder
-				.assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST)).build();
+				.authnRequestsSigned(authnRequestsSigned)
+				.assertingPartyDetails((party) -> party.singleSignOnServiceBinding(Saml2MessageBinding.POST)
+						.wantAuthnRequestsSigned(wantAuthRequestsSigned))
+				.build();
 		OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
 		Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
 			UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
@@ -180,4 +194,12 @@ public class OpenSamlAuthenticationRequestResolverTests {
 		return new OpenSamlAuthenticationRequestResolver((request, id) -> registration);
 	}
 
+	private static Stream<Arguments> provideSignRequestFlags() {
+		return Stream.of(
+				Arguments.of(true, true),
+				Arguments.of(true, false),
+				Arguments.of(false, true)
+		);
+	}
+
 }