Explorar o código

Add ID to Saml2 Post and Redirect Requests

Closes gh-11468
Scott Shidlovsky %!s(int64=3) %!d(string=hai) anos
pai
achega
947445fcc5

+ 2 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixin.java

@@ -48,7 +48,8 @@ class Saml2PostAuthenticationRequestMixin {
 	Saml2PostAuthenticationRequestMixin(@JsonProperty("samlRequest") String samlRequest,
 			@JsonProperty("relayState") String relayState,
 			@JsonProperty("authenticationRequestUri") String authenticationRequestUri,
-			@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) {
+			@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId,
+			@JsonProperty("id") String id) {
 	}
 
 }

+ 2 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/jackson2/Saml2RedirectAuthenticationRequestMixin.java

@@ -49,7 +49,8 @@ class Saml2RedirectAuthenticationRequestMixin {
 			@JsonProperty("sigAlg") String sigAlg, @JsonProperty("signature") String signature,
 			@JsonProperty("relayState") String relayState,
 			@JsonProperty("authenticationRequestUri") String authenticationRequestUri,
-			@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId) {
+			@JsonProperty("relyingPartyRegistrationId") String relyingPartyRegistrationId,
+			@JsonProperty("id") String id) {
 	}
 
 }

+ 30 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/AbstractSaml2AuthenticationRequest.java

@@ -49,6 +49,8 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
 
 	private final String relyingPartyRegistrationId;
 
+	private final String id;
+
 	/**
 	 * Mandatory constructor for the {@link AbstractSaml2AuthenticationRequest}
 	 * @param samlRequest - the SAMLRequest XML data, SAML encoded, cannot be empty or
@@ -58,15 +60,18 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
 	 * send the XML message, cannot be empty or null
 	 * @param relyingPartyRegistrationId the registration id of the relying party, may be
 	 * null
+	 * @param id This is the unique id used in the {@link #samlRequest}, cannot be empty
+	 * or null
 	 */
 	AbstractSaml2AuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri,
-			String relyingPartyRegistrationId) {
+			String relyingPartyRegistrationId, String id) {
 		Assert.hasText(samlRequest, "samlRequest cannot be null or empty");
 		Assert.hasText(authenticationRequestUri, "authenticationRequestUri cannot be null or empty");
 		this.authenticationRequestUri = authenticationRequestUri;
 		this.samlRequest = samlRequest;
 		this.relayState = relayState;
 		this.relyingPartyRegistrationId = relyingPartyRegistrationId;
+		this.id = id;
 	}
 
 	/**
@@ -106,6 +111,15 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
 		return this.relyingPartyRegistrationId;
 	}
 
+	/**
+	 * The unique identifier for this Authentication Request
+	 * @return the Authentication Request identifier
+	 * @since 5.8
+	 */
+	public String getId() {
+		return this.id;
+	}
+
 	/**
 	 * Returns the binding this AuthNRequest will be sent and encoded with. If
 	 * {@link Saml2MessageBinding#REDIRECT} is used, the DEFLATE encoding will be
@@ -127,6 +141,8 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
 
 		String relyingPartyRegistrationId;
 
+		String id;
+
 		/**
 		 * @deprecated Use {@link #Builder(RelyingPartyRegistration)} instead
 		 */
@@ -184,6 +200,19 @@ public abstract class AbstractSaml2AuthenticationRequest implements Serializable
 			return _this();
 		}
 
+		/**
+		 * This is the unique id used in the {@link #samlRequest}
+		 * @param id the SAML2 request id
+		 * @return the {@link AbstractSaml2AuthenticationRequest.Builder} for further
+		 * configurations
+		 * @since 5.8
+		 */
+		public T id(String id) {
+			Assert.notNull(id, "id cannot be null");
+			this.id = id;
+			return _this();
+		}
+
 	}
 
 }

+ 3 - 3
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2PostAuthenticationRequest.java

