Răsfoiți Sursa

Add ClientRegistration.withClientRegistration

Fixes gh-7486
Josh Cummings 6 ani în urmă
părinte
comite
adf9769eed

+ 43 - 7
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/registration/ClientRegistration.java

@@ -15,22 +15,26 @@
  */
 package org.springframework.security.oauth2.client.registration;
 
-import org.springframework.security.core.SpringSecurityCoreVersion;
-import org.springframework.security.oauth2.core.AuthenticationMethod;
-import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
-import org.springframework.util.Assert;
-import org.springframework.util.StringUtils;
-
 import java.io.Serializable;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
 import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Set;
 
+import org.springframework.security.core.SpringSecurityCoreVersion;
+import org.springframework.security.oauth2.core.AuthenticationMethod;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+
+import static java.util.Collections.EMPTY_MAP;
+
 /**
  * A representation of a client registration with an OAuth 2.0 or OpenID Connect 1.0 Provider.
  *
@@ -263,6 +267,17 @@ public final class ClientRegistration implements Serializable {
 		return new Builder(registrationId);
 	}
 
+	/**
+	 * Returns a new {@link Builder}, initialized with the provided {@link ClientRegistration}.
+	 *
+	 * @param clientRegistration the {@link ClientRegistration} to copy from
+	 * @return the {@link Builder}
+	 */
+	public static Builder withClientRegistration(ClientRegistration clientRegistration) {
+		Assert.notNull(clientRegistration, "clientRegistration cannot be null");
+		return new Builder(clientRegistration);
+	}
+
 	/**
 	 * A builder for {@link ClientRegistration}.
 	 */
@@ -288,6 +303,27 @@ public final class ClientRegistration implements Serializable {
 			this.registrationId = registrationId;
 		}
 
+		private Builder(ClientRegistration clientRegistration) {
+			this.registrationId = clientRegistration.registrationId;
+			this.clientId = clientRegistration.clientId;
+			this.clientSecret = clientRegistration.clientSecret;
+			this.clientAuthenticationMethod = clientRegistration.clientAuthenticationMethod;
+			this.authorizationGrantType = clientRegistration.authorizationGrantType;
+			this.redirectUriTemplate = clientRegistration.redirectUriTemplate;
+			this.scopes = clientRegistration.scopes == null ? null : new HashSet<>(clientRegistration.scopes);
+			this.authorizationUri = clientRegistration.providerDetails.authorizationUri;
+			this.tokenUri = clientRegistration.providerDetails.tokenUri;
+			this.userInfoUri = clientRegistration.providerDetails.userInfoEndpoint.uri;
+			this.userInfoAuthenticationMethod = clientRegistration.providerDetails.userInfoEndpoint.authenticationMethod;
+			this.userNameAttributeName = clientRegistration.providerDetails.userInfoEndpoint.userNameAttributeName;
+			this.jwkSetUri = clientRegistration.providerDetails.jwkSetUri;
+			Map<String, Object> configurationMetadata = clientRegistration.providerDetails.configurationMetadata;
+			if (configurationMetadata != EMPTY_MAP) {
+				this.configurationMetadata = new HashMap<>(configurationMetadata);
+			}
+			this.clientName = clientRegistration.clientName;
+		}
+
 		/**
 		 * Sets the registration id.
 		 *

+ 76 - 5
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/registration/ClientRegistrationTests.java

@@ -15,11 +15,6 @@
  */
 package org.springframework.security.oauth2.client.registration;
 
-import org.junit.Test;
-import org.springframework.security.oauth2.core.AuthenticationMethod;
-import org.springframework.security.oauth2.core.AuthorizationGrantType;
-import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
-
 import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.Map;
