Răsfoiți Sursa

Merge branch '6.0.x'

Closes gh-12937
Josh Cummings 2 ani în urmă
părinte
comite
46a40e7b38

+ 12 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java

@@ -30,10 +30,12 @@ import org.opensaml.core.xml.io.MarshallingException;
 import org.opensaml.saml.saml2.core.AuthnRequest;
 import org.opensaml.saml.saml2.core.Issuer;
 import org.opensaml.saml.saml2.core.NameID;
+import org.opensaml.saml.saml2.core.NameIDPolicy;
 import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder;
 import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller;
 import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
 import org.opensaml.saml.saml2.core.impl.NameIDBuilder;
+import org.opensaml.saml.saml2.core.impl.NameIDPolicyBuilder;
 import org.w3c.dom.Element;
 
 import org.springframework.core.convert.converter.Converter;
@@ -71,6 +73,8 @@ class OpenSamlAuthenticationRequestResolver {
 
 	private final NameIDBuilder nameIdBuilder;
 
+	private final NameIDPolicyBuilder nameIdPolicyBuilder;
+
 	private RequestMatcher requestMatcher = new AntPathRequestMatcher(
 			Saml2AuthenticationRequestResolver.DEFAULT_AUTHENTICATION_REQUEST_URI);
 
@@ -96,6 +100,9 @@ class OpenSamlAuthenticationRequestResolver {
 		Assert.notNull(this.issuerBuilder, "issuerBuilder must be configured in OpenSAML");
 		this.nameIdBuilder = (NameIDBuilder) registry.getBuilderFactory().getBuilder(NameID.DEFAULT_ELEMENT_NAME);
 		Assert.notNull(this.nameIdBuilder, "nameIdBuilder must be configured in OpenSAML");
+		this.nameIdPolicyBuilder = (NameIDPolicyBuilder) registry.getBuilderFactory()
+				.getBuilder(NameIDPolicy.DEFAULT_ELEMENT_NAME);
+		Assert.notNull(this.nameIdPolicyBuilder, "nameIdPolicyBuilder must be configured in OpenSAML");
 	}
 
 	void setRelayStateResolver(Converter<HttpServletRequest, String> relayStateResolver) {
@@ -135,6 +142,11 @@ class OpenSamlAuthenticationRequestResolver {
 		authnRequest.setIssuer(iss);
 		authnRequest.setDestination(registration.getAssertingPartyDetails().getSingleSignOnServiceLocation());
 		authnRequest.setAssertionConsumerServiceURL(assertionConsumerServiceLocation);
+		if (registration.getNameIdFormat() != null) {
+			NameIDPolicy nameIdPolicy = this.nameIdPolicyBuilder.buildObject();
+			nameIdPolicy.setFormat(registration.getNameIdFormat());
+			authnRequest.setNameIDPolicy(nameIdPolicy);
+		}
 		authnRequestConsumer.accept(registration, authnRequest);
 		if (authnRequest.getID() == null) {
 			authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1));

+ 2 - 2
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/TestRelyingPartyRegistrations.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -38,7 +38,7 @@ public final class TestRelyingPartyRegistrations {
 		Saml2X509Credential verificationCertificate = TestSaml2X509Credentials.relyingPartyVerifyingCredential();
 		String singleSignOnServiceLocation = "https://simplesaml-for-spring-saml.apps.pcfone.io/saml2/idp/SSOService.php";
 		String singleLogoutServiceLocation = "{baseUrl}/logout/saml2/slo";
-		return RelyingPartyRegistration.withRegistrationId(registrationId).entityId(rpEntityId)
+		return RelyingPartyRegistration.withRegistrationId(registrationId).entityId(rpEntityId).nameIdFormat("format")
 				.assertionConsumerServiceLocation(assertionConsumerServiceLocation)
 				.singleLogoutServiceLocation(singleLogoutServiceLocation)
 				.signingX509Credentials((c) -> c.add(signingCredential)).assertingPartyDetails(

+ 4 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolverTests.java

@@ -64,6 +64,7 @@ public class OpenSamlAuthenticationRequestResolverTests {
 		OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
 		Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
 			UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
+			assertThat(authnRequest.getNameIDPolicy().getFormat()).isEqualTo(registration.getNameIdFormat());
 			assertThat(authnRequest.getAssertionConsumerServiceURL())
 					.isEqualTo(uriResolver.resolve(registration.getAssertionConsumerServiceLocation()));
 			assertThat(authnRequest.getProtocolBinding())
@@ -89,6 +90,7 @@ public class OpenSamlAuthenticationRequestResolverTests {
 		OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
 		Saml2RedirectAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
 			UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
+			assertThat(authnRequest.getNameIDPolicy().getFormat()).isEqualTo(registration.getNameIdFormat());
 			assertThat(authnRequest.getAssertionConsumerServiceURL())
 					.isEqualTo(uriResolver.resolve(registration.getAssertionConsumerServiceLocation()));
 			assertThat(authnRequest.getProtocolBinding())
@@ -128,6 +130,7 @@ public class OpenSamlAuthenticationRequestResolverTests {
 		OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
 		Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
 			UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
+			assertThat(authnRequest.getNameIDPolicy().getFormat()).isEqualTo(registration.getNameIdFormat());
 			assertThat(authnRequest.getAssertionConsumerServiceURL())
 					.isEqualTo(uriResolver.resolve(registration.getAssertionConsumerServiceLocation()));
 			assertThat(authnRequest.getProtocolBinding())
@@ -157,6 +160,7 @@ public class OpenSamlAuthenticationRequestResolverTests {
 		OpenSamlAuthenticationRequestResolver resolver = authenticationRequestResolver(registration);
 		Saml2PostAuthenticationRequest result = resolver.resolve(request, (r, authnRequest) -> {
 			UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
+			assertThat(authnRequest.getNameIDPolicy().getFormat()).isEqualTo(registration.getNameIdFormat());
 			assertThat(authnRequest.getAssertionConsumerServiceURL())
 					.isEqualTo(uriResolver.resolve(registration.getAssertionConsumerServiceLocation()));
 			assertThat(authnRequest.getProtocolBinding())