Pārlūkot izejas kodu

Add relyingPartyRegistrationId to AbstractSaml2AuthenticationRequest

Closes gh-11195
Ulrich Grave 3 gadi atpakaļ
vecāks
revīzija
9b874bcde2
10 mainītis faili ar 158 papildinājumiem un 29 dzēšanām
  1. 2 1
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixin.java
  2. 2 1
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixin.java
  3. 33 1
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java
  4. 10 6
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java
  5. 8 6
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java
  6. 13 10
      saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java
  7. 17 0
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixinTests.java
  8. 20 0
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixinTests.java
  9. 9 4
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/TestSaml2JsonPayloads.java
  10. 44 0
      saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java

+ 2 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixin.java

@@ -47,7 +47,8 @@ class Saml2PostAuthenticationRequestMixin {
 	@JsonCreator
 	@JsonCreator
 	Saml2PostAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest,
 	Saml2PostAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest,
 			@JsonProperty("relayState") String relayState,
 			@JsonProperty("relayState") String relayState,
-			@JsonProperty("authenticationRequestUri") String authenticationRequestUri) {
+			@JsonProperty("authenticationRequestUri") String authenticationRequestUri,
+			@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) {
 	}
 	}
 
 
 }
 }

+ 2 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixin.java

@@ -48,7 +48,8 @@ class Saml2RedirectAuthenticationRequestMixin {
 	Saml2RedirectAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest,
 	Saml2RedirectAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest,
 			@JsonProperty("sigAlg") String sigAlg, @JsonProperty("signature") String signature,
 			@JsonProperty("sigAlg") String sigAlg, @JsonProperty("signature") String signature,
 			@JsonProperty("relayState") String relayState,
 			@JsonProperty("relayState") String relayState,
-			@JsonProperty("authenticationRequestUri") String authenticationRequestUri) {
+			@JsonProperty("authenticationRequestUri") String authenticationRequestUri,
+			@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) {
 	}
 	}
 
 
 }
 }

+ 33 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java

@@ -20,6 +20,7 @@ import java.io.Serializable;
 import java.nio.charset.Charset;
 import java.nio.charset.Charset;
 
 
 import org.springframework.security.core.SpringSecurityCoreVersion;
 import org.springframework.security.core.SpringSecurityCoreVersion;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
 import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
 import org.springframework.util.Assert;
 import org.springframework.util.Assert;
 
 
@@ -46,6 +47,8 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
 
 
 	private final String authenticationRequestUri;
 	private final String authenticationRequestUri;
 
 
+	private final String relyingPartyRegistrationId;
+
 	/**
 	/**
 	 * Mandatory constructor for the {@link AbstractSaml2AuthenticationRequest}
 	 * Mandatory constructor for the {@link AbstractSaml2AuthenticationRequest}
 	 * @param samlRequest - the SAMLRequest XML data, SAML encoded, cannot be empty or
 	 * @param samlRequest - the SAMLRequest XML data, SAML encoded, cannot be empty or
@@ -53,13 +56,17 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
 	 * @param relayState - RelayState value that accompanies the request, may be null
 	 * @param relayState - RelayState value that accompanies the request, may be null
 	 * @param authenticationRequestUri - The authenticationRequestUri, a URL, where to
 	 * @param authenticationRequestUri - The authenticationRequestUri, a URL, where to
 	 * send the XML message, cannot be empty or null
 	 * send the XML message, cannot be empty or null
+	 * @param relyingPartyRegistrationId the registration id of the relying party, may be
+	 * null
 	 */
 	 */
