瀏覽代碼

Add RelyingPartyRegstration#mutate

Closes gh-12841
Josh Cummings 2 年之前
父節點
當前提交
538db29bfe

+ 43 - 16
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java

@@ -130,6 +130,35 @@ public final class RelyingPartyRegistration {
 		this.signingX509Credentials = Collections.unmodifiableList(new LinkedList<>(signingX509Credentials));
 	}
 
+	/**
+	 * Copy the properties in this {@link RelyingPartyRegistration} into a {@link Builder}
+	 * @return a {@link Builder} based off of the properties in this
+	 * {@link RelyingPartyRegistration}
+	 * @since 6.1
+	 */
+	public Builder mutate() {
+		AssertingPartyDetails party = this.assertingPartyDetails;
+		return withRegistrationId(this.registrationId).entityId(this.entityId)
+				.signingX509Credentials((c) -> c.addAll(this.signingX509Credentials))
+				.decryptionX509Credentials((c) -> c.addAll(this.decryptionX509Credentials))
+				.assertionConsumerServiceLocation(this.assertionConsumerServiceLocation)
+				.assertionConsumerServiceBinding(this.assertionConsumerServiceBinding)
+				.singleLogoutServiceLocation(this.singleLogoutServiceLocation)
+				.singleLogoutServiceResponseLocation(this.singleLogoutServiceResponseLocation)
+				.singleLogoutServiceBindings((c) -> c.addAll(this.singleLogoutServiceBindings))
+				.nameIdFormat(this.nameIdFormat)
+				.assertingPartyDetails((assertingParty) -> assertingParty.entityId(party.getEntityId())
+						.wantAuthnRequestsSigned(party.getWantAuthnRequestsSigned())
+						.signingAlgorithms((algorithms) -> algorithms.addAll(party.getSigningAlgorithms()))
+						.verificationX509Credentials((c) -> c.addAll(party.getVerificationX509Credentials()))
+						.encryptionX509Credentials((c) -> c.addAll(party.getEncryptionX509Credentials()))
+						.singleSignOnServiceLocation(party.getSingleSignOnServiceLocation())
+						.singleSignOnServiceBinding(party.getSingleSignOnServiceBinding())
+						.singleLogoutServiceLocation(party.getSingleLogoutServiceLocation())
+						.singleLogoutServiceResponseLocation(party.getSingleLogoutServiceResponseLocation())
+						.singleLogoutServiceBinding(party.getSingleLogoutServiceBinding()));
+	}
+
 	/**
 	 * Get the unique registration id for this RP/AP pair
 	 * @return the unique registration id for this RP/AP pair
@@ -292,7 +321,7 @@ public final class RelyingPartyRegistration {
 	 */
 	public static Builder withRegistrationId(String registrationId) {
 		Assert.hasText(registrationId, "registrationId cannot be empty");
-		return new Builder(registrationId);
+		return new Builder(registrationId, new AssertingPartyDetails.Builder());
 	}
 
 	public static Builder withAssertingPartyDetails(AssertingPartyDetails assertingPartyDetails) {
@@ -315,7 +344,9 @@ public final class RelyingPartyRegistration {
 	 * object
 	 * @param registration the {@code RelyingPartyRegistration}
 	 * @return {@code Builder} to create a {@code RelyingPartyRegistration} object
+	 * @deprecated Use {@link #mutate()} instead
 	 */
+	@Deprecated(forRemoval = true, since = "6.1")
 	public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) {
 		Assert.notNull(registration, "registration cannot be null");
 		return withRegistrationId(registration.getRegistrationId()).entityId(registration.getEntityId())
@@ -736,9 +767,9 @@ public final class RelyingPartyRegistration {
 
 	}
 
-	public static final class Builder {
+	public static class Builder {
 
-		private Converter<AssertingPartyDetails, String> registrationId = AssertingPartyDetails::getEntityId;
+		private String registrationId;
 
 		private String entityId = "{baseUrl}/saml2/service-provider-metadata/{registrationId}";
 
@@ -760,13 +791,9 @@ public final class RelyingPartyRegistration {
 
 		private AssertingPartyDetails.Builder assertingPartyDetailsBuilder;
 
-		private Builder(String registrationId) {
-			this.registrationId = (party) -> registrationId;
-			this.assertingPartyDetailsBuilder = new AssertingPartyDetails.Builder();
-		}
-
-		Builder(AssertingPartyDetails.Builder builder) {
-			this.assertingPartyDetailsBuilder = builder;
+		protected Builder(String registrationId, AssertingPartyDetails.Builder assertingPartyDetailsBuilder) {
+			this.registrationId = registrationId;
+			this.assertingPartyDetailsBuilder = assertingPartyDetailsBuilder;
 		}
 
 		/**
@@ -775,7 +802,7 @@ public final class RelyingPartyRegistration {
 		 * @return this object
 		 */
 		public Builder registrationId(String id) {
-			this.registrationId = (party) -> id;
+			this.registrationId = id;
 			return this;
 		}
 
@@ -974,11 +1001,11 @@ public final class RelyingPartyRegistration {
 			}
 
 			AssertingPartyDetails party = this.assertingPartyDetailsBuilder.build();
-			String registrationId = this.registrationId.convert(party);
-			return new RelyingPartyRegistration(registrationId, this.entityId, this.assertionConsumerServiceLocation,
-					this.assertionConsumerServiceBinding, this.singleLogoutServiceLocation,
-					this.singleLogoutServiceResponseLocation, this.singleLogoutServiceBindings, party,
-					this.nameIdFormat, this.decryptionX509Credentials, this.signingX509Credentials);
+			return new RelyingPartyRegistration(this.registrationId, this.entityId,
+					this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding,
+					this.singleLogoutServiceLocation, this.singleLogoutServiceResponseLocation,
+					this.singleLogoutServiceBindings, party, this.nameIdFormat, this.decryptionX509Credentials,
+					this.signingX509Credentials);
 		}
 
 	}

+ 2 - 2
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java

@@ -101,8 +101,8 @@ public final class DefaultRelyingPartyRegistrationResolver
 				.apply(relyingPartyRegistration.getSingleLogoutServiceLocation());
 		String singleLogoutServiceResponseLocation = templateResolver
 				.apply(relyingPartyRegistration.getSingleLogoutServiceResponseLocation());
-		return RelyingPartyRegistration.withRelyingPartyRegistration(relyingPartyRegistration)
-				.entityId(relyingPartyEntityId).assertionConsumerServiceLocation(assertionConsumerServiceLocation)
+		return relyingPartyRegistration.mutate().entityId(relyingPartyEntityId)
+				.assertionConsumerServiceLocation(assertionConsumerServiceLocation)
 				.singleLogoutServiceLocation(singleLogoutServiceLocation)
 				.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocation).build();
 	}

+ 12 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java

@@ -38,6 +38,18 @@ public class RelyingPartyRegistrationTests {
 		compareRegistrations(registration, copy);
 	}
 
+	@Test
+	void mutateWhenInvokedThenCreatesCopy() {
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration()
+				.nameIdFormat("format")
+				.assertingPartyDetails((a) -> a.singleSignOnServiceBinding(Saml2MessageBinding.POST))
+				.assertingPartyDetails((a) -> a.wantAuthnRequestsSigned(false))
+				.assertingPartyDetails((a) -> a.signingAlgorithms((algs) -> algs.add("alg")))
+				.assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT).build();
+		RelyingPartyRegistration copy = registration.mutate().build();
+		compareRegistrations(registration, copy);
+	}
+
 	private void compareRegistrations(RelyingPartyRegistration registration, RelyingPartyRegistration copy) {
 		assertThat(copy.getRegistrationId()).isEqualTo(registration.getRegistrationId()).isEqualTo("simplesamlphp");
 		assertThat(copy.getAssertingPartyDetails().getEntityId())