Browse Source

Improve RegisteredClient model

Closes gh-221
Joe Grandja 4 years ago
parent
commit
afd5491ced

+ 6 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/InMemoryRegisteredClientRepository.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,13 +15,14 @@
  */
 package org.springframework.security.oauth2.server.authorization.client;
 
-import org.springframework.util.Assert;
-
 import java.util.Arrays;
 import java.util.List;
 import java.util.Map;
 import java.util.concurrent.ConcurrentHashMap;
 
+import org.springframework.lang.Nullable;
+import org.springframework.util.Assert;
+
 /**
  * A {@link RegisteredClientRepository} that stores {@link RegisteredClient}(s) in-memory.
  *
@@ -74,12 +75,14 @@ public final class InMemoryRegisteredClientRepository implements RegisteredClien
 		this.clientIdRegistrationMap = clientIdRegistrationMapResult;
 	}
 
+	@Nullable
 	@Override
 	public RegisteredClient findById(String id) {
 		Assert.hasText(id, "id cannot be empty");
 		return this.idRegistrationMap.get(id);
 	}
 
+	@Nullable
 	@Override
 	public RegisteredClient findByClientId(String clientId) {
 		Assert.hasText(clientId, "clientId cannot be empty");

+ 60 - 28
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClient.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,22 +15,23 @@
  */
 package org.springframework.security.oauth2.server.authorization.client;
 
-import org.springframework.security.oauth2.server.authorization.Version;
-import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
-import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
-import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
-import org.springframework.util.Assert;
-import org.springframework.util.CollectionUtils;
-
 import java.io.Serializable;
 import java.net.URI;
 import java.net.URISyntaxException;
 import java.util.Collections;
-import java.util.LinkedHashSet;
+import java.util.HashSet;
+import java.util.Objects;
 import java.util.Set;
 import java.util.function.Consumer;
 
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.security.oauth2.server.authorization.Version;
+import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
+import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+
 /**
  * A representation of a client registration with an OAuth 2.0 Authorization Server.
  *
@@ -82,8 +83,7 @@ public class RegisteredClient implements Serializable {
 	}
 
 	/**
-	 * Returns the {@link ClientAuthenticationMethod authentication method(s)} used
-	 * when authenticating the client with the authorization server.
+	 * Returns the {@link ClientAuthenticationMethod authentication method(s)} that the client may use.
 	 *
 	 * @return the {@code Set} of {@link ClientAuthenticationMethod authentication method(s)}
 	 */
@@ -110,7 +110,7 @@ public class RegisteredClient implements Serializable {
 	}
 
 	/**
-	 * Returns the scope(s) used by the client.
+	 * Returns the scope(s) that the client may use.
 	 *
 	 * @return the {@code Set} of scope(s)
 	 */
@@ -136,15 +136,44 @@ public class RegisteredClient implements Serializable {
 		return this.tokenSettings;
 	}
 
+	@Override
+	public boolean equals(Object obj) {
+		if (this == obj) {
+			return true;
+		}
+		if (obj == null || getClass() != obj.getClass()) {
+			return false;
+		}
+		RegisteredClient that = (RegisteredClient) obj;
+		return Objects.equals(this.id, that.id) &&
+				Objects.equals(this.clientId, that.clientId) &&
+				Objects.equals(this.clientSecret, that.clientSecret) &&
+				Objects.equals(this.clientAuthenticationMethods, that.clientAuthenticationMethods) &&
+				Objects.equals(this.authorizationGrantTypes, that.authorizationGrantTypes) &&
+				Objects.equals(this.redirectUris, that.redirectUris) &&
+				Objects.equals(this.scopes, that.scopes) &&
+				Objects.equals(this.clientSettings.settings(), that.getClientSettings().settings()) &&
+				Objects.equals(this.tokenSettings.settings(), that.tokenSettings.settings());
+	}
+
+	@Override
+	public int hashCode() {
+		return Objects.hash(this.id, this.clientId, this.clientSecret,
+				this.clientAuthenticationMethods, this.authorizationGrantTypes, this.redirectUris,
+				this.scopes, this.clientSettings.settings(), this.tokenSettings.settings());
+	}
+
 	@Override
 	public String toString() {
-		return "RegisteredClient{" +
+		return "RegisteredClient {" +
 				"id='" + this.id + '\'' +
 				", clientId='" + this.clientId + '\'' +
 				", clientAuthenticationMethods=" + this.clientAuthenticationMethods +
 				", authorizationGrantTypes=" + this.authorizationGrantTypes +
 				", redirectUris=" + this.redirectUris +
 				", scopes=" + this.scopes +
+				", clientSettings=" + this.clientSettings.settings() +
+				", tokenSettings=" + this.tokenSettings.settings() +
 				'}';
 	}
 
