Browse Source

Allow custom relay state

Closes gh-11065
sebastiano 3 years ago
parent
commit
f7a43e4989

+ 8 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java

@@ -36,6 +36,7 @@ import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
 import org.opensaml.saml.saml2.core.impl.NameIDBuilder;
 import org.opensaml.saml.saml2.core.impl.NameIDBuilder;
 import org.w3c.dom.Element;
 import org.w3c.dom.Element;
 
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.saml2.Saml2Exception;
 import org.springframework.security.saml2.Saml2Exception;
 import org.springframework.security.saml2.core.OpenSamlInitializationService;
 import org.springframework.security.saml2.core.OpenSamlInitializationService;
 import org.springframework.security.saml2.core.Saml2ParameterNames;
 import org.springframework.security.saml2.core.Saml2ParameterNames;
@@ -71,6 +72,8 @@ class OpenSamlAuthenticationRequestResolver {
 
 
 	private final NameIDBuilder nameIdBuilder;
 	private final NameIDBuilder nameIdBuilder;
 
 
+	private Converter<HttpServletRequest, String> relayStateResolver = (request) -> UUID.randomUUID().toString();
+
 	/**
 	/**
 	 * Construct a {@link OpenSamlAuthenticationRequestResolver} using the provided
 	 * Construct a {@link OpenSamlAuthenticationRequestResolver} using the provided
 	 * parameters
 	 * parameters
@@ -93,6 +96,10 @@ class OpenSamlAuthenticationRequestResolver {
 		Assert.notNull(this.nameIdBuilder, "nameIdBuilder must be configured in OpenSAML");
 		Assert.notNull(this.nameIdBuilder, "nameIdBuilder must be configured in OpenSAML");
 	}
 	}
 
 
+	void setRelayStateResolver(Converter<HttpServletRequest, String> relayStateResolver) {
+		this.relayStateResolver = relayStateResolver;
+	}
+
 	<T extends AbstractSaml2AuthenticationRequest> T resolve(HttpServletRequest request) {
 	<T extends AbstractSaml2AuthenticationRequest> T resolve(HttpServletRequest request) {
 		return resolve(request, (registration, logoutRequest) -> {
 		return resolve(request, (registration, logoutRequest) -> {
 		});
 		});
@@ -122,7 +129,7 @@ class OpenSamlAuthenticationRequestResolver {
 		if (authnRequest.getID() == null) {
 		if (authnRequest.getID() == null) {
 			authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1));
 			authnRequest.setID("ARQ" + UUID.randomUUID().toString().substring(1));
 		}
 		}
-		String relayState = UUID.randomUUID().toString();
+		String relayState = this.relayStateResolver.convert(request);
 		Saml2MessageBinding binding = registration.getAssertingPartyDetails().getSingleSignOnServiceBinding();
 		Saml2MessageBinding binding = registration.getAssertingPartyDetails().getSingleSignOnServiceBinding();
 		if (binding == Saml2MessageBinding.POST) {
 		if (binding == Saml2MessageBinding.POST) {
 			if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
 			if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) {

+ 11 - 0
saml2/saml2-service-provider/src/opensaml4Main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSaml4AuthenticationRequestResolver.java

@@ -23,6 +23,7 @@ import java.util.function.Consumer;
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletRequest;
 import org.opensaml.saml.saml2.core.AuthnRequest;
 import org.opensaml.saml.saml2.core.AuthnRequest;
 
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
 import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationResolver;
@@ -77,6 +78,16 @@ public final class OpenSaml4AuthenticationRequestResolver implements Saml2Authen
 		this.clock = clock;
 		this.clock = clock;
 	}
 	}
 
 
+	/**
+	 * Use this {@link Converter} to compute the RelayState
+	 * @param relayStateResolver the {@link Converter} to use
+	 * @since 5.7
+	 */
+	public void setRelayStateResolver(Converter<HttpServletRequest, String> relayStateResolver) {
+		Assert.notNull(relayStateResolver, "relayStateResolver cannot be null");
+		this.authnRequestResolver.setRelayStateResolver(relayStateResolver);
+	}
+
 	public static final class AuthnRequestContext {
 	public static final class AuthnRequestContext {
 
 
 		private final HttpServletRequest request;
 		private final HttpServletRequest request;