浏览代码

Add SP NameIDFormat Support

closes gh-9115
Arnaud Mergey 4 年之前
父节点
当前提交
dbe4d704f8

+ 11 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolver.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -31,6 +31,7 @@ import org.opensaml.saml.common.xml.SAMLConstants;
 import org.opensaml.saml.saml2.metadata.AssertionConsumerService;
 import org.opensaml.saml.saml2.metadata.EntityDescriptor;
 import org.opensaml.saml.saml2.metadata.KeyDescriptor;
+import org.opensaml.saml.saml2.metadata.NameIDFormat;
 import org.opensaml.saml.saml2.metadata.SPSSODescriptor;
 import org.opensaml.saml.saml2.metadata.SingleLogoutService;
 import org.opensaml.saml.saml2.metadata.impl.EntityDescriptorMarshaller;
@@ -87,6 +88,9 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
 				.addAll(buildKeys(registration.getDecryptionX509Credentials(), UsageType.ENCRYPTION));
 		spSsoDescriptor.getAssertionConsumerServices().add(buildAssertionConsumerService(registration));
 		spSsoDescriptor.getSingleLogoutServices().add(buildSingleLogoutService(registration));
+		if (registration.getNameIdFormat() != null) {
+			spSsoDescriptor.getNameIDFormats().add(buildNameIDFormat(registration));
+		}
 		return spSsoDescriptor;
 	}
 
@@ -133,6 +137,12 @@ public final class OpenSamlMetadataResolver implements Saml2MetadataResolver {
 		return singleLogoutService;
 	}
 
