Browse Source

In-memory ClientRegistration Repo Duplicate Check

Fixes gh-7338
Josh Cummings 6 years ago
parent
commit
5e98b92273

+ 19 - 13
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepository.java

@@ -15,16 +15,17 @@
  */
 package org.springframework.security.oauth2.client.registration;
 
+import java.util.Arrays;
+import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentHashMap;
-
-import org.springframework.util.Assert;
 
 import reactor.core.publisher.Mono;
 
+import org.springframework.util.Assert;
+
 /**
  * A Reactive {@link ClientRegistrationRepository} that stores {@link ClientRegistration}(s) in-memory.
  *
@@ -45,12 +46,12 @@ public final class InMemoryReactiveClientRegistrationRepository
 	 * @param registrations the client registration(s)
 	 */
 	public InMemoryReactiveClientRegistrationRepository(ClientRegistration... registrations) {
-		Assert.notEmpty(registrations, "registrations cannot be empty");
-		this.clientIdToClientRegistration = new ConcurrentHashMap<>();
-		for (ClientRegistration registration : registrations) {
-			Assert.notNull(registration, "registrations cannot contain null values");
-			this.clientIdToClientRegistration.put(registration.getRegistrationId(), registration);
-		}
+		this(toList(registrations));
+	}
+
+	private static List<ClientRegistration> toList(ClientRegistration... registrations) {
+		Assert.notEmpty(registrations, "registrations cannot be null or empty");
+		return Arrays.asList(registrations);
 	}
 
 	/**
@@ -59,8 +60,7 @@ public final class InMemoryReactiveClientRegistrationRepository
 	 * @param registrations the client registration(s)
 	 */
 	public InMemoryReactiveClientRegistrationRepository(List<ClientRegistration> registrations) {
-		Assert.notEmpty(registrations, "registrations cannot be null or empty");
-		this.clientIdToClientRegistration = toConcurrentMap(registrations);
+		this.clientIdToClientRegistration = toUnmodifiableConcurrentMap(registrations);
 	}
 
 	@Override
@@ -78,11 +78,17 @@ public final class InMemoryReactiveClientRegistrationRepository
 		return this.clientIdToClientRegistration.values().iterator();
 	}
 
-	private ConcurrentHashMap<String, ClientRegistration> toConcurrentMap(List<ClientRegistration> registrations) {
+	private static Map<String, ClientRegistration> toUnmodifiableConcurrentMap(List<ClientRegistration> registrations) {
+		Assert.notEmpty(registrations, "registrations cannot be null or empty");
 		ConcurrentHashMap<String, ClientRegistration> result = new ConcurrentHashMap<>();
 		for (ClientRegistration registration : registrations) {
+			Assert.notNull(registration, "no registration can be null");
+			if (result.containsKey(registration.getRegistrationId())) {
+				throw new IllegalStateException(String.format("Duplicate key %s",
+						registration.getRegistrationId()));
+			}
 			result.put(registration.getRegistrationId(), registration);
 		}
-		return result;
+		return Collections.unmodifiableMap(result);
 	}
 }

+ 10 - 4
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryReactiveClientRegistrationRepositoryTests.java

@@ -16,16 +16,16 @@
 
 package org.springframework.security.oauth2.client.registration;
 
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.assertj.core.api.Assertions.assertThatThrownBy;
-
+import java.util.Arrays;
 import java.util.List;
 
 import org.junit.Before;
 import org.junit.Test;
-
 import reactor.test.StepVerifier;
 
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+
 /**
  * @author Rob Winch
  * @since 5.1
@@ -61,6 +61,12 @@ public class InMemoryReactiveClientRegistrationRepositoryTests {
 				.isInstanceOf(IllegalArgumentException.class);
 	}
 
+	@Test(expected = IllegalStateException.class)
+	public void constructorListClientRegistrationWhenDuplicateIdThenIllegalArgumentException() {
+		List<ClientRegistration> registrations = Arrays.asList(this.registration, this.registration);
+		new InMemoryReactiveClientRegistrationRepository(registrations);
+	}
+
 	@Test
 	public void constructorWhenClientRegistrationIsNullThenIllegalArgumentException() {
 		ClientRegistration registration = null;