Răsfoiți Sursa

Make InMemory*ClientRegistrationRepository Consistent

The previous builders with the list argument were inconsistent with their 
respective builders of var args.
dperezcabrera 6 ani în urmă
părinte
comite
8014114225

+ 12 - 10
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java

@@ -18,16 +18,13 @@ package org.springframework.security.oauth2.client.registration;
 import org.springframework.util.Assert;
 
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.concurrent.ConcurrentMap;
 import java.util.function.Function;
-import java.util.stream.Collector;
-
-import static java.util.stream.Collectors.collectingAndThen;
-import static java.util.stream.Collectors.toConcurrentMap;
+import java.util.stream.Collectors;
 
 /**
  * A {@link ClientRegistrationRepository} that stores {@link ClientRegistration}(s) in-memory.
@@ -39,6 +36,7 @@ import static java.util.stream.Collectors.toConcurrentMap;
  * @see ClientRegistration
  */
 public final class InMemoryClientRegistrationRepository implements ClientRegistrationRepository, Iterable<ClientRegistration> {
+
 	private final Map<String, ClientRegistration> registrations;
 
 	/**
@@ -47,7 +45,8 @@ public final class InMemoryClientRegistrationRepository implements ClientRegistr
 	 * @param registrations the client registration(s)
 	 */
 	public InMemoryClientRegistrationRepository(ClientRegistration... registrations) {
-		this(Arrays.asList(registrations));
+		Assert.notEmpty(registrations, "registrations cannot be empty");
+		this.registrations = createClientRegistrationIdToClientRegistration(Arrays.asList(registrations));
 	}
 
 	/**
@@ -57,10 +56,13 @@ public final class InMemoryClientRegistrationRepository implements ClientRegistr
 	 */
 	public InMemoryClientRegistrationRepository(List<ClientRegistration> registrations) {
 		Assert.notEmpty(registrations, "registrations cannot be empty");
-		Collector<ClientRegistration, ?, ConcurrentMap<String, ClientRegistration>> collector =
-			toConcurrentMap(ClientRegistration::getRegistrationId, Function.identity());
-		this.registrations = registrations.stream()
-			.collect(collectingAndThen(collector, Collections::unmodifiableMap));
+		this.registrations = createClientRegistrationIdToClientRegistration(registrations);
+	}
+
+	private static Map<String, ClientRegistration> createClientRegistrationIdToClientRegistration(Collection<ClientRegistration> registrations) {
+		return Collections.unmodifiableMap(registrations.stream()
+				.peek(registration -> Assert.notNull(registration, "registrations cannot contain null values"))
+				.collect(Collectors.toMap(ClientRegistration::getRegistrationId, Function.identity())));
 	}
 
 	@Override

+ 5 - 19
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepository.java

@@ -17,12 +17,6 @@ package org.springframework.security.oauth2.client.registration;
 
 import java.util.Iterator;
 import java.util.List;
-import java.util.Map;
-import java.util.function.Function;
-import java.util.stream.Collectors;
-
-import org.springframework.util.Assert;
-import org.springframework.util.ConcurrentReferenceHashMap;
 
 import reactor.core.publisher.Mono;
 
@@ -37,7 +31,7 @@ import reactor.core.publisher.Mono;
 public final class InMemoryReactiveClientRegistrationRepository
 		implements ReactiveClientRegistrationRepository, Iterable<ClientRegistration> {
 
-	private final Map<String, ClientRegistration> clientIdToClientRegistration;
+	private final InMemoryClientRegistrationRepository delegate;
 
 	/**
 	 * Constructs an {@code InMemoryReactiveClientRegistrationRepository} using the provided parameters.
@@ -45,12 +39,7 @@ public final class InMemoryReactiveClientRegistrationRepository
 	 * @param registrations the client registration(s)
 	 */
 	public InMemoryReactiveClientRegistrationRepository(ClientRegistration... registrations) {
-		Assert.notEmpty(registrations, "registrations cannot be empty");
-		this.clientIdToClientRegistration = new ConcurrentReferenceHashMap<>();
-		for (ClientRegistration registration : registrations) {
-			Assert.notNull(registration, "registrations cannot contain null values");
-			this.clientIdToClientRegistration.put(registration.getRegistrationId(), registration);
-		}
+		this.delegate = new InMemoryClientRegistrationRepository(registrations);
 	}
 
 	/**
@@ -59,15 +48,12 @@ public final class InMemoryReactiveClientRegistrationRepository
 	 * @param registrations the client registration(s)
 	 */
 	public InMemoryReactiveClientRegistrationRepository(List<ClientRegistration> registrations) {
-		Assert.notEmpty(registrations, "registrations cannot be null or empty");
-		this.clientIdToClientRegistration = registrations.stream()
-				.collect(Collectors.toConcurrentMap(ClientRegistration::getRegistrationId, Function.identity()));
+		this.delegate = new InMemoryClientRegistrationRepository(registrations);
 	}
 
-
 	@Override
 	public Mono<ClientRegistration> findByRegistrationId(String registrationId) {
-		return Mono.justOrEmpty(this.clientIdToClientRegistration.get(registrationId));
+		return Mono.justOrEmpty(this.delegate.findByRegistrationId(registrationId));
 	}
 
 	/**
@@ -77,6 +63,6 @@ public final class InMemoryReactiveClientRegistrationRepository
 	 */
 	@Override
 	public Iterator<ClientRegistration> iterator() {
-		return this.clientIdToClientRegistration.values().iterator();
+		return delegate.iterator();
 	}
 }

+ 6 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepositoryTests.java

@@ -35,6 +35,12 @@ public class InMemoryClientRegistrationRepositoryTests {
 
 	private InMemoryClientRegistrationRepository clients = new InMemoryClientRegistrationRepository(this.registration);
 
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorVarArgsListClientRegistrationWhenNullThenIllegalArgumentException() {
+		ClientRegistration nullRegistration = null;
+		new InMemoryClientRegistrationRepository(nullRegistration);
+	}
+
 	@Test(expected = IllegalArgumentException.class)
 	public void constructorListClientRegistrationWhenNullThenIllegalArgumentException() {
 		List<ClientRegistration> registrations = null;

+ 9 - 2
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepositoryTests.java

@@ -19,6 +19,7 @@ package org.springframework.security.oauth2.client.registration;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
+import java.util.Arrays;
 import java.util.List;
 
 import org.junit.Before;
@@ -61,10 +62,16 @@ public class InMemoryReactiveClientRegistrationRepositoryTests {
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 
+	@Test
+	public void constructorWhenClientRegistrationListHasNullThenIllegalArgumentException() {
+		List<ClientRegistration> registrations = Arrays.asList(null, registration);
+		assertThatThrownBy(() -> new InMemoryReactiveClientRegistrationRepository(registrations))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
 	@Test
 	public void constructorWhenClientRegistrationIsNullThenIllegalArgumentException() {
-		ClientRegistration registration = null;
-		assertThatThrownBy(() -> new InMemoryReactiveClientRegistrationRepository(registration))
+		assertThatThrownBy(() -> new InMemoryReactiveClientRegistrationRepository(registration, null))
 				.isInstanceOf(IllegalArgumentException.class);
 	}