+	private NameIDFormat buildNameIDFormat(RelyingPartyRegistration registration) {
+		NameIDFormat nameIdFormat = build(NameIDFormat.DEFAULT_ELEMENT_NAME);
+		nameIdFormat.setFormat(registration.getNameIdFormat());
+		return nameIdFormat;
+	}
+
 	@SuppressWarnings("unchecked")
 	private <T> T build(QName elementName) {
 		XMLObjectBuilder<?> builder = XMLObjectProviderRegistrySupport.getBuilderFactory().getBuilder(elementName);

+ 28 - 2
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java

@@ -87,6 +87,8 @@ public final class RelyingPartyRegistration {
 
 	private final Saml2MessageBinding singleLogoutServiceBinding;
 
+	private final String nameIdFormat;
+
 	private final ProviderDetails providerDetails;
 
 	private final List<org.springframework.security.saml2.credentials.Saml2X509Credential> credentials;
@@ -98,7 +100,7 @@ public final class RelyingPartyRegistration {
 	private RelyingPartyRegistration(String registrationId, String entityId, String assertionConsumerServiceLocation,
 			Saml2MessageBinding assertionConsumerServiceBinding, String singleLogoutServiceLocation,
 			String singleLogoutServiceResponseLocation, Saml2MessageBinding singleLogoutServiceBinding,
-			ProviderDetails providerDetails,
+			ProviderDetails providerDetails, String nameIdFormat,
 			Collection<org.springframework.security.saml2.credentials.Saml2X509Credential> credentials,
 			Collection<Saml2X509Credential> decryptionX509Credentials,
 			Collection<Saml2X509Credential> signingX509Credentials) {
@@ -129,6 +131,7 @@ public final class RelyingPartyRegistration {
 		this.singleLogoutServiceLocation = singleLogoutServiceLocation;
 		this.singleLogoutServiceResponseLocation = singleLogoutServiceResponseLocation;
 		this.singleLogoutServiceBinding = singleLogoutServiceBinding;
+		this.nameIdFormat = nameIdFormat;
 		this.providerDetails = providerDetails;
 		this.credentials = Collections.unmodifiableList(new LinkedList<>(credentials));
 		this.decryptionX509Credentials = Collections.unmodifiableList(new LinkedList<>(decryptionX509Credentials));
@@ -234,6 +237,15 @@ public final class RelyingPartyRegistration {
 		return this.singleLogoutServiceResponseLocation;
 	}
 
+	/**
+	 * Get the NameID format.
+	 * @return the NameID format
+	 * @since 5.7
+	 */
+	public String getNameIdFormat() {
+		return this.nameIdFormat;
+	}
+
 	/**
 	 * Get the {@link Collection} of decryption {@link Saml2X509Credential}s associated
 	 * with this relying party
@@ -424,6 +436,7 @@ public final class RelyingPartyRegistration {
 				.singleLogoutServiceLocation(registration.getSingleLogoutServiceLocation())
 				.singleLogoutServiceResponseLocation(registration.getSingleLogoutServiceResponseLocation())
 				.singleLogoutServiceBinding(registration.getSingleLogoutServiceBinding())
+				.nameIdFormat(registration.getNameIdFormat())
 				.assertingPartyDetails((assertingParty) -> assertingParty
 						.entityId(registration.getAssertingPartyDetails().getEntityId())
 						.wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned())
@@ -1018,6 +1031,8 @@ public final class RelyingPartyRegistration {
 
 		private Saml2MessageBinding singleLogoutServiceBinding = Saml2MessageBinding.POST;
 
+		private String nameIdFormat = null;
+
 		private ProviderDetails.Builder providerDetails = new ProviderDetails.Builder();
 
 		private Collection<org.springframework.security.saml2.credentials.Saml2X509Credential> credentials = new HashSet<>();
@@ -1173,6 +1188,17 @@ public final class RelyingPartyRegistration {
 			return this;
 		}
 
+		/**
+		 * Set the NameID format
+		 * @param nameIdFormat
+		 * @return the {@link Builder} for further configuration
+		 * @since 5.7
+		 */
+		public Builder nameIdFormat(String nameIdFormat) {
+			this.nameIdFormat = nameIdFormat;
+			return this;
+		}
+
 		/**
 		 * Apply this {@link Consumer} to further configure the Asserting Party details
 		 * @param assertingPartyDetails The {@link Consumer} to apply
@@ -1321,7 +1347,7 @@ public final class RelyingPartyRegistration {
 			return new RelyingPartyRegistration(this.registrationId, this.entityId,
 					this.assertionConsumerServiceLocation, this.assertionConsumerServiceBinding,
 					this.singleLogoutServiceLocation, this.singleLogoutServiceResponseLocation,
-					this.singleLogoutServiceBinding, this.providerDetails.build(), this.credentials,
+					this.singleLogoutServiceBinding, this.providerDetails.build(), this.nameIdFormat, this.credentials,
 					this.decryptionX509Credentials, this.signingX509Credentials);
 		}
 

+ 19 - 1
saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationRequestFactory.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -27,8 +27,10 @@ import org.opensaml.core.xml.config.XMLObjectProviderRegistry;
 import org.opensaml.saml.common.xml.SAMLConstants;
 import org.opensaml.saml.saml2.core.AuthnRequest;
 import org.opensaml.saml.saml2.core.Issuer;
+import org.opensaml.saml.saml2.core.NameIDPolicy;
 import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder;
 import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
+import org.opensaml.saml.saml2.core.impl.NameIDPolicyBuilder;
 
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.saml2.core.OpenSamlInitializationService;
@@ -56,6 +58,8 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent
 
 	private final IssuerBuilder issuerBuilder;
 
+	private final NameIDPolicyBuilder nameIdPolicyBuilder;
+
 	private Clock clock = Clock.systemUTC();
 
 	private Converter<Saml2AuthenticationRequestContext, AuthnRequest> authenticationRequestContextConverter;
@@ -69,6 +73,8 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent
 		this.authnRequestBuilder = (AuthnRequestBuilder) registry.getBuilderFactory()
 				.getBuilder(AuthnRequest.DEFAULT_ELEMENT_NAME);
 		this.issuerBuilder = (IssuerBuilder) registry.getBuilderFactory().getBuilder(Issuer.DEFAULT_ELEMENT_NAME);
+		this.nameIdPolicyBuilder = (NameIDPolicyBuilder) registry.getBuilderFactory()
+				.getBuilder(NameIDPolicy.DEFAULT_ELEMENT_NAME);
 	}
 
 	/**
@@ -152,6 +158,9 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent
 			auth.setProtocolBinding(SAMLConstants.SAML2_POST_BINDING_URI);
 		}
 		auth.setProtocolBinding(protocolBinding);
+		if (auth.getNameIDPolicy() == null) {
+			setNameIdPolicy(auth, context.getRelyingPartyRegistration());
+		}
 		Issuer iss = this.issuerBuilder.buildObject();
 		iss.setValue(issuer);
 		auth.setIssuer(iss);
@@ -160,6 +169,15 @@ public final class OpenSaml4AuthenticationRequestFactory implements Saml2Authent
 		return auth;
 	}
 
+	private void setNameIdPolicy(AuthnRequest authnRequest, RelyingPartyRegistration registration) {
+		if (!StringUtils.hasText(registration.getNameIdFormat())) {
+			return;
+		}
+		NameIDPolicy nameIdPolicy = this.nameIdPolicyBuilder.buildObject();
+		nameIdPolicy.setFormat(registration.getNameIdFormat());
+		authnRequest.setNameIDPolicy(nameIdPolicy);
+	}
+
 	/**
 	 * Set the strategy for building an {@link AuthnRequest} from a given context
 	 * @param authenticationRequestContextConverter the conversion strategy to use

+ 12 - 0
saml2/saml2-service-provider/src/opensaml4Test/java/org/springframework/security/saml2/provider/service/authentication/OpenSaml4AuthenticationRequestFactoryTests.java

@@ -242,6 +242,18 @@ public class OpenSaml4AuthenticationRequestFactoryTests {
 		assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
 	}
 
+	@Test
+	public void createAuthenticationRequestWhenSetNameIDPolicyThenReturnsCorrectNameIDPolicy() {
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().nameIdFormat("format").build();
+		this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration)
+				.build();
+		AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST);
+		assertThat(authn.getNameIDPolicy()).isNotNull();
+		assertThat(authn.getNameIDPolicy().getAllowCreate()).isFalse();
+		assertThat(authn.getNameIDPolicy().getFormat()).isEqualTo("format");
+		assertThat(authn.getNameIDPolicy().getSPNameQualifier()).isNull();
+	}
+
 	private AuthnRequest authnRequest() {
 		AuthnRequest authnRequest = TestOpenSamlObjects.authnRequest();
 		authnRequest.setIssueInstant(Instant.now());

+ 10 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/OpenSamlMetadataResolverTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -61,4 +61,13 @@ public class OpenSamlMetadataResolverTests {
 				.contains("ResponseLocation=\"https://rp.example.org/logout/saml2/response\"");
 	}
 
+	@Test
+	public void resolveWhenRelyingPartyNameIDFormatThenMetadataMatches() {
+		RelyingPartyRegistration relyingPartyRegistration = TestRelyingPartyRegistrations.full().nameIdFormat("format")
+				.build();
+		OpenSamlMetadataResolver openSamlMetadataResolver = new OpenSamlMetadataResolver();
+		String metadata = openSamlMetadataResolver.resolve(relyingPartyRegistration);
+		assertThat(metadata).contains("<md:NameIDFormat>format</md:NameIDFormat>");
+	}
+
 }

+ 2 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java

@@ -28,6 +28,7 @@ public class RelyingPartyRegistrationTests {
 	@Test
 	public void withRelyingPartyRegistrationWorks() {
 		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration()
+				.nameIdFormat("format")
 				.assertingPartyDetails((a) -> a.singleSignOnServiceBinding(Saml2MessageBinding.POST))
 				.assertingPartyDetails((a) -> a.wantAuthnRequestsSigned(false))
 				.assertingPartyDetails((a) -> a.signingAlgorithms((algs) -> algs.add("alg")))
@@ -74,6 +75,7 @@ public class RelyingPartyRegistrationTests {
 				.isEqualTo(registration.getAssertingPartyDetails().getVerificationX509Credentials());
 		assertThat(copy.getAssertingPartyDetails().getSigningAlgorithms())
 				.isEqualTo(registration.getAssertingPartyDetails().getSigningAlgorithms());
+		assertThat(copy.getNameIdFormat()).isEqualTo(registration.getNameIdFormat());
 	}
 
 	@Test