浏览代码

Allow InMemoryOAuth2AuthorizedClientService to be constructed with a Map

Fixes gh-5994
Vedran Pavic 6 年之前
父节点
当前提交
9432670f1d

+ 17 - 11
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientService.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -20,7 +20,6 @@ import org.springframework.security.oauth2.client.registration.ClientRegistratio
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
 import org.springframework.util.Assert;
 
-import java.util.Base64;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 
@@ -29,6 +28,7 @@ import java.util.concurrent.ConcurrentHashMap;
  * {@link OAuth2AuthorizedClient Authorized Client(s)} in-memory.
  *
  * @author Joe Grandja
+ * @author Vedran Pavic
  * @since 5.0
  * @see OAuth2AuthorizedClientService
  * @see OAuth2AuthorizedClient
@@ -36,8 +36,8 @@ import java.util.concurrent.ConcurrentHashMap;
  * @see Authentication
  */
 public final class InMemoryOAuth2AuthorizedClientService implements OAuth2AuthorizedClientService {
-	private final Map<String, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();
 	private final ClientRegistrationRepository clientRegistrationRepository;
+	private Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();
 
 	/**
 	 * Constructs an {@code InMemoryOAuth2AuthorizedClientService} using the provided parameters.
@@ -49,7 +49,17 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author
 		this.clientRegistrationRepository = clientRegistrationRepository;
 	}
 
+	/**
+	 * Sets the map of authorized clients to use.
+	 * @param authorizedClients the map of authorized clients
+	 */
+	public void setAuthorizedClients(Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients) {
+		Assert.notNull(authorizedClients, "authorizedClients cannot be null");
+		this.authorizedClients = authorizedClients;
+	}
+
 	@Override
+	@SuppressWarnings("unchecked")
 	public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, String principalName) {
 		Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
 		Assert.hasText(principalName, "principalName cannot be empty");
@@ -57,15 +67,15 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author
 		if (registration == null) {
 			return null;
 		}
-		return (T) this.authorizedClients.get(this.getIdentifier(registration, principalName));
+		return (T) this.authorizedClients.get(OAuth2AuthorizedClientId.create(registration, principalName));
 	}
 
 	@Override
 	public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
 		Assert.notNull(authorizedClient, "authorizedClient cannot be null");
 		Assert.notNull(principal, "principal cannot be null");
-		this.authorizedClients.put(this.getIdentifier(
-			authorizedClient.getClientRegistration(), principal.getName()), authorizedClient);
+		this.authorizedClients.put(OAuth2AuthorizedClientId.create(authorizedClient.getClientRegistration(),
+				principal.getName()), authorizedClient);
 	}
 
 	@Override
@@ -74,12 +84,8 @@ public final class InMemoryOAuth2AuthorizedClientService implements OAuth2Author
 		Assert.hasText(principalName, "principalName cannot be empty");
 		ClientRegistration registration = this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId);
 		if (registration != null) {
-			this.authorizedClients.remove(this.getIdentifier(registration, principalName));
+			this.authorizedClients.remove(OAuth2AuthorizedClientId.create(registration, principalName));
 		}
 	}
 
-	private String getIdentifier(ClientRegistration registration, String principalName) {
-		String identifier = "[" + registration.getRegistrationId() + "][" + principalName + "]";
-		return Base64.getEncoder().encodeToString(identifier.getBytes());
-	}
 }

+ 10 - 15
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/InMemoryReactiveOAuth2AuthorizedClientService.java

@@ -15,7 +15,6 @@
  */
 package org.springframework.security.oauth2.client;
 
-import java.util.Base64;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 
@@ -31,6 +30,7 @@ import reactor.core.publisher.Mono;
  * {@link OAuth2AuthorizedClient Authorized Client(s)} in-memory.
  *
  * @author Rob Winch
+ * @author Vedran Pavic
  * @since 5.1
  * @see OAuth2AuthorizedClientService
  * @see OAuth2AuthorizedClient