-	AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri) {
+	AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri,
+			String relyingPartyRegistrationId) {
 		Assert.hasText(samlRequest, "samlRequest cannot be null or empty");
 		Assert.hasText(samlRequest, "samlRequest cannot be null or empty");
 		Assert.hasText(authenticationRequestUri, "authenticationRequestUri cannot be null or empty");
 		Assert.hasText(authenticationRequestUri, "authenticationRequestUri cannot be null or empty");
 		this.authenticationRequestUri = authenticationRequestUri;
 		this.authenticationRequestUri = authenticationRequestUri;
 		this.samlRequest = samlRequest;
 		this.samlRequest = samlRequest;
 		this.relayState = relayState;
 		this.relayState = relayState;
+		this.relyingPartyRegistrationId = relyingPartyRegistrationId;
 	}
 	}
 
 
 	/**
 	/**
@@ -89,6 +96,16 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
 		return this.authenticationRequestUri;
 		return this.authenticationRequestUri;
 	}
 	}
 
 
+	/**
+	 * The identifier for the {@link RelyingPartyRegistration} associated with this
+	 * request
+	 * @return the {@link RelyingPartyRegistration} id
+	 * @since 5.8
+	 */
+	public String getRelyingPartyRegistrationId() {
+		return this.relyingPartyRegistrationId;
+	}
+
 	/**
 	/**
 	 * Returns the binding this AuthNRequest will be sent and encoded with. If
 	 * Returns the binding this AuthNRequest will be sent and encoded with. If
 	 * {@link Saml2MessageBinding#REDIRECT} is used, the DEFLATE encoding will be
 	 * {@link Saml2MessageBinding#REDIRECT} is used, the DEFLATE encoding will be
@@ -108,9 +125,24 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
 
 
 		String relayState;
 		String relayState;
 
 
+		String relyingPartyRegistrationId;
+
+		/**
+		 * @deprecated Use {@link #Builder(RelyingPartyRegistration)} instead
+		 */
+		@Deprecated
 		protected Builder() {
 		protected Builder() {
 		}
 		}
 
 
+		/**
+		 * Creates a new Builder with relying party registration
+		 * @param registration the registration of the relying party.
+		 * @sine 5.8
+		 */
+		protected Builder(RelyingPartyRegistration registration) {
+			this.relyingPartyRegistrationId = registration.getRegistrationId();
+		}
+
 		/**
 		/**
 		 * Casting the return as the generic subtype, when returning itself
 		 * Casting the return as the generic subtype, when returning itself
 		 * @return this object
 		 * @return this object

+ 10 - 6
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java

@@ -30,8 +30,9 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
  */
  */
 public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationRequest {
 public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationRequest {
 
 
-	Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri) {
-		super(samlRequest, relayState, authenticationRequestUri);
+	Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri,
+			String relyingPartyRegistrationId) {
+		super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId);
 	}
 	}
 
 
 	/**
 	/**
@@ -52,7 +53,8 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR
 	 * @return a modifiable builder object
 	 * @return a modifiable builder object
 	 */
 	 */
 	public static Builder withAuthenticationRequestContext(Saml2AuthenticationRequestContext context) {
 	public static Builder withAuthenticationRequestContext(Saml2AuthenticationRequestContext context) {
-		return new Builder().authenticationRequestUri(context.getDestination()).relayState(context.getRelayState());
+		return new Builder(context.getRelyingPartyRegistration()).authenticationRequestUri(context.getDestination())
+				.relayState(context.getRelayState());
 	}
 	}
 
 
 	/**
 	/**
@@ -63,7 +65,7 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR
 	 */
 	 */
 	public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) {
 	public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) {
 		String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation();
 		String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation();
-		return new Builder().authenticationRequestUri(location);
+		return new Builder(registration).authenticationRequestUri(location);
 	}
 	}
 
 
 	/**
 	/**
@@ -71,7 +73,8 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR
 	 */
 	 */
 	public static final class Builder extends AbstractSaml2AuthenticationRequest.Builder<Builder> {
 	public static final class Builder extends AbstractSaml2AuthenticationRequest.Builder<Builder> {
 
 
-		private Builder() {
+		private Builder(RelyingPartyRegistration registration) {
+			super(registration);
 		}
 		}
 
 
 		/**
 		/**
@@ -79,7 +82,8 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR
 		 * @return an immutable {@link Saml2PostAuthenticationRequest} object.
 		 * @return an immutable {@link Saml2PostAuthenticationRequest} object.
 		 */
 		 */
 		public Saml2PostAuthenticationRequest build() {
 		public Saml2PostAuthenticationRequest build() {
-			return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri);
+			return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri,
+					this.relyingPartyRegistrationId);
 		}
 		}
 
 
 	}
 	}

+ 8 - 6
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java

@@ -35,8 +35,8 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
 	private final String signature;
 	private final String signature;
 
 
 	private Saml2RedirectAuthenticationRequest(String samlRequest, String sigAlg, String signature, String relayState,
 	private Saml2RedirectAuthenticationRequest(String samlRequest, String sigAlg, String signature, String relayState,
-			String authenticationRequestUri) {
-		super(samlRequest, relayState, authenticationRequestUri);
+			String authenticationRequestUri, String relyingPartyRegistrationId) {
+		super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId);
 		this.sigAlg = sigAlg;
 		this.sigAlg = sigAlg;
 		this.signature = signature;
 		this.signature = signature;
 	}
 	}
