Explorar el Código

Add RelayStateResolver

Co-authored-by: ghaege <ghaege@qaepps.de>

Closes gh-12538
Josh Cummings hace 2 años
padre
commit
c1c28375d6

+ 13 - 1
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestResolver.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@ import java.util.function.Consumer;
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletRequest;
 import org.opensaml.saml.saml2.core.LogoutRequest;
 import org.opensaml.saml.saml2.core.LogoutRequest;
 
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
 import org.springframework.security.saml2.provider.service.authentication.logout.Saml2LogoutRequest;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
@@ -34,6 +35,7 @@ import org.springframework.util.Assert;
  * OpenSAML 4
  * OpenSAML 4
  *
  *
  * @author Josh Cummings
  * @author Josh Cummings
+ * @author Gerhard Haege
  * @since 5.6
  * @since 5.6
  */
  */
 public final class OpenSaml4LogoutRequestResolver implements Saml2LogoutRequestResolver {
 public final class OpenSaml4LogoutRequestResolver implements Saml2LogoutRequestResolver {
@@ -83,6 +85,16 @@ public final class OpenSaml4LogoutRequestResolver implements Saml2LogoutRequestR
 		this.clock = clock;
 		this.clock = clock;
 	}
 	}
 
 
+	/**
+	 * Use this {@link Converter} to compute the RelayState
+	 * @param relayStateResolver the {@link Converter} to use
+	 * @since 6.1
+	 */
+	public void setRelayStateResolver(Converter<HttpServletRequest, String> relayStateResolver) {
+		Assert.notNull(relayStateResolver, "relayStateResolver cannot be null");
+		this.logoutRequestResolver.setRelayStateResolver(relayStateResolver);
+	}
+
 	public static final class LogoutRequestParameters {
 	public static final class LogoutRequestParameters {
 
 
 		private final HttpServletRequest request;
 		private final HttpServletRequest request;

+ 9 - 2
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSamlLogoutRequestResolver.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -38,6 +38,7 @@ import org.opensaml.saml.saml2.core.impl.NameIDBuilder;
 import org.opensaml.saml.saml2.core.impl.SessionIndexBuilder;
 import org.opensaml.saml.saml2.core.impl.SessionIndexBuilder;
 import org.w3c.dom.Element;
 import org.w3c.dom.Element;
 
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.saml2.Saml2Exception;
 import org.springframework.security.saml2.Saml2Exception;
 import org.springframework.security.saml2.core.OpenSamlInitializationService;
 import org.springframework.security.saml2.core.OpenSamlInitializationService;
@@ -74,6 +75,8 @@ final class OpenSamlLogoutRequestResolver {
 
 
 	private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
 	private final RelyingPartyRegistrationResolver relyingPartyRegistrationResolver;
 
 
+	private Converter<HttpServletRequest, String> relayStateResolver = (request) -> UUID.randomUUID().toString();
+
 	/**
 	/**
 	 * Construct a {@link OpenSamlLogoutRequestResolver}
 	 * Construct a {@link OpenSamlLogoutRequestResolver}
 	 */
 	 */
@@ -95,6 +98,10 @@ final class OpenSamlLogoutRequestResolver {
 		Assert.notNull(this.sessionIndexBuilder, "sessionIndexBuilder must be configured in OpenSAML");
 		Assert.notNull(this.sessionIndexBuilder, "sessionIndexBuilder must be configured in OpenSAML");
 	}
 	}
 
 
+	void setRelayStateResolver(Converter<HttpServletRequest, String> relayStateResolver) {
+		this.relayStateResolver = relayStateResolver;
+	}
+
 	/**
 	/**
 	 * Prepare to create, sign, and serialize a SAML 2.0 Logout Request.
 	 * Prepare to create, sign, and serialize a SAML 2.0 Logout Request.
 	 *
 	 *
@@ -140,7 +147,7 @@ final class OpenSamlLogoutRequestResolver {
 		if (logoutRequest.getID() == null) {
 		if (logoutRequest.getID() == null) {
 			logoutRequest.setID("LR" + UUID.randomUUID());
 			logoutRequest.setID("LR" + UUID.randomUUID());
 		}
 		}
-		String relayState = UUID.randomUUID().toString();
+		String relayState = this.relayStateResolver.convert(request);
 		Saml2LogoutRequest.Builder result = Saml2LogoutRequest.withRelyingPartyRegistration(registration)
 		Saml2LogoutRequest.Builder result = Saml2LogoutRequest.withRelyingPartyRegistration(registration)
 				.id(logoutRequest.getID());
 				.id(logoutRequest.getID());
 		if (registration.getAssertingPartyDetails().getSingleLogoutServiceBinding() == Saml2MessageBinding.POST) {
 		if (registration.getAssertingPartyDetails().getSingleLogoutServiceBinding() == Saml2MessageBinding.POST) {

+ 17 - 1
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/authentication/logout/OpenSaml4LogoutRequestResolverTests.java

@@ -1,5 +1,5 @@
 /*
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * you may not use this file except in compliance with the License.
@@ -16,9 +16,11 @@
 
 
 package org.springframework.security.saml2.provider.service.web.authentication.logout;
 package org.springframework.security.saml2.provider.service.web.authentication.logout;
 
 
+import jakarta.servlet.http.HttpServletRequest;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Test;
 
 
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Authentication;
@@ -32,6 +34,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 
 
 /**
 /**
  * Tests for {@link OpenSaml4LogoutRequestResolver}
  * Tests for {@link OpenSaml4LogoutRequestResolver}
@@ -67,6 +70,19 @@ public class OpenSaml4LogoutRequestResolverTests {
 				.isThrownBy(() -> this.logoutRequestResolver.setParametersConsumer(null));
 				.isThrownBy(() -> this.logoutRequestResolver.setParametersConsumer(null));
 	}
 	}
 
 
+	@Test
+	public void resolveWhenCustomRelayStateThenUses() {
+		given(this.registrationResolver.resolve(any(), any())).willReturn(this.registration);
+		Converter<HttpServletRequest, String> relayState = mock(Converter.class);
+		given(relayState.convert(any())).willReturn("any-state");
+		this.logoutRequestResolver.setRelayStateResolver(relayState);
+
+		Saml2LogoutRequest logoutRequest = this.logoutRequestResolver.resolve(givenRequest(), givenAuthentication());
+
+		assertThat(logoutRequest.getRelayState()).isEqualTo("any-state");
+		verify(relayState).convert(any());
+	}
+
 	private static Authentication givenAuthentication() {
 	private static Authentication givenAuthentication() {
 		return new TestingAuthenticationToken("user", "password");
 		return new TestingAuthenticationToken("user", "password");
 	}
 	}