@@ -31,8 +31,8 @@ import org.springframework.security.saml2.provider.service.registration.Saml2Mes
 public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationRequest {
 
 	Saml2PostAuthenticationRequest(String samlRequest, String relayState, String authenticationRequestUri,
-			String relyingPartyRegistrationId) {
-		super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId);
+			String relyingPartyRegistrationId, String id) {
+		super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId, id);
 	}
 
 	/**
@@ -69,7 +69,7 @@ public class Saml2PostAuthenticationRequest extends AbstractSaml2AuthenticationR
 		 */
 		public Saml2PostAuthenticationRequest build() {
 			return new Saml2PostAuthenticationRequest(this.samlRequest, this.relayState, this.authenticationRequestUri,
-					this.relyingPartyRegistrationId);
+					this.relyingPartyRegistrationId, this.id);
 		}
 
 	}

+ 3 - 3
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/Saml2RedirectAuthenticationRequest.java

@@ -35,8 +35,8 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
 	private final String signature;
 
 	private Saml2RedirectAuthenticationRequest(String samlRequest, String sigAlg, String signature, String relayState,
-			String authenticationRequestUri, String relyingPartyRegistrationId) {
-		super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId);
+			String authenticationRequestUri, String relyingPartyRegistrationId, String id) {
+		super(samlRequest, relayState, authenticationRequestUri, relyingPartyRegistrationId, id);
 		this.sigAlg = sigAlg;
 		this.signature = signature;
 	}
@@ -116,7 +116,7 @@ public final class Saml2RedirectAuthenticationRequest extends AbstractSaml2Authe
 		 */
 		public Saml2RedirectAuthenticationRequest build() {
 			return new Saml2RedirectAuthenticationRequest(this.samlRequest, this.sigAlg, this.signature,
-					this.relayState, this.authenticationRequestUri, this.relyingPartyRegistrationId);
+					this.relayState, this.authenticationRequestUri, this.relyingPartyRegistrationId, this.id);
 		}
 
 	}

+ 3 - 2
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/OpenSamlAuthenticationRequestResolver.java

@@ -142,13 +142,14 @@ class OpenSamlAuthenticationRequestResolver {
 			String xml = serialize(authnRequest);
 			String encoded = Saml2Utils.samlEncode(xml.getBytes(StandardCharsets.UTF_8));
 			return (T) Saml2PostAuthenticationRequest.withRelyingPartyRegistration(registration).samlRequest(encoded)
-					.relayState(relayState).build();
+					.relayState(relayState).id(authnRequest.getID()).build();
 		}
 		else {
 			String xml = serialize(authnRequest);
 			String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
 			Saml2RedirectAuthenticationRequest.Builder builder = Saml2RedirectAuthenticationRequest
-					.withRelyingPartyRegistration(registration).samlRequest(deflatedAndEncoded).relayState(relayState);
+					.withRelyingPartyRegistration(registration).samlRequest(deflatedAndEncoded).relayState(relayState)
+					.id(authnRequest.getID());
 			if (registration.getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
 				Map<String, String> parameters = OpenSamlSigningUtils.sign(registration)
 						.param(Saml2ParameterNames.SAML_REQUEST, deflatedAndEncoded)

+ 19 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/Saml2PostAuthenticationRequestMixinTests.java

@@ -58,6 +58,7 @@ class Saml2PostAuthenticationRequestMixinTests {
 				.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
 		assertThat(authRequest.getRelyingPartyRegistrationId())
 				.isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID);
+		assertThat(authRequest.getId()).isEqualTo(TestSaml2JsonPayloads.ID);
 	}
 
 	@Test
@@ -73,6 +74,24 @@ class Saml2PostAuthenticationRequestMixinTests {
 		assertThat(authRequest.getAuthenticationRequestUri())
 				.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
 		assertThat(authRequest.getRelyingPartyRegistrationId()).isNull();
+		assertThat(authRequest.getId()).isEqualTo(TestSaml2JsonPayloads.ID);
+	}
+
+	@Test
+	void shouldDeserializeWithNoId() throws Exception {
+		String json = TestSaml2JsonPayloads.DEFAULT_POST_AUTH_REQUEST_JSON
+				.replace(", \"id\": \"" + TestSaml2JsonPayloads.ID + "\"", "");
+
+		Saml2PostAuthenticationRequest authRequest = this.mapper.readValue(json, Saml2PostAuthenticationRequest.class);
+
+		assertThat(authRequest).isNotNull();
+		assertThat(authRequest.getSamlRequest()).isEqualTo(TestSaml2JsonPayloads.SAML_REQUEST);
+		assertThat(authRequest.getRelayState()).isEqualTo(TestSaml2JsonPayloads.RELAY_STATE);
+		assertThat(authRequest.getAuthenticationRequestUri())
+				.isEqualTo(TestSaml2JsonPayloads.AUTHENTICATION_REQUEST_URI);
+		assertThat(authRequest.getRelyingPartyRegistrationId())
+				.isEqualTo(TestSaml2JsonPayloads.RELYINGPARTY_REGISTRATION_ID);
+		assertThat(authRequest.getId()).isNull();
 	}
 
 }