@@ -75,7 +75,8 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
 	 * @return a modifiable builder object
 	 * @return a modifiable builder object
 	 */
 	 */
 	public static Builder withAuthenticationRequestContext(Saml2AuthenticationRequestContext context) {
 	public static Builder withAuthenticationRequestContext(Saml2AuthenticationRequestContext context) {
-		return new Builder().authenticationRequestUri(context.getDestination()).relayState(context.getRelayState());
+		return new Builder(context.getRelyingPartyRegistration()).authenticationRequestUri(context.getDestination())
+				.relayState(context.getRelayState());
 	}
 	}
 
 
 	/**
 	/**
@@ -87,7 +88,7 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
 	 */
 	 */
 	public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) {
 	public static Builder withRelyingPartyRegistration(RelyingPartyRegistration registration) {
 		String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation();
 		String location = registration.getAssertingPartyDetails().getSingleSignOnServiceLocation();
-		return new Builder().authenticationRequestUri(location);
+		return new Builder(registration).authenticationRequestUri(location);
 	}
 	}
 
 
 	/**
 	/**
@@ -99,7 +100,8 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
 
 
 		private String signature;
 		private String signature;
 
 
-		private Builder() {
+		private Builder(RelyingPartyRegistration registration) {
+			super(registration);
 		}
 		}
 
 
 		/**
 		/**
@@ -128,7 +130,7 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
 		 */
 		 */
 		public Saml2RedirectAuthenticationRequest build() {
 		public Saml2RedirectAuthenticationRequest build() {
 			return new Saml2RedirectAuthenticationRequest(this.samlRequest, this.sigAlg, this.signature,
 			return new Saml2RedirectAuthenticationRequest(this.samlRequest, this.sigAlg, this.signature,
-					this.relayState, this.authenticationRequestUri);
+					this.relayState, this.authenticationRequestUri, this.relyingPartyRegistrationId);
 		}
 		}
 
 
 	}
 	}

+ 13 - 10
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverter.java

@@ -51,7 +51,7 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 
 
 	private static Base64 BASE64 = new Base64(0, new byte[] { '\n' }, false, CodecPolicy.STRICT);
 	private static Base64 BASE64 = new Base64(0, new byte[] { '\n' }, false, CodecPolicy.STRICT);
 
 
-	private final Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver;
+	private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
 
 
 	private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;
 	private Function<HttpServletRequest, AbstractSaml2AuthenticationRequest> loader;
 
 
@@ -67,24 +67,28 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 	@Deprecated
 	@Deprecated
 	public Saml2AuthenticationTokenConverter(
 	public Saml2AuthenticationTokenConverter(
 			Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
 			Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
-		Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
-		this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
-		this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
+		this(adaptToResolver(relyingPartyRegistrationResolver));
 	}
 	}
 
 
 	public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
 	public Saml2AuthenticationTokenConverter(RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
-		this(adaptToConverter(relyingPartyRegistrationResolver));
+		Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
+		this.relyingPartyRegistrationResolver = relyingPartyRegistrationResolver;
+		this.loader = new HttpSessionSaml2AuthenticationRequestRepository()::loadAuthenticationRequest;
 	}
 	}
 
 
-	private static Converter<HttpServletRequest, RelyingPartyRegistration> adaptToConverter(
-			RelyingPartyRegistrationResolver relyingPartyRegistrationResolver) {
+	private static RelyingPartyRegistrationResolver adaptToResolver(
+			Converter<HttpServletRequest, RelyingPartyRegistration> relyingPartyRegistrationResolver) {
 		Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
 		Assert.notNull(relyingPartyRegistrationResolver, "relyingPartyRegistrationResolver cannot be null");
-		return (request) -> relyingPartyRegistrationResolver.resolve(request, null);
+		return (request, relyingPartyRegistrationId) -> relyingPartyRegistrationResolver.convert(request);
 	}
 	}
 
 
 	@Override
 	@Override
 	public Saml2AuthenticationToken convert(HttpServletRequest request) {
 	public Saml2AuthenticationToken convert(HttpServletRequest request) {
-		RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.convert(request);
+		AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
+		String relyingPartyRegistrationId = (authenticationRequest != null)
+				? authenticationRequest.getRelyingPartyRegistrationId() : null;
+		RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationResolver.resolve(request,
+				relyingPartyRegistrationId);
 		if (relyingPartyRegistration == null) {
 		if (relyingPartyRegistration == null) {
 			return null;
 			return null;
 		}
 		}
@@ -94,7 +98,6 @@ public final class Saml2AuthenticationTokenConverter implements AuthenticationCo
 		}
 		}
 		byte[] b = samlDecode(saml2Response);
 		byte[] b = samlDecode(saml2Response);
 		saml2Response = inflateIfRequired(request, b);
 		saml2Response = inflateIfRequired(request, b);
