Ver código fonte

Add update support in JdbcRegisteredClientRepository

Closes gh-356
Ovidiu Popa 4 anos atrás
pai
commit
41f8c9cd00

+ 29 - 6
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java

@@ -56,6 +56,7 @@ import org.springframework.util.StringUtils;
  *
  * @author Rafal Lewczuk
  * @author Joe Grandja
+ * @author Ovidiu Popa
  * @since 0.1.2
  * @see RegisteredClientRepository
  * @see RegisteredClient
@@ -81,6 +82,8 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 
 	private static final String TABLE_NAME = "oauth2_registered_client";
 
+	private static final String PK_FILTER = "id = ?";
+
 	private static final String LOAD_REGISTERED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + " FROM " + TABLE_NAME + " WHERE ";
 
 	// @formatter:off
@@ -88,6 +91,14 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 			+ "(" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
 	// @formatter:on
 
+	// @formatter:off
+	private static final String UPDATE_REGISTERED_CLIENT_SQL = "UPDATE " + TABLE_NAME
+			+ " SET client_secret = ?, client_secret_expires_at = ?,"
+			+ " client_name = ?, client_authentication_methods = ?, authorization_grant_types = ?,"
+			+ " redirect_uris = ?, scopes = ?, client_settings = ?, token_settings = ?"
+			+ " WHERE " + PK_FILTER;
+	// @formatter:on
+
 	private final JdbcOperations jdbcOperations;
 	private RowMapper<RegisteredClient> registeredClientRowMapper;
 	private Function<RegisteredClient, List<SqlParameterValue>> registeredClientParametersMapper;
@@ -107,14 +118,26 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 	@Override
 	public void save(RegisteredClient registeredClient) {
 		Assert.notNull(registeredClient, "registeredClient cannot be null");
-		RegisteredClient existingRegisteredClient = findBy("id = ? OR client_id = ?",
-				registeredClient.getId(), registeredClient.getClientId());
+		RegisteredClient existingRegisteredClient = findBy(PK_FILTER,
+				registeredClient.getId());
 		if (existingRegisteredClient != null) {
-			Assert.isTrue(!existingRegisteredClient.getId().equals(registeredClient.getId()),
-					"Registered client must be unique. Found duplicate identifier: " + registeredClient.getId());
-			Assert.isTrue(!existingRegisteredClient.getClientId().equals(registeredClient.getClientId()),
-					"Registered client must be unique. Found duplicate client identifier: " + registeredClient.getClientId());
+			updateRegisteredClient(registeredClient);
+		} else {
+			insertRegisteredClient(registeredClient);
 		}
+	}
+
+	private void updateRegisteredClient(RegisteredClient registeredClient) {
+		List<SqlParameterValue> parameters = new ArrayList<>(this.registeredClientParametersMapper.apply(registeredClient));
+		SqlParameterValue id = parameters.remove(0);
+		parameters.remove(0); // remove client_id
+		parameters.remove(0); // remove client_id_issued_at
+		parameters.add(id);
+		PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
+		this.jdbcOperations.update(UPDATE_REGISTERED_CLIENT_SQL, pss);
+	}
+
+	private void insertRegisteredClient(RegisteredClient registeredClient) {
 		List<SqlParameterValue> parameters = this.registeredClientParametersMapper.apply(registeredClient);
 		PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
 		this.jdbcOperations.update(INSERT_REGISTERED_CLIENT_SQL, pss);

+ 19 - 18
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java

@@ -18,6 +18,7 @@ package org.springframework.security.oauth2.server.authorization.client;
 import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.sql.Timestamp;
+import java.time.Instant;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -26,10 +27,10 @@ import java.util.function.Function;
 import com.fasterxml.jackson.core.type.TypeReference;
 import com.fasterxml.jackson.databind.Module;
 import com.fasterxml.jackson.databind.ObjectMapper;
+
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
-
 import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
 import org.springframework.jdbc.core.JdbcOperations;
 import org.springframework.jdbc.core.JdbcTemplate;
@@ -60,6 +61,7 @@ import static org.mockito.Mockito.verify;
  * @author Rafal Lewczuk
  * @author Steve Riesenberg
  * @author Joe Grandja
+ * @author Ovidiu Popa
  */
 public class JdbcRegisteredClientRepositoryTests {
 	private static final String OAUTH2_REGISTERED_CLIENT_SCHEMA_SQL_RESOURCE = "/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql";
@@ -115,26 +117,25 @@ public class JdbcRegisteredClientRepositoryTests {
 	}
 
 	@Test
-	public void saveWhenExistingIdThenThrowIllegalArgumentException() {
-		RegisteredClient registeredClient = TestRegisteredClients.registeredClient().build();
-		this.registeredClientRepository.save(registeredClient);
+	public void saveWhenRegisteredClientExistsThenUpdated() {
+		RegisteredClient originalRegisteredClient = TestRegisteredClients.registeredClient().build();
+		this.registeredClientRepository.save(originalRegisteredClient);
 
-		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.registeredClientRepository.save(registeredClient))
-				.withMessage("Registered client must be unique. Found duplicate identifier: " + registeredClient.getId());
-	}
+		RegisteredClient registeredClient = this.registeredClientRepository.findById(
+				originalRegisteredClient.getId());
+		assertThat(registeredClient).isEqualTo(originalRegisteredClient);
 
-	@Test
-	public void saveWhenExistingClientIdThenThrowIllegalArgumentException() {
-		RegisteredClient existingRegisteredClient = TestRegisteredClients.registeredClient().build();
-		this.registeredClientRepository.save(existingRegisteredClient);
-		RegisteredClient registeredClient = RegisteredClient.from(existingRegisteredClient)
-				.id("registration-2")
-				.build();
+		RegisteredClient updatedRegisteredClient = RegisteredClient.from(originalRegisteredClient)
+				.clientId("test").clientIdIssuedAt(Instant.now()).clientName("clientName").scope("scope2").build();
 
-		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.registeredClientRepository.save(registeredClient))
-				.withMessage("Registered client must be unique. Found duplicate client identifier: " + registeredClient.getClientId());
+		RegisteredClient expectedUpdatedRegisteredClient = RegisteredClient.from(originalRegisteredClient)
+				.clientName("clientName").scope("scope2").build();
+		this.registeredClientRepository.save(updatedRegisteredClient);
+
+		registeredClient = this.registeredClientRepository.findById(
+				updatedRegisteredClient.getId());
+		assertThat(registeredClient).isEqualTo(expectedUpdatedRegisteredClient);
+		assertThat(registeredClient).isNotEqualTo(originalRegisteredClient);
 	}
 
 	@Test