Przeglądaj źródła

Add CacheSaml2AuthenticationRequestRepository

Closes gh-14793
Josh Cummings 4 miesięcy temu
rodzic
commit
a283700ef8

+ 40 - 0
docs/modules/ROOT/pages/servlet/saml2/login/authentication-requests.adoc

@@ -83,6 +83,46 @@ open fun authenticationRequestRepository(): Saml2AuthenticationRequestRepository
 ----
 ======
 
+=== Caching the `<saml2:AuthnRequest>` by the Relay State
+
+If you don't want to use the session to store the `<saml2:AuthnRequest>`, you can also store it in a distributed cache.
+This can be helpful if you are trying to use `SameSite=Strict` and are losing the authentication request in the redirect from the Identity Provider.
+
+[NOTE]
+=====
+It's important to remember that there are security benefits to storing it in the session.
+One such benefit is the natural login fixation defense it provides.
+For example, if an application looks the authentication request up from the session, then even if an attacker provides their own SAML response to a victim, the login will fail.
+
+On the other hand, if we trust the InResponseTo or RelayState to retrieve the authentication request, then there's no way to know if the SAML response was requested by that handshake.
+=====
+
+To help with this, Spring Security has `CacheSaml2AuthenticationRequestRepository`, which you can publish as a bean for the filter chain to pick up:
+
+[tabs]
+======
+Java::
++
+[source,java,role="primary"]
+----
+@Bean
+Saml2AuthenticationRequestRepository<?> authenticationRequestRepository() {
+	return new CacheSaml2AuthenticationRequestRepository();
+}
+----
+
+Kotlin::
++
+[source,kotlin,role="secondary"]
+----
+@Bean
+fun authenticationRequestRepository(): Saml2AuthenticationRequestRepository<*> {
+    return CacheSaml2AuthenticationRequestRepository()
+}
+----
+======
+
+
 [[servlet-saml2login-sp-initiated-factory-signing]]
 == Changing How the `<saml2:AuthnRequest>` Gets Sent
 

+ 84 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/CacheSaml2AuthenticationRequestRepository.java

@@ -0,0 +1,84 @@
+/*
+ * Copyright 2002-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web;
+
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+
+import org.springframework.cache.Cache;
+import org.springframework.cache.concurrent.ConcurrentMapCache;
+import org.springframework.security.saml2.core.Saml2ParameterNames;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
+import org.springframework.util.Assert;
+
+/**
+ * A cache-based {@link Saml2AuthenticationRequestRepository}. This can be handy when you
+ * are dropping requests due to using SameSite=Strict and the previous session is lost.
+ *
+ * <p>
+ * On the other hand, this presents a tradeoff where the application can only tell that
+ * the given authentication request was created by this application, but cannot guarantee
+ * that it was for the user trying to log in. Please see the reference for details.
+ *
+ * @author Josh Cummings
+ * @since 6.5
+ */
+public final class CacheSaml2AuthenticationRequestRepository
+		implements Saml2AuthenticationRequestRepository<AbstractSaml2AuthenticationRequest> {
+
+	private Cache cache = new ConcurrentMapCache("authentication-requests");
+
+	@Override
+	public AbstractSaml2AuthenticationRequest loadAuthenticationRequest(HttpServletRequest request) {
+		String relayState = request.getParameter(Saml2ParameterNames.RELAY_STATE);
+		Assert.notNull(relayState, "relayState must not be null");
+		return this.cache.get(relayState, AbstractSaml2AuthenticationRequest.class);
+	}
+
+	@Override
+	public void saveAuthenticationRequest(AbstractSaml2AuthenticationRequest authenticationRequest,
+			HttpServletRequest request, HttpServletResponse response) {
+		String relayState = request.getParameter(Saml2ParameterNames.RELAY_STATE);
+		Assert.notNull(relayState, "relayState must not be null");
+		this.cache.put(relayState, authenticationRequest);
+	}
+
+	@Override
+	public AbstractSaml2AuthenticationRequest removeAuthenticationRequest(HttpServletRequest request,
+			HttpServletResponse response) {
+		String relayState = request.getParameter(Saml2ParameterNames.RELAY_STATE);
+		Assert.notNull(relayState, "relayState must not be null");
+		AbstractSaml2AuthenticationRequest authenticationRequest = this.cache.get(relayState,
+				AbstractSaml2AuthenticationRequest.class);
+		if (authenticationRequest == null) {
+			return null;
+		}
+		this.cache.evict(relayState);
+		return authenticationRequest;
+	}
+
+	/**
+	 * Use this {@link Cache} instance. The default is an in-memory cache, which means it
+	 * won't work in a clustered environment. Instead, replace it here with a distributed
+	 * cache.
+	 * @param cache the {@link Cache} instance to use
+	 */
+	public void setCache(Cache cache) {
+		this.cache = cache;
+	}
+
+}