-		AbstractSaml2AuthenticationRequest authenticationRequest = loadAuthenticationRequest(request);
 		return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest);
 		return new Saml2AuthenticationToken(relyingPartyRegistration, saml2Response, authenticationRequest);
 	}
 	}
 
 

+ 17 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixinTests.java

@@ -56,6 +56,23 @@ class Saml2PostAuthenticationRequestMixinTests {
 		assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE);
 		assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE);
 		assertThat(authRequest.getAuthenticationRequestUri())
 		assertThat(authRequest.getAuthenticationRequestUri())
 				.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
 				.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
+		assertThat(authRequest.getRelyingPartyRegistrationId())
+				.isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID);
+	}
+
+	@Test
+	void shouldDeserializeWithNoRegistrationId() throws Exception {
+		String json = TestSaml2JsonPayloads.DEFAULT_POST_AUTH_REQUEST_JSON.replace(
+				"\"relyingPartyRegistrationId\": \"" + TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID + "\",", "");
+
+		Saml2PostAuthenticationRequest authRequest = this.mapper.readValue(json, Saml2PostAuthenticationRequest.class);
+
+		assertThat(authRequest).isNotNull();
+		assertThat(authRequest.getSamlRequest()).isEqualTo(TestSaml2JsonPayloads.SAML_REQUEST);
+		assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE);
+		assertThat(authRequest.getAuthenticationRequestUri())
+				.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
+		assertThat(authRequest.getRelyingPartyRegistrationId()).isNull();
 	}
 	}
 
 
 }
 }

+ 20 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixinTests.java

@@ -59,6 +59,26 @@ class Saml2RedirectAuthenticationRequestMixinTests {
 				.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
 				.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
 		assertThat(authRequest.getSigAlg()).isEqualTo(TestSaml2JsonPayloads.SIG_ALG);
 		assertThat(authRequest.getSigAlg()).isEqualTo(TestSaml2JsonPayloads.SIG_ALG);
 		assertThat(authRequest.getSignature()).isEqualTo(TestSaml2JsonPayloads.SIGNATURE);
 		assertThat(authRequest.getSignature()).isEqualTo(TestSaml2JsonPayloads.SIGNATURE);
+		assertThat(authRequest.getRelyingPartyRegistrationId())
+				.isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID);
+	}
+
+	@Test
+	void shouldDeserializeWithNoRegistrationId() throws Exception {
+		String json = TestSaml2JsonPayloads.DEFAULT_REDIRECT_AUTH_REQUEST_JSON.replace(
+				"\"relyingPartyRegistrationId\": \"" + TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID + "\",", "");
+
+		Saml2RedirectAuthenticationRequest authRequest = this.mapper.readValue(json,
+				Saml2RedirectAuthenticationRequest.class);
+
+		assertThat(authRequest).isNotNull();
+		assertThat(authRequest.getSamlRequest()).isEqualTo(TestSaml2JsonPayloads.SAML_REQUEST);
+		assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE);
+		assertThat(authRequest.getAuthenticationRequestUri())
+				.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
+		assertThat(authRequest.getSigAlg()).isEqualTo(TestSaml2JsonPayloads.SIG_ALG);
+		assertThat(authRequest.getSignature()).isEqualTo(TestSaml2JsonPayloads.SIGNATURE);
+		assertThat(authRequest.getRelyingPartyRegistrationId()).isNull();
 	}
 	}
 
 
 }
 }

+ 9 - 4
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/TestSaml2JsonPayloads.java

