浏览代码

Polish InMemoryClientRegistrationRepository

- use Map.get
- Construct with stream()
- Add tests
- Remove unnecessary unmodifiableCollection (already unmodifiable)

Fixes gh-4745
Rob Winch 7 年之前
父节点
当前提交
8032baa296

+ 15 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java

@@ -82,6 +82,21 @@ public final class ClientRegistration {
 		return this.clientName;
 	}
 
+	@Override
+	public String toString() {
+		return "ClientRegistration{"
+			+ "registrationId='" + this.registrationId + '\''
+			+ ", clientId='" + this.clientId + '\''
+			+ ", clientSecret='" + this.clientSecret + '\''
+			+ ", clientAuthenticationMethod=" + this.clientAuthenticationMethod
+			+ ", authorizationGrantType=" + this.authorizationGrantType
+			+ ", redirectUri='" + this.redirectUri + '\''
+			+ ", scopes=" + this.scopes
+			+ ", providerDetails=" + this.providerDetails
+			+ ", clientName='" + this.clientName
+			+ '\'' + '}';
+	}
+
 	public class ProviderDetails {
 		private String authorizationUri;
 		private String tokenUri;

+ 13 - 15
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepository.java

@@ -21,7 +21,13 @@ import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
-import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.function.Function;
+import java.util.stream.Collector;
+
+import static java.util.stream.Collectors.collectingAndThen;
+import static java.util.stream.Collectors.toConcurrentMap;
+import static java.util.stream.Collectors.toMap;
 
 /**
  * A {@link ClientRegistrationRepository} that stores {@link ClientRegistration}(s) <i>in-memory</i>.
@@ -36,28 +42,20 @@ public final class InMemoryClientRegistrationRepository implements ClientRegistr
 
 	public InMemoryClientRegistrationRepository(List<ClientRegistration> registrations) {
 		Assert.notEmpty(registrations, "registrations cannot be empty");
-		Map<String, ClientRegistration> registrationsMap = new ConcurrentHashMap<>();
-		registrations.forEach(registration -> {
-			if (registrationsMap.containsKey(registration.getRegistrationId())) {
-				throw new IllegalArgumentException("ClientRegistration must be unique. Found duplicate registrationId: " +
-					registration.getRegistrationId());
-			}
-			registrationsMap.put(registration.getRegistrationId(), registration);
-		});
-		this.registrations = Collections.unmodifiableMap(registrationsMap);
+		Collector<ClientRegistration, ?, ConcurrentMap<String, ClientRegistration>> collector =
+			toConcurrentMap(ClientRegistration::getRegistrationId, Function.identity());
+		this.registrations = registrations.stream()
+			.collect(collectingAndThen(collector, Collections::unmodifiableMap));
 	}
 
 	@Override
 	public ClientRegistration findByRegistrationId(String registrationId) {
 		Assert.hasText(registrationId, "registrationId cannot be empty");
-		return this.registrations.values().stream()
-			.filter(registration -> registration.getRegistrationId().equals(registrationId))
-			.findFirst()
-			.orElse(null);
+		return this.registrations.get(registrationId);
 	}
 
 	@Override
 	public Iterator<ClientRegistration> iterator() {
-		return Collections.unmodifiableCollection(this.registrations.values()).iterator();
+		return this.registrations.values().iterator();
 	}
 }

+ 94 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/InMemoryClientRegistrationRepositoryTests.java

@@ -0,0 +1,94 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client.registration;
+
+import org.junit.Test;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.assertj.core.api.Assertions.*;
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+public class InMemoryClientRegistrationRepositoryTests {
+	private ClientRegistration registration = ClientRegistration.withRegistrationId("id")
+		.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
+		.authorizationUri("https://example.com/oauth2/authorize")
+		.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+		.clientId("client-id")
+		.clientName("client-name")
+		.clientSecret("client-secret")
+		.redirectUri("{scheme}://{serverName}:{serverPort}{contextPath}/login/oauth2/code/{registrationId}")
+		.scope("user")
+		.tokenUri("https://example.com/oauth/access_token")
+		.build();
+
+	private InMemoryClientRegistrationRepository clients = new InMemoryClientRegistrationRepository(
+		Arrays.asList(this.registration));
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorListClientRegistrationWhenNullThenIllegalArgumentException() {
+		List<ClientRegistration> registrations = null;
+		new InMemoryClientRegistrationRepository(registrations);
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void constructorListClientRegistrationWhenEmptyThenIllegalArgumentException() {
+		List<ClientRegistration> registrations = Collections.emptyList();
+		new InMemoryClientRegistrationRepository(registrations);
+	}
+
+	@Test(expected = IllegalStateException.class)
+	public void constructorListClientRegistrationWhenDuplicateIdThenIllegalArgumentException() {
+		List<ClientRegistration> registrations = Arrays.asList(this.registration, this.registration);
+		new InMemoryClientRegistrationRepository(registrations);
+	}
+
+	@Test
+	public void findByRegistrationIdWhenFoundThenFound() {
+		String id = this.registration.getRegistrationId();
+		assertThat(this.clients.findByRegistrationId(id)).isEqualTo(this.registration);
+	}
+
+	@Test
+	public void findByRegistrationIdWhenNotFoundThenNull() {
+		String id = this.registration.getRegistrationId() + "MISSING";
+		assertThat(this.clients.findByRegistrationId(id)).isNull();
+	}
+
+	@Test(expected = IllegalArgumentException.class)
+	public void findByRegistrationIdWhenNullIdThenIllegalArgumentException() {
+		String id = null;
+		assertThat(this.clients.findByRegistrationId(id));
+	}
+
+	@Test(expected = UnsupportedOperationException.class)
+	public void iteratorWhenRemoveThenThrowsUnsupportedOperationException() {
+		this.clients.iterator().remove();
+	}
+
+	@Test
+	public void iteratorWhenGetThenContainsAll() {
+		assertThat(this.clients.iterator()).containsOnly(this.registration);
+	}
+}