+ 7 - 5
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/jackson2/TestSaml2JsonPayloads.java

@@ -97,6 +97,7 @@ final class TestSaml2JsonPayloads {
 	static final String RELYINGPARTY_REGISTRATION_ID = "registrationIdValue";
 	static final String SIG_ALG = "sigAlgValue";
 	static final String SIGNATURE = "signatureValue";
+	static final String ID = "idValue";
 
 	// @formatter:off
 	static final String DEFAULT_REDIRECT_AUTH_REQUEST_JSON = "{"
@@ -106,7 +107,8 @@ final class TestSaml2JsonPayloads {
 			+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\","
 			+ " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\","
 			+ " \"sigAlg\": \"" + SIG_ALG + "\","
-			+ " \"signature\": \"" + SIGNATURE + "\""
+			+ " \"signature\": \"" + SIGNATURE + "\","
+			+ " \"id\": \"" + ID + "\""
 			+ "}";
 	// @formatter:on
 
@@ -116,11 +118,11 @@ final class TestSaml2JsonPayloads {
 			+ " \"samlRequest\": \"" + SAML_REQUEST + "\","
 			+ " \"relayState\": \"" + RELAY_STATE + "\","
 			+ " \"relyingPartyRegistrationId\": \"" + RELYINGPARTY_REGISTRATION_ID + "\","
-			+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\""
+			+ " \"authenticationRequestUri\": \"" + AUTHENTICATION_REQUEST_URI + "\","
+			+ " \"id\": \"" + ID + "\""
 			+ "}";
 	// @formatter:on
 
-	static final String ID = "idValue";
 	static final String LOCATION = "locationValue";
 	static final String BINDNG = "REDIRECT";
 	static final String ADDITIONAL_PARAM = "additionalParamValue";
@@ -146,7 +148,7 @@ final class TestSaml2JsonPayloads {
 				TestRelyingPartyRegistrations.full().registrationId(RELYINGPARTY_REGISTRATION_ID)
 						.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
 						.build())
-				.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).build();
+				.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).id(ID).build();
 	}
 
 	static Saml2RedirectAuthenticationRequest createDefaultSaml2RedirectAuthenticationRequest() {
@@ -155,7 +157,7 @@ final class TestSaml2JsonPayloads {
 						.registrationId(RELYINGPARTY_REGISTRATION_ID)
 						.assertingPartyDetails((party) -> party.singleSignOnServiceLocation(AUTHENTICATION_REQUEST_URI))
 						.build())
-				.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).sigAlg(SIG_ALG).signature(SIGNATURE).build();
+				.samlRequest(SAML_REQUEST).relayState(RELAY_STATE).sigAlg(SIG_ALG).signature(SIGNATURE).id(ID).build();
 	}
 
 	static Saml2LogoutRequest createDefaultSaml2LogoutRequest() {