Browse Source

Add AssertionConsumerServiceBinding

Closes gh-8776
Josh Cummings 5 years ago
parent
commit
44ec061f05

+ 19 - 5
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java

@@ -29,6 +29,7 @@ import org.opensaml.saml.common.xml.SAMLConstants;
 import org.opensaml.saml.saml2.core.AuthnRequest;
 import org.opensaml.saml.saml2.core.Issuer;
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.saml2.credentials.Saml2X509Credential;
 import org.springframework.security.saml2.provider.service.authentication.Saml2RedirectAuthenticationRequest.Builder;
 import org.springframework.util.Assert;
@@ -43,7 +44,14 @@ import static org.springframework.security.saml2.provider.service.authentication
 public class OpenSamlAuthenticationRequestFactory implements Saml2AuthenticationRequestFactory {
 	private Clock clock = Clock.systemUTC();
 	private final OpenSamlImplementation saml = OpenSamlImplementation.getInstance();
-	private String protocolBinding = SAMLConstants.SAML2_POST_BINDING_URI;
+
+	private Converter<Saml2AuthenticationRequestContext, String> protocolBindingResolver =
+			context -> {
+				if (context == null) {
+					return SAMLConstants.SAML2_POST_BINDING_URI;
+				}
+				return context.getRelyingPartyRegistration().getAssertionConsumerServiceBinding().getUrn();
+			};
 
 	private Function<Saml2AuthenticationRequestContext, Consumer<AuthnRequest>> authnRequestConsumerResolver
 			= context -> authnRequest -> {};
@@ -52,7 +60,8 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 	@Deprecated
 	public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
 		AuthnRequest authnRequest = createAuthnRequest(request.getIssuer(),
-				request.getDestination(), request.getAssertionConsumerServiceUrl());
+				request.getDestination(), request.getAssertionConsumerServiceUrl(),
+				this.protocolBindingResolver.convert(null));
 		return this.saml.serialize(authnRequest, request.getCredentials());
 	}
 
@@ -101,12 +110,14 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 
 	private AuthnRequest createAuthnRequest(Saml2AuthenticationRequestContext context) {
 		AuthnRequest authnRequest = createAuthnRequest(context.getIssuer(),
-				context.getDestination(), context.getAssertionConsumerServiceUrl());
+				context.getDestination(), context.getAssertionConsumerServiceUrl(),
+				this.protocolBindingResolver.convert(context));
 		this.authnRequestConsumerResolver.apply(context).accept(authnRequest);
 		return authnRequest;
 	}
 