+ 91 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/CacheSaml2AuthenticationRequestRepositoryTests.java

@@ -0,0 +1,91 @@
+/*
+ * Copyright 2002-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web;
+
+import org.junit.jupiter.api.Test;
+
+import org.springframework.cache.Cache;
+import org.springframework.cache.concurrent.ConcurrentMapCache;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.security.saml2.core.Saml2ParameterNames;
+import org.springframework.security.saml2.provider.service.authentication.AbstractSaml2AuthenticationRequest;
+import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
+import org.springframework.security.saml2.provider.service.authentication.TestSaml2PostAuthenticationRequests;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+
+/**
+ * Tests for {@link CacheSaml2AuthenticationRequestRepository}
+ */
+class CacheSaml2AuthenticationRequestRepositoryTests {
+
+	CacheSaml2AuthenticationRequestRepository repository = new CacheSaml2AuthenticationRequestRepository();
+
+	@Test
+	void loadAuthenticationRequestWhenCachedThenReturns() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setParameter(Saml2ParameterNames.RELAY_STATE, "test");
+		Saml2PostAuthenticationRequest authenticationRequest = TestSaml2PostAuthenticationRequests.create();
+		this.repository.saveAuthenticationRequest(authenticationRequest, request, null);
+		assertThat(this.repository.loadAuthenticationRequest(request)).isEqualTo(authenticationRequest);
+		this.repository.removeAuthenticationRequest(request, null);
+		assertThat(this.repository.loadAuthenticationRequest(request)).isNull();
+	}
+
+	@Test
+	void loadAuthenticationRequestWhenNoRelayStateThenException() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		assertThatExceptionOfType(IllegalArgumentException.class)
+			.isThrownBy(() -> this.repository.loadAuthenticationRequest(request));
+	}
+
+	@Test
+	void saveAuthenticationRequestWhenNoRelayStateThenException() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		assertThatExceptionOfType(IllegalArgumentException.class)
+			.isThrownBy(() -> this.repository.saveAuthenticationRequest(null, request, null));
+	}
+
+	@Test
+	void removeAuthenticationRequestWhenNoRelayStateThenException() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		assertThatExceptionOfType(IllegalArgumentException.class)
+			.isThrownBy(() -> this.repository.removeAuthenticationRequest(request, null));
+	}
+
+	@Test
+	void repositoryWhenCustomCacheThenUses() {
+		CacheSaml2AuthenticationRequestRepository repository = new CacheSaml2AuthenticationRequestRepository();
+		Cache cache = spy(new ConcurrentMapCache("requests"));
+		repository.setCache(cache);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setParameter(Saml2ParameterNames.RELAY_STATE, "test");
+		Saml2PostAuthenticationRequest authenticationRequest = TestSaml2PostAuthenticationRequests.create();
+		repository.saveAuthenticationRequest(authenticationRequest, request, null);
+		verify(cache).put(eq("test"), any());
+		repository.loadAuthenticationRequest(request);
+		verify(cache).get("test", AbstractSaml2AuthenticationRequest.class);
+		repository.removeAuthenticationRequest(request, null);
+		verify(cache).evict("test");
+	}
+
+}