@@ -94,6 +94,7 @@ final class TestSaml2JsonPayloads {
 	static final String SAML_REQUEST = "samlRequestValue";
 	static final String SAML_REQUEST = "samlRequestValue";
 	static final String RELAY_STATE = "relayStateValue";
 	static final String RELAY_STATE = "relayStateValue";
 	static final String AUTHENTICATION_REQUEST_URI = "authenticationRequestUriValue";
 	static final String AUTHENTICATION_REQUEST_URI = "authenticationRequestUriValue";
+	static final String RELYINGPARTY_REGISTRATION_ID = "registrationIdValue";
 	static final String SIG_ALG = "sigAlgValue";
 	static final String SIG_ALG = "sigAlgValue";
 	static final String SIGNATURE = "signatureValue";
 	static final String SIGNATURE = "signatureValue";
 
 
@@ -103,6 +104,7 @@ final class TestSaml2JsonPayloads {
 			+ " \"samlRequest\": \"" + SAML_REQUEST + "\","
 			+ " \"samlRequest\": \"" + SAML_REQUEST + "\","
 			+ " \"relayState\": \"" + RELAY_STATE + "\","
 			+ " \"relayState\": \"" + RELAY_STATE + "\","
 			+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\","
 			+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\","
+			+ " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\","
 			+ " \"sigAlg\": \"" + SIG_ALG + "\","
 			+ " \"sigAlg\": \"" + SIG_ALG + "\","
 			+ " \"signature\": \"" + SIGNATURE + "\""
 			+ " \"signature\": \"" + SIGNATURE + "\""
 			+ "}";
 			+ "}";
@@ -113,6 +115,7 @@ final class TestSaml2JsonPayloads {
 			+ " \"@class\": \"org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest\","
 			+ " \"@class\": \"org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest\","
 			+ " \"samlRequest\": \"" + SAML_REQUEST + "\","
 			+ " \"samlRequest\": \"" + SAML_REQUEST + "\","
 			+ " \"relayState\": \"" + RELAY_STATE + "\","
 			+ " \"relayState\": \"" + RELAY_STATE + "\","
+			+ " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\","
 			+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\""
 			+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\""
 			+ "}";
 			+ "}";
 	// @formatter:on
 	// @formatter:on
@@ -120,7 +123,6 @@ final class TestSaml2JsonPayloads {
 	static final String ID = "idValue";
 	static final String ID = "idValue";
 	static final String LOCATION = "locationValue";
 	static final String LOCATION = "locationValue";
 	static final String BINDNG = "REDIRECT";
 	static final String BINDNG = "REDIRECT";
-	static final String RELYINGPARTY_REGISTRATION_ID = "registrationIdValue";
 	static final String ADDITIONAL_PARAM = "additionalParamValue";
 	static final String ADDITIONAL_PARAM = "additionalParamValue";
 
 
 	// @formatter:off
 	// @formatter:off
@@ -140,14 +142,17 @@ final class TestSaml2JsonPayloads {
 	// @formatter:on
 	// @formatter:on
 
 
 	static Saml2PostAuthenticationRequest createDefaultSaml2PostAuthenticationRequest() {
 	static Saml2PostAuthenticationRequest createDefaultSaml2PostAuthenticationRequest() {
-		return Saml2PostAuthenticationRequest.withRelyingPartyRegistration(TestRelyingPartyRegistrations.full()
-				.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
-				.build()).samlRequest(SAML_REQUEST).relayState(RELAY_STATE).build();
+		return Saml2PostAuthenticationRequest.withRelyingPartyRegistration(
+				TestRelyingPartyRegistrations.full().registrationId(RELYINGPARTY_REGISTRATION_ID)
+						.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
+						.build())
+				.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).build();
 	}
 	}
 
 
 	static Saml2RedirectAuthenticationRequest createDefaultSaml2RedirectAuthenticationRequest() {
 	static Saml2RedirectAuthenticationRequest createDefaultSaml2RedirectAuthenticationRequest() {
 		return Saml2RedirectAuthenticationRequest
 		return Saml2RedirectAuthenticationRequest
 				.withRelyingPartyRegistration(TestRelyingPartyRegistrations.full()
 				.withRelyingPartyRegistration(TestRelyingPartyRegistrations.full()
+						.registrationId(RELYINGPARTY_REGISTRATION_ID)
 						.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
 						.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
 						.build())
 						.build())
 				.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).sigAlg(SIG_ALG).signature(SIGNATURE).build();
 				.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).sigAlg(SIG_ALG).signature(SIGNATURE).build();

+ 44 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2AuthenticationTokenConverterTests.java

@@ -44,8 +44,11 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.isNull;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 
 
 @ExtendWith(MockitoExtension.class)
 @ExtendWith(MockitoExtension.class)
 public class Saml2AuthenticationTokenConverterTests {
 public class Saml2AuthenticationTokenConverterTests {
@@ -71,6 +74,21 @@ public class Saml2AuthenticationTokenConverterTests {
 				.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
 				.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
 	}
 	}
 
 
+	@Test
+	public void convertWhenSamlResponseWithRelyingPartyRegistrationResolver(
+			@Mock RelyingPartyRegistrationResolver resolver) {
+		Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(resolver);
+		given(resolver.resolve(any(HttpServletRequest.class), any())).willReturn(this.relyingPartyRegistration);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
+				Saml2Utils.samlEncodeNotRfc2045("response".getBytes(StandardCharsets.UTF_8)));
+		Saml2AuthenticationToken token = converter.convert(request);
+		assertThat(token.getSaml2Response()).isEqualTo("response");
+		assertThat(token.getRelyingPartyRegistration().getRegistrationId())
+				.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
+		verify(resolver).resolve(any(), isNull());
+	}
+
 	@Test
 	@Test
 	public void convertWhenSamlResponseInvalidBase64ThenSaml2AuthenticationException() {
 	public void convertWhenSamlResponseInvalidBase64ThenSaml2AuthenticationException() {
 		Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
 		Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
@@ -159,6 +177,8 @@ public class Saml2AuthenticationTokenConverterTests {
 		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
 		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
 				Saml2AuthenticationRequestRepository.class);
 				Saml2AuthenticationRequestRepository.class);
 		AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
 		AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
+		given(authenticationRequest.getRelyingPartyRegistrationId())
+				.willReturn(this.relyingPartyRegistration.getRegistrationId());
 		Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
 		Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(
 				this.relyingPartyRegistrationResolver);
 				this.relyingPartyRegistrationResolver);
 		converter.setAuthenticationRequestRepository(authenticationRequestRepository);
 		converter.setAuthenticationRequestRepository(authenticationRequestRepository);
@@ -176,6 +196,30 @@ public class Saml2AuthenticationTokenConverterTests {
 		assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
 		assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
 	}
 	}
 
 
+	@Test
+	public void convertWhenSavedAuthenticationRequestThenTokenWithRelyingPartyRegistrationResolver(
+			@Mock RelyingPartyRegistrationResolver resolver) {
+		Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> authenticationRequestRepository = mock(
+				Saml2AuthenticationRequestRepository.class);
+		AbstractSaml2AuthenticationRequest authenticationRequest = mock(AbstractSaml2AuthenticationRequest.class);
+		given(authenticationRequest.getRelyingPartyRegistrationId())
+				.willReturn(this.relyingPartyRegistration.getRegistrationId());
+		Saml2AuthenticationTokenConverter converter = new Saml2AuthenticationTokenConverter(resolver);
+		converter.setAuthenticationRequestRepository(authenticationRequestRepository);
+		given(resolver.resolve(any(HttpServletRequest.class), any())).willReturn(this.relyingPartyRegistration);
+		given(authenticationRequestRepository.loadAuthenticationRequest(any(HttpServletRequest.class)))
+				.willReturn(authenticationRequest);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setParameter(Saml2ParameterNames.SAML_RESPONSE,
+				Saml2Utils.samlEncode("response".getBytes(StandardCharsets.UTF_8)));
+		Saml2AuthenticationToken token = converter.convert(request);
+		assertThat(token.getSaml2Response()).isEqualTo("response");
+		assertThat(token.getRelyingPartyRegistration().getRegistrationId())
+				.isEqualTo(this.relyingPartyRegistration.getRegistrationId());
+		assertThat(token.getAuthenticationRequest()).isEqualTo(authenticationRequest);
+		verify(resolver).resolve(any(), eq(this.relyingPartyRegistration.getRegistrationId()));
+	}
+
 	@Test
 	@Test
 	public void constructorWhenResolverIsNullThenIllegalArgument() {
 	public void constructorWhenResolverIsNullThenIllegalArgument() {
 		assertThatIllegalArgumentException()
 		assertThatIllegalArgumentException()