@@ -27,8 +22,16 @@ import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
+import org.junit.Test;
+
+import org.springframework.security.oauth2.core.AuthenticationMethod;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
+
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.springframework.security.oauth2.client.registration.ClientRegistration.withClientRegistration;
+import static org.springframework.security.oauth2.client.registration.TestClientRegistrations.clientRegistration;
 
 /**
  * Tests for {@link ClientRegistration}.
@@ -696,4 +699,72 @@ public class ClientRegistrationTests {
 		assertThat(registration.getProviderDetails().getTokenUri()).isEqualTo(TOKEN_URI);
 		assertThat(registration.getClientName()).isEqualTo(CLIENT_NAME);
 	}
+
+	@Test
+	public void buildWhenClientRegistrationProvidedThenMakesACopy() {
+		ClientRegistration clientRegistration = clientRegistration().build();
+		ClientRegistration updated = withClientRegistration(clientRegistration).build();
+		assertThat(clientRegistration.getScopes()).isEqualTo(updated.getScopes());
+		assertThat(clientRegistration.getScopes()).isNotSameAs(updated.getScopes());
+		assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata())
+				.isEqualTo(updated.getProviderDetails().getConfigurationMetadata());
+		assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata())
+				.isNotSameAs(updated.getProviderDetails().getConfigurationMetadata());
+	}
+
+	@Test
+	public void buildWhenClientRegistrationProvidedThenEachPropertyMatches() {
+		ClientRegistration clientRegistration = clientRegistration().build();
+		ClientRegistration updated = withClientRegistration(clientRegistration).build();
+		assertThat(clientRegistration.getRegistrationId()).isEqualTo(updated.getRegistrationId());
+		assertThat(clientRegistration.getClientId()).isEqualTo(updated.getClientId());
+		assertThat(clientRegistration.getClientSecret()).isEqualTo(updated.getClientSecret());
+		assertThat(clientRegistration.getClientAuthenticationMethod())
+				.isEqualTo(updated.getClientAuthenticationMethod());
+		assertThat(clientRegistration.getAuthorizationGrantType())
+				.isEqualTo(updated.getAuthorizationGrantType());
+		assertThat(clientRegistration.getRedirectUriTemplate())
+				.isEqualTo(updated.getRedirectUriTemplate());
+		assertThat(clientRegistration.getScopes()).isEqualTo(updated.getScopes());
+
+		ClientRegistration.ProviderDetails providerDetails = clientRegistration.getProviderDetails();
+		ClientRegistration.ProviderDetails updatedProviderDetails = updated.getProviderDetails();
+		assertThat(providerDetails.getAuthorizationUri())
+				.isEqualTo(updatedProviderDetails.getAuthorizationUri());
+		assertThat(providerDetails.getTokenUri())
+				.isEqualTo(updatedProviderDetails.getTokenUri());
+
+		ClientRegistration.ProviderDetails.UserInfoEndpoint userInfoEndpoint = providerDetails.getUserInfoEndpoint();
+		ClientRegistration.ProviderDetails.UserInfoEndpoint updatedUserInfoEndpoint = updatedProviderDetails.getUserInfoEndpoint();
+		assertThat(userInfoEndpoint.getUri()).isEqualTo(updatedUserInfoEndpoint.getUri());
+		assertThat(userInfoEndpoint.getAuthenticationMethod())
+				.isEqualTo(updatedUserInfoEndpoint.getAuthenticationMethod());
+		assertThat(userInfoEndpoint.getUserNameAttributeName())
+				.isEqualTo(updatedUserInfoEndpoint.getUserNameAttributeName());
+
+		assertThat(providerDetails.getJwkSetUri()).isEqualTo(updatedProviderDetails.getJwkSetUri());
+		assertThat(providerDetails.getConfigurationMetadata())
+				.isEqualTo(updatedProviderDetails.getConfigurationMetadata());
+
+		assertThat(clientRegistration.getClientName()).isEqualTo(updated.getClientName());
+	}
+
+	@Test
+	public void buildWhenClientRegistrationValuesOverriddenThenPropagated() {
+		ClientRegistration clientRegistration = clientRegistration().build();
+		ClientRegistration updated = withClientRegistration(clientRegistration)
+				.clientSecret("a-new-secret")
+				.scope("a-new-scope")
+				.providerConfigurationMetadata(Collections.singletonMap("a-new-config", "a-new-value"))
+				.build();
+
+		assertThat(clientRegistration.getClientSecret()).isNotEqualTo(updated.getClientSecret());
+		assertThat(updated.getClientSecret()).isEqualTo("a-new-secret");
+		assertThat(clientRegistration.getScopes()).doesNotContain("a-new-scope");
+		assertThat(updated.getScopes()).containsExactly("a-new-scope");
+		assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata())
+				.doesNotContainKey("a-new-config").doesNotContainValue("a-new-value");
+		assertThat(updated.getProviderDetails().getConfigurationMetadata())
+				.containsOnlyKeys("a-new-config").containsValue("a-new-value");
+	}
 }