@@ -160,12 +189,12 @@ public class RegisteredClient implements Serializable {
 	}
 
 	/**
-	 * Returns a new {@link Builder}, initialized with the provided {@link RegisteredClient}.
+	 * Returns a new {@link Builder}, initialized with the values from the provided {@link RegisteredClient}.
 	 *
-	 * @param registeredClient the {@link RegisteredClient} to copy from
+	 * @param registeredClient the {@link RegisteredClient} used for initializing the {@link Builder}
 	 * @return the {@link Builder}
 	 */
-	public static Builder withRegisteredClient(RegisteredClient registeredClient) {
+	public static Builder from(RegisteredClient registeredClient) {
 		Assert.notNull(registeredClient, "registeredClient cannot be null");
 		return new Builder(registeredClient);
 	}
@@ -178,10 +207,10 @@ public class RegisteredClient implements Serializable {
 		private String id;
 		private String clientId;
 		private String clientSecret;
-		private Set<ClientAuthenticationMethod> clientAuthenticationMethods = new LinkedHashSet<>();
-		private Set<AuthorizationGrantType> authorizationGrantTypes = new LinkedHashSet<>();
-		private Set<String> redirectUris = new LinkedHashSet<>();
-		private Set<String> scopes = new LinkedHashSet<>();
+		private Set<ClientAuthenticationMethod> clientAuthenticationMethods = new HashSet<>();
+		private Set<AuthorizationGrantType> authorizationGrantTypes = new HashSet<>();
+		private Set<String> redirectUris = new HashSet<>();
+		private Set<String> scopes = new HashSet<>();
 		private ClientSettings clientSettings = new ClientSettings();
 		private TokenSettings tokenSettings = new TokenSettings();
 
@@ -385,13 +414,16 @@ public class RegisteredClient implements Serializable {
 			registeredClient.id = this.id;
 			registeredClient.clientId = this.clientId;
 			registeredClient.clientSecret = this.clientSecret;
-			registeredClient.clientAuthenticationMethods =
-					Collections.unmodifiableSet(this.clientAuthenticationMethods);
-			registeredClient.authorizationGrantTypes = Collections.unmodifiableSet(this.authorizationGrantTypes);
-			registeredClient.redirectUris = Collections.unmodifiableSet(this.redirectUris);
-			registeredClient.scopes = Collections.unmodifiableSet(this.scopes);
-			registeredClient.clientSettings = this.clientSettings;
-			registeredClient.tokenSettings = this.tokenSettings;
+			registeredClient.clientAuthenticationMethods = Collections.unmodifiableSet(
+					new HashSet<>(this.clientAuthenticationMethods));
+			registeredClient.authorizationGrantTypes = Collections.unmodifiableSet(
+					new HashSet<>(this.authorizationGrantTypes));
+			registeredClient.redirectUris = Collections.unmodifiableSet(
+					new HashSet<>(this.redirectUris));
+			registeredClient.scopes = Collections.unmodifiableSet(
+					new HashSet<>(this.scopes));
+			registeredClient.clientSettings = new ClientSettings(this.clientSettings.settings());
+			registeredClient.tokenSettings = new TokenSettings(this.tokenSettings.settings());
 
 			return registeredClient;
 		}

+ 9 - 3
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientRepository.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,6 +15,8 @@
  */
 package org.springframework.security.oauth2.server.authorization.client;
 
+import org.springframework.lang.Nullable;
+
 /**
  * A repository for OAuth 2.0 {@link RegisteredClient}(s).
  *
@@ -26,19 +28,23 @@ package org.springframework.security.oauth2.server.authorization.client;
 public interface RegisteredClientRepository {
 
 	/**
-	 * Returns the registered client identified by the provided {@code id}, or {@code null} if not found.
+	 * Returns the registered client identified by the provided {@code id},
+	 * or {@code null} if not found.
 	 *
 	 * @param id the registration identifier
 	 * @return the {@link RegisteredClient} if found, otherwise {@code null}
 	 */
+	@Nullable
 	RegisteredClient findById(String id);
 
 	/**
-	 * Returns the registered client identified by the provided {@code clientId}, or {@code null} if not found.
+	 * Returns the registered client identified by the provided {@code clientId},
+	 * or {@code null} if not found.
 	 *
 	 * @param clientId the client identifier
 	 * @return the {@link RegisteredClient} if found, otherwise {@code null}
 	 */
+	@Nullable
 	RegisteredClient findByClientId(String clientId);
 
 }

+ 12 - 11
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/RegisteredClientTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2020 the original author or authors.
+ * Copyright 2020-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,15 +15,16 @@
  */
 package org.springframework.security.oauth2.server.authorization.client;
 
-import org.junit.Test;
-import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
-
 import java.util.Collections;
 import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
+import org.junit.Test;
+
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
 
@@ -231,7 +232,7 @@ public class RegisteredClientTests {
 				.build();
 
 		assertThat(registration.getAuthorizationGrantTypes())
-				.containsExactly(AuthorizationGrantType.AUTHORIZATION_CODE, AuthorizationGrantType.CLIENT_CREDENTIALS);
+				.containsExactlyInAnyOrder(AuthorizationGrantType.AUTHORIZATION_CODE, AuthorizationGrantType.CLIENT_CREDENTIALS);
 	}
 
 	@Test
@@ -249,7 +250,7 @@ public class RegisteredClientTests {
 				.build();
 
 		assertThat(registration.getAuthorizationGrantTypes())
-				.containsExactly(AuthorizationGrantType.AUTHORIZATION_CODE, AuthorizationGrantType.CLIENT_CREDENTIALS);
+				.containsExactlyInAnyOrder(AuthorizationGrantType.AUTHORIZATION_CODE, AuthorizationGrantType.CLIENT_CREDENTIALS);
 	}
 
 	@Test
@@ -280,7 +281,7 @@ public class RegisteredClientTests {
 				.build();
 
 		assertThat(registration.getClientAuthenticationMethods())
-				.containsExactly(ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST);
+				.containsExactlyInAnyOrder(ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST);
 	}
 
 	@Test
@@ -298,7 +299,7 @@ public class RegisteredClientTests {
 				.build();
 
 		assertThat(registration.getClientAuthenticationMethods())
-				.containsExactly(ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST);
+				.containsExactlyInAnyOrder(ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST);
 	}
 
 	@Test
@@ -320,7 +321,7 @@ public class RegisteredClientTests {
 	@Test
 	public void buildWhenRegisteredClientProvidedThenMakesACopy() {
 		RegisteredClient registration = TestRegisteredClients.registeredClient().build();
-		RegisteredClient updated = RegisteredClient.withRegisteredClient(registration).build();
+		RegisteredClient updated = RegisteredClient.from(registration).build();
 
 		assertThat(registration.getId()).isEqualTo(updated.getId());
 		assertThat(registration.getClientId()).isEqualTo(updated.getClientId());
@@ -345,7 +346,7 @@ public class RegisteredClientTests {
 		String newSecret = "new-secret";
 		String newScope = "new-scope";
 		String newRedirectUri = "https://another-redirect-uri.com";
-		RegisteredClient updated = RegisteredClient.withRegisteredClient(registration)
+		RegisteredClient updated = RegisteredClient.from(registration)
 				.clientSecret(newSecret)
 				.scopes(scopes -> {
 					scopes.clear();