-	private AuthnRequest createAuthnRequest(String issuer, String destination, String assertionConsumerServiceUrl) {
+	private AuthnRequest createAuthnRequest
+			(String issuer, String destination, String assertionConsumerServiceUrl, String protocolBinding) {
 		AuthnRequest auth = this.saml.buildSamlObject(AuthnRequest.DEFAULT_ELEMENT_NAME);
 		auth.setID("ARQ" + UUID.randomUUID().toString().substring(1));
 		auth.setIssueInstant(new DateTime(this.clock.millis()));
@@ -155,13 +166,16 @@ public class OpenSamlAuthenticationRequestFactory implements Saml2Authentication
 	 * @param protocolBinding either {@link SAMLConstants#SAML2_POST_BINDING_URI} or
 	 * {@link SAMLConstants#SAML2_REDIRECT_BINDING_URI}
 	 * @throws IllegalArgumentException if the protocolBinding is not valid
+	 * @deprecated Use {@link org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.Builder#assertionConsumerServiceBinding}
+	 * instead
 	 */
+	@Deprecated
 	public void setProtocolBinding(String protocolBinding) {
 		boolean isAllowedBinding = SAMLConstants.SAML2_POST_BINDING_URI.equals(protocolBinding) ||
 				SAMLConstants.SAML2_REDIRECT_BINDING_URI.equals(protocolBinding);
 		if (!isAllowedBinding) {
 			throw new IllegalArgumentException("Invalid protocol binding: " + protocolBinding);
 		}
-		this.protocolBinding = protocolBinding;
+		this.protocolBindingResolver = context -> protocolBinding;
 	}
 }

+ 36 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.java

@@ -68,6 +68,7 @@ public class RelyingPartyRegistration {
 	private final String registrationId;
 	private final String entityId;
 	private final String assertionConsumerServiceLocation;
+	private final Saml2MessageBinding assertionConsumerServiceBinding;
 	private final ProviderDetails providerDetails;
 	private final List<Saml2X509Credential> credentials;
 
@@ -75,12 +76,14 @@ public class RelyingPartyRegistration {
 			String registrationId,
 			String entityId,
 			String assertionConsumerServiceLocation,
+			Saml2MessageBinding assertionConsumerServiceBinding,
 			ProviderDetails providerDetails,
 			List<Saml2X509Credential> credentials) {
 
 		Assert.hasText(registrationId, "registrationId cannot be empty");
 		Assert.hasText(entityId, "entityId cannot be empty");
 		Assert.hasText(assertionConsumerServiceLocation, "assertionConsumerServiceLocation cannot be empty");
+		Assert.notNull(assertionConsumerServiceBinding, "assertionConsumerServiceBinding cannot be null");
 		Assert.notNull(providerDetails, "providerDetails cannot be null");
 		Assert.notEmpty(credentials, "credentials cannot be empty");
 		for (Saml2X509Credential c : credentials) {
@@ -89,6 +92,7 @@ public class RelyingPartyRegistration {
 		this.registrationId = registrationId;
 		this.entityId = entityId;
 		this.assertionConsumerServiceLocation = assertionConsumerServiceLocation;
+		this.assertionConsumerServiceBinding = assertionConsumerServiceBinding;
 		this.providerDetails = providerDetails;
 		this.credentials = Collections.unmodifiableList(new LinkedList<>(credentials));
 	}
@@ -138,6 +142,18 @@ public class RelyingPartyRegistration {
 		return this.assertionConsumerServiceLocation;
 	}
 
+	/**
+	 * Get the AssertionConsumerService Binding.
+	 * Equivalent to the value found in &lt;AssertionConsumerService Binding="..."/&gt;
+	 * in the relying party's &lt;SPSSODescriptor&gt;.
+	 *
+	 * @return the AssertionConsumerService Binding
+	 * @since 5.4
+	 */
+	public Saml2MessageBinding getAssertionConsumerServiceBinding() {
+		return this.assertionConsumerServiceBinding;
+	}
+
 	/**
 	 * Get the configuration details for the Asserting Party
 	 *
@@ -280,6 +296,7 @@ public class RelyingPartyRegistration {
 		return withRegistrationId(registration.getRegistrationId())
 				.entityId(registration.getEntityId())
 				.assertionConsumerServiceLocation(registration.getAssertionConsumerServiceLocation())
+				.assertionConsumerServiceBinding(registration.getAssertionConsumerServiceBinding())
 				.assertingPartyDetails(c -> c
 					.entityId(registration.getAssertingPartyDetails().getEntityId())
 					.wantAuthnRequestsSigned(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned())
@@ -575,6 +592,7 @@ public class RelyingPartyRegistration {
 		private String registrationId;
 		private String entityId = "{baseUrl}/saml2/service-provider-metadata/{registrationId}";
 		private String assertionConsumerServiceLocation;
+		private Saml2MessageBinding assertionConsumerServiceBinding = Saml2MessageBinding.POST;
 		private ProviderDetails.Builder providerDetails = new ProviderDetails.Builder();
 		private List<Saml2X509Credential> credentials = new LinkedList<>();
 
@@ -633,6 +651,23 @@ public class RelyingPartyRegistration {
 			return this;
 		}
 
+		/**
+		 * Set the <a href="https://wiki.shibboleth.net/confluence/display/CONCEPT/AssertionConsumerService">AssertionConsumerService</a>
+		 * Binding.
+		 *
+		 * <p>
+		 * Equivalent to the value found in &lt;AssertionConsumerService Binding="..."/&gt;
+		 * in the relying party's &lt;SPSSODescriptor&gt;
+		 *
+		 * @param assertionConsumerServiceBinding
+		 * @return the {@link Builder} for further configuration
+		 * @since 5.4
+		 */
+		public Builder assertionConsumerServiceBinding(Saml2MessageBinding assertionConsumerServiceBinding) {
+			this.assertionConsumerServiceBinding = assertionConsumerServiceBinding;
+			return this;
+		}
+
 		/**
 		 * Apply this {@link Consumer} to further configure the Asserting Party details
 		 *
@@ -738,6 +773,7 @@ public class RelyingPartyRegistration {
 					this.registrationId,
 					this.entityId,
 					this.assertionConsumerServiceLocation,
+					this.assertionConsumerServiceBinding,
 					this.providerDetails.build(),
 					this.credentials
 			);

+ 22 - 5
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java

@@ -39,6 +39,7 @@ import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartySigningCredential;
 import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlDecode;
+import static org.springframework.security.saml2.provider.service.authentication.Saml2Utils.samlInflate;
 import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
 import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
 import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.REDIRECT;
@@ -52,19 +53,21 @@ public class OpenSamlAuthenticationRequestFactoryTests {
 	private Saml2AuthenticationRequestContext.Builder contextBuilder;
 	private Saml2AuthenticationRequestContext context;
 
+	private RelyingPartyRegistration.Builder relyingPartyRegistrationBuilder;
+	private RelyingPartyRegistration relyingPartyRegistration;
+
 	@Rule
 	public ExpectedException exception = ExpectedException.none();
-	private RelyingPartyRegistration relyingPartyRegistration;
 
 	@Before
 	public void setUp() {
-		relyingPartyRegistration = RelyingPartyRegistration.withRegistrationId("id")
+		this.relyingPartyRegistrationBuilder = RelyingPartyRegistration.withRegistrationId("id")
 				.assertionConsumerServiceLocation("template")
 				.providerDetails(c -> c.webSsoUrl("https://destination/sso"))
 				.providerDetails(c -> c.entityId("remote-entity-id"))
 				.localEntityIdTemplate("local-entity-id")
-				.credentials(c -> c.add(relyingPartySigningCredential()))
-				.build();
+				.credentials(c -> c.add(relyingPartySigningCredential()));
+		this.relyingPartyRegistration = this.relyingPartyRegistrationBuilder.build();
 		contextBuilder = Saml2AuthenticationRequestContext.builder()
 				.issuer("https://issuer")
 				.relyingPartyRegistration(relyingPartyRegistration)
@@ -195,6 +198,20 @@ public class OpenSamlAuthenticationRequestFactoryTests {
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 
+	@Test
+	public void createPostAuthenticationRequestWhenAssertionConsumerServiceBindingThenUses() {
+		RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationBuilder
+				.assertionConsumerServiceBinding(REDIRECT)
+				.build();
+		Saml2AuthenticationRequestContext context = this.contextBuilder
+				.relyingPartyRegistration(relyingPartyRegistration)
+				.build();
+		Saml2PostAuthenticationRequest request = this.factory.createPostAuthenticationRequest(context);
+		String samlRequest = request.getSamlRequest();
+		String inflated = new String(samlDecode(samlRequest));
+		assertThat(inflated).contains("ProtocolBinding=\"" + SAMLConstants.SAML2_REDIRECT_BINDING_URI + "\"");
+	}
+
 	private AuthnRequest getAuthNRequest(Saml2MessageBinding binding) {
 		AbstractSaml2AuthenticationRequest result = (binding == REDIRECT) ?
 				factory.createRedirectAuthenticationRequest(context) :
@@ -202,7 +219,7 @@ public class OpenSamlAuthenticationRequestFactoryTests {
 		String samlRequest = result.getSamlRequest();
 		assertThat(samlRequest).isNotEmpty();
 		if (result.getBinding() == REDIRECT) {
-			samlRequest = Saml2Utils.samlInflate(samlDecode(samlRequest));
+			samlRequest = samlInflate(samlDecode(samlRequest));
 		}
 		else {
 			samlRequest = new String(samlDecode(samlRequest), UTF_8);

+ 20 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistrationTests.java

@@ -21,6 +21,8 @@ import org.junit.Test;
 import org.springframework.security.saml2.provider.service.servlet.filter.Saml2WebSsoAuthenticationFilter;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.springframework.security.saml2.credentials.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
+import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRegistrationId;
 import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
 import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
 
@@ -31,6 +33,7 @@ public class RelyingPartyRegistrationTests {
 		RelyingPartyRegistration registration = relyingPartyRegistration()
 				.providerDetails(p -> p.binding(POST))
 				.providerDetails(p -> p.signAuthNRequest(false))
+				.assertionConsumerServiceBinding(Saml2MessageBinding.REDIRECT)
 				.build();
 		RelyingPartyRegistration copy = RelyingPartyRegistration.withRelyingPartyRegistration(registration).build();
 		compareRegistrations(registration, copy);
@@ -76,5 +79,22 @@ public class RelyingPartyRegistrationTests {
 				.isEqualTo(copy.getAssertingPartyDetails().getWantAuthnRequestsSigned())
 				.isEqualTo(registration.getAssertingPartyDetails().getWantAuthnRequestsSigned())
 				.isFalse();
+		assertThat(copy.getAssertionConsumerServiceBinding())
+				.isEqualTo(registration.getAssertionConsumerServiceBinding());
+	}
+
+	@Test
+	public void buildWhenUsingDefaultsThenAssertionConsumerServiceBindingDefaultsToPost() {
+		RelyingPartyRegistration relyingPartyRegistration = withRegistrationId("id")
+				.entityId("entity-id")
+				.assertionConsumerServiceLocation("location")
+				.assertingPartyDetails(assertingParty -> assertingParty
+					.entityId("entity-id")
+					.singleSignOnServiceLocation("location"))
+					.credentials(c -> c.add(relyingPartyVerifyingCredential()))
+				.build();
+
+		assertThat(relyingPartyRegistration.getAssertionConsumerServiceBinding())
+				.isEqualTo(POST);
 	}
 }