Browse Source

Add CachingRelyingPartyRegistrationRepository

Closes gh-15341
Josh Cummings 1 year ago
parent
commit
7b39800606

+ 51 - 0
docs/modules/ROOT/pages/servlet/saml2/login/overview.adoc

@@ -588,6 +588,57 @@ class MyCustomSecurityConfiguration {
 A relying party can be multi-tenant by registering more than one relying party in the `RelyingPartyRegistrationRepository`.
 ====
 
+[[servlet-saml2login-relyingpartyregistrationrepository-caching]]
+If you want your metadata to be refreshable on a periodic basis, you can wrap your repository in `CachingRelyingPartyRegistrationRepository` like so:
+
+.Caching Relying Party Registration Repository
+[tabs]
+======
+Java::
++
+[source,java,role="primary"]
+----
+@Configuration
+@EnableWebSecurity
+public class MyCustomSecurityConfiguration {
+    @Bean
+    public RelyingPartyRegistrationRepository registrations(CacheManager cacheManager) {
+		Supplier<IterableRelyingPartyRegistrationRepository> delegate = () ->
+            new InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistrations
+                .fromMetadataLocation("https://idp.example.org/ap/metadata")
+                .registrationId("ap").build());
+		CachingRelyingPartyRegistrationRepository registrations =
+            new CachingRelyingPartyRegistrationRepository(delegate);
+		registrations.setCache(cacheManager.getCache("my-cache-name"));
+        return registrations;
+    }
+}
+----
+
+Kotlin::
++
+[source,kotlin,role="secondary"]
+----
+@Configuration
+@EnableWebSecurity
+class MyCustomSecurityConfiguration  {
+    @Bean
+    fun registrations(cacheManager: CacheManager): RelyingPartyRegistrationRepository {
+        val delegate = Supplier<IterableRelyingPartyRegistrationRepository> {
+             InMemoryRelyingPartyRegistrationRepository(RelyingPartyRegistrations
+                .fromMetadataLocation("https://idp.example.org/ap/metadata")
+                .registrationId("ap").build())
+        }
+        val registrations = CachingRelyingPartyRegistrationRepository(delegate)
+        registrations.setCache(cacheManager.getCache("my-cache-name"))
+        return registrations
+    }
+}
+----
+======
+
+In this way, the set of `RelyingPartyRegistration`s will refresh based on {spring-framework-reference-url}integration/cache/store-configuration.html[the cache's eviction schedule].
+
 [[servlet-saml2login-relyingpartyregistration]]
 == RelyingPartyRegistration
 A {security-api-url}org/springframework/security/saml2/provider/service/registration/RelyingPartyRegistration.html[`RelyingPartyRegistration`]

+ 95 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/CachingRelyingPartyRegistrationRepository.java

@@ -0,0 +1,95 @@
+/*
+ * Copyright 2002-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.registration;
+
+import java.util.Iterator;
+import java.util.Spliterator;
+import java.util.concurrent.Callable;
+import java.util.function.Consumer;
+
+import org.springframework.cache.Cache;
+import org.springframework.cache.concurrent.ConcurrentMapCache;
+import org.springframework.util.Assert;
+
+/**
+ * An {@link IterableRelyingPartyRegistrationRepository} that lazily queries and caches
+ * metadata from a backing {@link IterableRelyingPartyRegistrationRepository}. Delegates
+ * caching policies to Spring Cache.
+ *
+ * @author Josh Cummings
+ * @since 6.4
+ */
+public final class CachingRelyingPartyRegistrationRepository implements IterableRelyingPartyRegistrationRepository {
+
+	private final Callable<IterableRelyingPartyRegistrationRepository> registrationLoader;
+
+	private Cache cache = new ConcurrentMapCache("registrations");
+
+	public CachingRelyingPartyRegistrationRepository(Callable<IterableRelyingPartyRegistrationRepository> loader) {
+		this.registrationLoader = loader;
+	}
+
+	/**
+	 * {@inheritDoc}
+	 */
+	@Override
+	public Iterator<RelyingPartyRegistration> iterator() {
+		return registrations().iterator();
+	}
+
+	/**
+	 * {@inheritDoc}
+	 */
+	@Override
+	public RelyingPartyRegistration findByRegistrationId(String registrationId) {
+		return registrations().findByRegistrationId(registrationId);
+	}
+
+	@Override
+	public RelyingPartyRegistration findUniqueByAssertingPartyEntityId(String entityId) {
+		return registrations().findUniqueByAssertingPartyEntityId(entityId);
+	}
+
+	@Override
+	public void forEach(Consumer<? super RelyingPartyRegistration> action) {
+		registrations().forEach(action);
+	}
+
+	@Override
+	public Spliterator<RelyingPartyRegistration> spliterator() {
+		return registrations().spliterator();
+	}
+
+	private IterableRelyingPartyRegistrationRepository registrations() {
+		return this.cache.get("registrations", this.registrationLoader);
+	}
+
+	/**
+	 * Use this cache for the completed {@link RelyingPartyRegistration} instances.
+	 *
+	 * <p>
+	 * Defaults to {@link ConcurrentMapCache}, meaning that the registrations are cached
+	 * without expiry. To turn off the cache, use
+	 * {@link org.springframework.cache.support.NoOpCache}.
+	 * @param cache the {@link Cache} to use
+	 */
+	public void setCache(Cache cache) {
+		Assert.notNull(cache, "cache cannot be null");
+		this.cache = cache;
+	}
+
+}

+ 81 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/CachingRelyingPartyRegistrationRepositoryTests.java

@@ -0,0 +1,81 @@
+/*
+ * Copyright 2002-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.registration;
+
+import java.util.concurrent.Callable;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.InjectMocks;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
+
+import org.springframework.cache.Cache;
+
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+
+/**
+ * Tests for {@link CachingRelyingPartyRegistrationRepository}
+ */
+@ExtendWith(MockitoExtension.class)
+public class CachingRelyingPartyRegistrationRepositoryTests {
+
+	@Mock
+	Callable<Iterable<RelyingPartyRegistration>> callable;
+
+	@InjectMocks
+	CachingRelyingPartyRegistrationRepository registrations;
+
+	@Test
+	public void iteratorWhenResolvableThenPopulatesCache() throws Exception {
+		given(this.callable.call()).willReturn(mock(IterableRelyingPartyRegistrationRepository.class));
+		this.registrations.iterator();
+		verify(this.callable).call();
+		this.registrations.iterator();
+		verifyNoMoreInteractions(this.callable);
+	}
+
+	@Test
+	public void iteratorWhenExceptionThenPropagates() throws Exception {
+		given(this.callable.call()).willThrow(IllegalStateException.class);
+		assertThatExceptionOfType(Cache.ValueRetrievalException.class).isThrownBy(this.registrations::iterator)
+			.withCauseInstanceOf(IllegalStateException.class);
+	}
+
+	@Test
+	public void findByRegistrationIdWhenResolvableThenPopulatesCache() throws Exception {
+		given(this.callable.call()).willReturn(mock(IterableRelyingPartyRegistrationRepository.class));
+		this.registrations.findByRegistrationId("id");
+		verify(this.callable).call();
+		this.registrations.findByRegistrationId("id");
+		verifyNoMoreInteractions(this.callable);
+	}
+
+	@Test
+	public void findUniqueByAssertingPartyEntityIdWhenResolvableThenPopulatesCache() throws Exception {
+		given(this.callable.call()).willReturn(mock(IterableRelyingPartyRegistrationRepository.class));
+		this.registrations.findUniqueByAssertingPartyEntityId("id");
+		verify(this.callable).call();
+		this.registrations.findUniqueByAssertingPartyEntityId("id");
+		verifyNoMoreInteractions(this.callable);
+	}
+
+}