@@ -38,7 +38,7 @@ import reactor.core.publisher.Mono;
  * @see Authentication
  */
 public final class InMemoryReactiveOAuth2AuthorizedClientService implements ReactiveOAuth2AuthorizedClientService {
-	private final Map<String, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();
+	private final Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = new ConcurrentHashMap<>();;
 	private final ReactiveClientRegistrationRepository clientRegistrationRepository;
 
 	/**
@@ -52,10 +52,12 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac
 	}
 
 	@Override
+	@SuppressWarnings("unchecked")
 	public <T extends OAuth2AuthorizedClient> Mono<T> loadAuthorizedClient(String clientRegistrationId, String principalName) {
 		Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
 		Assert.hasText(principalName, "principalName cannot be empty");
-		return (Mono<T>) getIdentifier(clientRegistrationId, principalName)
+		return (Mono<T>) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
+				.map(clientRegistration -> OAuth2AuthorizedClientId.create(clientRegistration, principalName))
 				.flatMap(identifier -> Mono.justOrEmpty(this.authorizedClients.get(identifier)));
 	}
 
@@ -64,7 +66,8 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac
 		Assert.notNull(authorizedClient, "authorizedClient cannot be null");
 		Assert.notNull(principal, "principal cannot be null");
 		return Mono.fromRunnable(() -> {
-			String identifier = this.getIdentifier(authorizedClient.getClientRegistration(), principal.getName());
+			OAuth2AuthorizedClientId identifier = OAuth2AuthorizedClientId.create(
+					authorizedClient.getClientRegistration(), principal.getName());
 			this.authorizedClients.put(identifier, authorizedClient);
 		});
 	}
@@ -73,18 +76,10 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac
 	public Mono<Void> removeAuthorizedClient(String clientRegistrationId, String principalName) {
 		Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
 		Assert.hasText(principalName, "principalName cannot be empty");
-		return this.getIdentifier(clientRegistrationId, principalName)
-				.doOnNext(identifier -> this.authorizedClients.remove(identifier))
-				.then(Mono.empty());
-	}
-
-	private Mono<String> getIdentifier(String clientRegistrationId, String principalName) {
 		return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
-				.map(registration -> getIdentifier(registration, principalName));
+				.map(clientRegistration -> OAuth2AuthorizedClientId.create(clientRegistration, principalName))
+				.doOnNext(this.authorizedClients::remove)
+				.then(Mono.empty());
 	}
 
-	private String getIdentifier(ClientRegistration registration, String principalName) {
-		String identifier = "[" + registration.getRegistrationId() + "][" + principalName + "]";
-		return Base64.getEncoder().encodeToString(identifier.getBytes());
-	}
 }

+ 77 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientId.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client;
+
+import java.io.Serializable;
+import java.util.Objects;
+
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.util.Assert;
+
+/**
+ * The identifier for {@link OAuth2AuthorizedClient}.
+ *
+ * @author Vedran Pavic
+ * @since 5.2
+ * @see OAuth2AuthorizedClient
+ * @see OAuth2AuthorizedClientService
+ */
+public final class OAuth2AuthorizedClientId implements Serializable {
+
+	private final String clientRegistrationId;
+
+	private final String principalName;
+
+	private OAuth2AuthorizedClientId(String clientRegistrationId, String principalName) {
+		Assert.notNull(clientRegistrationId, "clientRegistrationId cannot be null");
+		Assert.notNull(principalName, "principalName cannot be null");
+		this.clientRegistrationId = clientRegistrationId;
+		this.principalName = principalName;
+	}
+
+	/**
+	 * Factory method for creating new {@link OAuth2AuthorizedClientId} using
+	 * {@link ClientRegistration} and principal name.
+	 * @param clientRegistration the client registration
+	 * @param principalName the principal name
+	 * @return the new authorized client id
+	 */
+	public static OAuth2AuthorizedClientId create(ClientRegistration clientRegistration,
+			String principalName) {
+		return new OAuth2AuthorizedClientId(clientRegistration.getRegistrationId(),
+				principalName);
+	}
+
+	@Override
+	public boolean equals(Object obj) {
+		if (this == obj) {
+			return true;
+		}
+		if (obj == null || getClass() != obj.getClass()) {
+			return false;
+		}
+		OAuth2AuthorizedClientId that = (OAuth2AuthorizedClientId) obj;
+		return Objects.equals(this.clientRegistrationId, that.clientRegistrationId)
+				&& Objects.equals(this.principalName, that.principalName);
+	}
+
+	@Override
+	public int hashCode() {
+		return Objects.hash(this.clientRegistrationId, this.principalName);
+	}
+
+}

+ 33 - 1
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/InMemoryOAuth2AuthorizedClientServiceTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2018 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,7 +15,11 @@
  */
 package org.springframework.security.oauth2.client;
 
+import java.util.Collections;
+import java.util.Map;
+
 import org.junit.Test;
+
 import org.springframework.security.core.Authentication;
 import org.springframework.security.oauth2.client.registration.ClientRegistration;
 import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
@@ -24,6 +28,9 @@ import org.springframework.security.oauth2.client.registration.TestClientRegistr
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
 
@@ -31,6 +38,7 @@ import static org.mockito.Mockito.when;
  * Tests for {@link InMemoryOAuth2AuthorizedClientService}.
  *
  * @author Joe Grandja
+ * @author Vedran Pavic
  */
 public class InMemoryOAuth2AuthorizedClientServiceTests {
 	private String principalName1 = "principal-1";
@@ -57,6 +65,30 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
 		new InMemoryOAuth2AuthorizedClientService(null);
 	}
 
+	@Test
+	public void constructorWhenAuthorizedClientsIsNullThenIllegalArgumentException() {
+		assertThatExceptionOfType(IllegalArgumentException.class)
+				.isThrownBy(() -> this.authorizedClientService.setAuthorizedClients(null))
+				.withMessage("authorizedClients cannot be null");
+	}
+
+	@Test
+	public void constructorWhenAuthorizedClientsIsEmptyMapThenRepositoryUsingSuppliedAuthorizedClients() {
+		String registrationId = this.registration3.getRegistrationId();
+
+		Map<OAuth2AuthorizedClientId, OAuth2AuthorizedClient> authorizedClients = Collections.singletonMap(
+				OAuth2AuthorizedClientId.create(this.registration3, this.principalName1),
+				mock(OAuth2AuthorizedClient.class));
+		ClientRegistrationRepository clientRegistrationRepository = mock(ClientRegistrationRepository.class);
+		given(clientRegistrationRepository.findByRegistrationId(eq(registrationId))).willReturn(this.registration3);
+
+		InMemoryOAuth2AuthorizedClientService authorizedClientService = new InMemoryOAuth2AuthorizedClientService(
+				this.clientRegistrationRepository);
+		authorizedClientService.setAuthorizedClients(authorizedClients);
+		assertThat((OAuth2AuthorizedClient) authorizedClientService.loadAuthorizedClient(
+				registrationId, this.principalName1)).isNotNull();
+	}
+
 	@Test(expected = IllegalArgumentException.class)
 	public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
 		this.authorizedClientService.loadAuthorizedClient(null, this.principalName1);

+ 94 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/OAuth2AuthorizedClientIdTests.java

@@ -0,0 +1,94 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.oauth2.client;
+
+import org.junit.Test;
+
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+/**
+ * Tests for {@link OAuth2AuthorizedClientId}.
+ *
+ * @author Vedran Pavic
+ */
+public class OAuth2AuthorizedClientIdTests {
+
+	@Test
+	public void equalsWhenSameRegistrationIdAndPrincipalThenShouldReturnTrue() {
+		OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
+				"test-principal");
+		OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
+				"test-principal");
+		assertThat(id1.equals(id2)).isTrue();
+	}
+
+	@Test
+	public void equalsWhenDifferentRegistrationIdAndSamePrincipalThenShouldReturnFalse() {
+		OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client1"),
+				"test-principal");
+		OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client2"),
+				"test-principal");
+		assertThat(id1.equals(id2)).isFalse();
+	}
+
+	@Test
+	public void equalsWhenSameRegistrationIdAndDifferentPrincipalThenShouldReturnFalse() {
+		OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
+				"test-principal1");
+		OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
+				"test-principal2");
+		assertThat(id1.equals(id2)).isFalse();
+	}
+
+	@Test
+	public void hashCodeWhenSameRegistrationIdAndPrincipalThenShouldReturnSame() {
+		OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
+				"test-principal");
+		OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
+				"test-principal");
+		assertThat(id1.hashCode()).isEqualTo(id2.hashCode());
+	}
+
+	@Test
+	public void hashCodeWhenDifferentRegistrationIdAndSamePrincipalThenShouldNotReturnSame() {
+		OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client1"),
+				"test-principal");
+		OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client2"),
+				"test-principal");
+		assertThat(id1.hashCode()).isNotEqualTo(id2.hashCode());
+	}
+
+	@Test
+	public void hashCodeWhenSameRegistrationIdAndDifferentPrincipalThenShouldNotReturnSame() {
+		OAuth2AuthorizedClientId id1 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
+				"test-principal1");
+		OAuth2AuthorizedClientId id2 = OAuth2AuthorizedClientId.create(testClientRegistration("test-client"),
+				"test-principal2");
+		assertThat(id1.hashCode()).isNotEqualTo(id2.hashCode());
+	}
+
+	private static ClientRegistration testClientRegistration(String registrationId) {
+		return ClientRegistration.withRegistrationId(registrationId).clientId("id").clientSecret("secret")
+				.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
+				.redirectUriTemplate("{baseUrl}/{action}/oauth2/code/{registrationId}")
+				.authorizationUri("http://example.com/authorize").tokenUri("http://example.com/token").build();
+	}
+
+}