瀏覽代碼

Add test to override schema for JdbcRegisteredClientRepository

Steve Riesenberg 4 年之前
父節點
當前提交
623736d640

+ 2 - 1
oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client.sql → oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql

@@ -11,4 +11,5 @@ CREATE TABLE oauth2_registered_client (
     scopes varchar(1000) NOT NULL,
     client_settings varchar(1000) DEFAULT NULL,
     token_settings varchar(1000) DEFAULT NULL,
-    PRIMARY KEY (id));
+    PRIMARY KEY (id)
+);

+ 211 - 61
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java

@@ -15,20 +15,37 @@
  */
 package org.springframework.security.oauth2.server.authorization.client;
 
-import java.io.InputStream;
-import java.nio.charset.Charset;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Timestamp;
 import java.time.Duration;
 import java.time.Instant;
-
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
+import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
+import org.springframework.jdbc.core.JdbcOperations;
 import org.springframework.jdbc.core.JdbcTemplate;
+import org.springframework.jdbc.core.PreparedStatementSetter;
+import org.springframework.jdbc.core.RowMapper;
+import org.springframework.jdbc.core.SqlParameterValue;
 import org.springframework.jdbc.datasource.DriverManagerDataSource;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
-import org.springframework.util.StreamUtils;
+import org.springframework.util.StringUtils;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -37,51 +54,38 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
  * JDBC-backed registered client repository tests
  *
  * @author Rafal Lewczuk
+ * @author Steve Riesenberg
  * @since 0.1.2
  */
 public class JdbcRegisteredClientRepositoryTests {
 
-	private final String SCRIPT = "/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client.sql";
+	private static final String REGISTERED_CLIENT_SCHEMA_SQL_RESOURCE = "/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client-schema.sql";
+	private static final String CUSTOM_REGISTERED_CLIENT_SCHEMA_SQL_RESOURCE = "/org/springframework/security/oauth2/server/authorization/client/custom-oauth2-registered-client-schema.sql";
 
 	private DriverManagerDataSource dataSource;
 
-	private JdbcRegisteredClientRepository clients;
+	private JdbcRegisteredClientRepository registeredClientRepository;
+
+	private RegisteredClient registeredClient;
 
-	private RegisteredClient registration;
+	private EmbeddedDatabase db;
 
-	private JdbcTemplate jdbc;
+	private JdbcOperations jdbcOperations;
 
 	@Before
 	public void setup() throws Exception {
-		this.dataSource = new DriverManagerDataSource();
-		this.dataSource.setDriverClassName("org.hsqldb.jdbcDriver");
-		this.dataSource.setUrl("jdbc:hsqldb:mem:oauthtest");
-		this.dataSource.setUsername("sa");
-		this.dataSource.setPassword("");
-
-		this.jdbc = new JdbcTemplate(this.dataSource);
-
-		// execute scripts
-		try (InputStream is = JdbcRegisteredClientRepositoryTests.class.getResourceAsStream(SCRIPT)) {
-			assertThat(is).isNotNull().describedAs("Cannot open resource file: " + SCRIPT);
-			String ddls = StreamUtils.copyToString(is, Charset.defaultCharset());
-			for (String ddl : ddls.split(";\n")) {
-				if (!ddl.trim().isEmpty()) {
-					this.jdbc.execute(ddl.trim());
-				}
-			}
-		}
+		this.db = createDb(REGISTERED_CLIENT_SCHEMA_SQL_RESOURCE);
+		this.jdbcOperations = new JdbcTemplate(this.db);
 
-		this.clients = new JdbcRegisteredClientRepository(this.jdbc);
-		this.registration = TestRegisteredClients.registeredClient().build();
+		this.registeredClientRepository = new JdbcRegisteredClientRepository(this.jdbcOperations);
+		this.registeredClient = TestRegisteredClients.registeredClient().build();
 
-		this.clients.save(this.registration);
+		this.registeredClientRepository.save(this.registeredClient);
 	}
 
 	@After
 	public void destroyDatabase() {
-		this.jdbc.update("truncate table oauth2_registered_client");
-		new JdbcTemplate(this.dataSource).execute("SHUTDOWN");
+		this.db.shutdown();
 	}
 
 	@Test
@@ -97,7 +101,7 @@ public class JdbcRegisteredClientRepositoryTests {
 	public void whenSetNullRegisteredClientRowMapperThenThrow() {
 		// @formatter:off
 		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.clients.setRegisteredClientRowMapper(null))
+				.isThrownBy(() -> this.registeredClientRepository.setRegisteredClientRowMapper(null))
 				.withMessage("registeredClientRowMapper cannot be null");
 		// @formatter:on
 	}
@@ -106,20 +110,20 @@ public class JdbcRegisteredClientRepositoryTests {
 	public void whenSetNullRegisteredClientParameterMapperThenThrow() {
 		// @formatter:off
 		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.clients.setRegisteredClientParametersMapper(null))
+				.isThrownBy(() -> this.registeredClientRepository.setRegisteredClientParametersMapper(null))
 				.withMessage("registeredClientParameterMapper cannot be null");
 		// @formatter:on
 	}
 
 	@Test
 	public void findByIdWhenFoundThenFound() {
-		String id = this.registration.getId();
-		assertRegisteredClientIsEqualTo(this.clients.findById(id), this.registration);
+		String id = this.registeredClient.getId();
+		assertRegisteredClientIsEqualTo(this.registeredClientRepository.findById(id), this.registeredClient);
 	}
 
 	@Test
 	public void findByIdWhenNotFoundThenNull() {
-		RegisteredClient client = this.clients.findById("noooope");
+		RegisteredClient client = this.registeredClientRepository.findById("noooope");
 		assertThat(client).isNull();
 	}
 
@@ -127,20 +131,20 @@ public class JdbcRegisteredClientRepositoryTests {
 	public void findByIdWhenNullThenThrowIllegalArgumentException() {
 		// @formatter:off
 		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.clients.findById(null))
+				.isThrownBy(() -> this.registeredClientRepository.findById(null))
 				.withMessage("id cannot be empty");
 		// @formatter:on
 	}
 
 	@Test
 	public void findByClientIdWhenFoundThenFound() {
-		String id = this.registration.getClientId();
-		assertRegisteredClientIsEqualTo(this.clients.findByClientId(id), this.registration);
+		String id = this.registeredClient.getClientId();
+		assertRegisteredClientIsEqualTo(this.registeredClientRepository.findByClientId(id), this.registeredClient);
 	}
 
 	@Test
 	public void findByClientIdWhenNotFoundThenNull() {
-		RegisteredClient client = this.clients.findByClientId("noooope");
+		RegisteredClient client = this.registeredClientRepository.findByClientId("noooope");
 		assertThat(client).isNull();
 	}
 
@@ -148,7 +152,7 @@ public class JdbcRegisteredClientRepositoryTests {
 	public void findByClientIdWhenNullThenThrowIllegalArgumentException() {
 		// @formatter:off
 		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.clients.findByClientId(null))
+				.isThrownBy(() -> this.registeredClientRepository.findByClientId(null))
 				.withMessage("clientId cannot be empty");
 		// @formatter:on
 	}
@@ -156,58 +160,58 @@ public class JdbcRegisteredClientRepositoryTests {
 	@Test
 	public void saveWhenNullThenThrowIllegalArgumentException() {
 		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.clients.save(null))
+				.isThrownBy(() -> this.registeredClientRepository.save(null))
 				.withMessageContaining("registeredClient cannot be null");
 	}
 
 	@Test
 	public void saveWhenExistingIdThenThrowIllegalArgumentException() {
 		RegisteredClient registeredClient = createRegisteredClient(
-				this.registration.getId(), "client-id-2", "client-secret-2");
+				this.registeredClient.getId(), "client-id-2", "client-secret-2");
 		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.clients.save(registeredClient))
+				.isThrownBy(() -> this.registeredClientRepository.save(registeredClient))
 				.withMessage("Registered client must be unique. Found duplicate identifier: " + registeredClient.getId());
 	}
 
 	@Test
 	public void saveWhenExistingClientIdThenThrowIllegalArgumentException() {
 		RegisteredClient registeredClient = createRegisteredClient(
-				"client-2", this.registration.getClientId(), "client-secret-2");
+				"client-2", this.registeredClient.getClientId(), "client-secret-2");
 		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.clients.save(registeredClient))
+				.isThrownBy(() -> this.registeredClientRepository.save(registeredClient))
 				.withMessage("Registered client must be unique. Found duplicate client identifier: " + registeredClient.getClientId());
 	}
 
 	@Test
 	public void saveWhenExistingClientSecretThenSuccess() {
 		RegisteredClient registeredClient = createRegisteredClient(
-				"client-2", "client-id-2", this.registration.getClientSecret());
-		this.clients.save(registeredClient);
-		RegisteredClient savedClient = this.clients.findById(registeredClient.getId());
+				"client-2", "client-id-2", this.registeredClient.getClientSecret());
+		this.registeredClientRepository.save(registeredClient);
+		RegisteredClient savedClient = this.registeredClientRepository.findById(registeredClient.getId());
 		assertRegisteredClientIsEqualTo(savedClient, registeredClient);
 	}
 
 	@Test
 	public void saveWhenSavedAndFindByIdThenFound() {
 		RegisteredClient registeredClient = createRegisteredClient();
-		this.clients.save(registeredClient);
-		RegisteredClient savedClient = this.clients.findById(registeredClient.getId());
+		this.registeredClientRepository.save(registeredClient);
+		RegisteredClient savedClient = this.registeredClientRepository.findById(registeredClient.getId());
 		assertRegisteredClientIsEqualTo(savedClient, registeredClient);
 	}
 
 	@Test
 	public void saveWhenSavedAndFindByClientIdThenFound() {
 		RegisteredClient registeredClient = createRegisteredClient();
-		this.clients.save(registeredClient);
-		RegisteredClient savedClient = this.clients.findByClientId(registeredClient.getClientId());
+		this.registeredClientRepository.save(registeredClient);
+		RegisteredClient savedClient = this.registeredClientRepository.findByClientId(registeredClient.getClientId());
 		assertRegisteredClientIsEqualTo(savedClient, registeredClient);
 	}
 
 	@Test
 	public void saveWhenPublicClientSavedAndFindByClientIdThenFound() {
 		RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build();
-		this.clients.save(registeredClient);
-		RegisteredClient savedClient = this.clients.findByClientId(registeredClient.getClientId());
+		this.registeredClientRepository.save(registeredClient);
+		RegisteredClient savedClient = this.registeredClientRepository.findByClientId(registeredClient.getClientId());
 		assertRegisteredClientIsEqualTo(savedClient, registeredClient);
 	}
 
@@ -217,9 +221,9 @@ public class JdbcRegisteredClientRepositoryTests {
 				.id("1").clientId("a").build();
 		RegisteredClient registeredClient2 = TestRegisteredClients.registeredPublicClient()
 				.id("2").clientId("b").build();
-		this.clients.save(registeredClient1);
-		this.clients.save(registeredClient2);
-		RegisteredClient savedClient = this.clients.findByClientId(registeredClient2.getClientId());
+		this.registeredClientRepository.save(registeredClient1);
+		this.registeredClientRepository.save(registeredClient2);
+		RegisteredClient savedClient = this.registeredClientRepository.findByClientId(registeredClient2.getClientId());
 		assertRegisteredClientIsEqualTo(savedClient, registeredClient2);
 	}
 
@@ -243,13 +247,28 @@ public class JdbcRegisteredClientRepositoryTests {
 				})
 				.build();
 
-		this.clients.save(client);
+		this.registeredClientRepository.save(client);
 
-		RegisteredClient retrievedClient = this.clients.findById(client.getId());
+		RegisteredClient retrievedClient = this.registeredClientRepository.findById(client.getId());
 
 		assertRegisteredClientIsEqualTo(retrievedClient, client);
 	}
 
+	@Test
+	public void tableDefinitionWhenCustomThenAbleToOverride() {
+		EmbeddedDatabase db = createDb(CUSTOM_REGISTERED_CLIENT_SCHEMA_SQL_RESOURCE);
+		CustomJdbcRegisteredClientRepository registeredClientRepository =
+				new CustomJdbcRegisteredClientRepository(new JdbcTemplate(db));
+		registeredClientRepository.save(this.registeredClient);
+		RegisteredClient foundClient1 = registeredClientRepository.findById(this.registeredClient.getId());
+		assertThat(foundClient1).isNotNull();
+		assertRegisteredClientIsEqualTo(foundClient1, this.registeredClient);
+		RegisteredClient foundClient2 = registeredClientRepository.findByClientId(this.registeredClient.getClientId());
+		assertThat(foundClient2).isNotNull();
+		assertRegisteredClientIsEqualTo(foundClient2, this.registeredClient);
+		db.shutdown();
+	}
+
 	private void assertRegisteredClientIsEqualTo(RegisteredClient rc, RegisteredClient ref) {
 		assertThat(rc).isNotNull();
 		assertThat(rc.getId()).isEqualTo(ref.getId());
@@ -282,11 +301,21 @@ public class JdbcRegisteredClientRepositoryTests {
 		assertThat(rc.getTokenSettings().refreshTokenTimeToLive()).isEqualTo(ref.getTokenSettings().refreshTokenTimeToLive());
 	}
 
+	private static EmbeddedDatabase createDb(String schema) {
+		// @formatter:off
+		return new EmbeddedDatabaseBuilder()
+				.generateUniqueName(true)
+				.setType(EmbeddedDatabaseType.HSQL)
+				.setScriptEncoding("UTF-8")
+				.addScript(schema)
+				.build();
+		// @formatter:on
+	}
+
 	private static RegisteredClient createRegisteredClient() {
 		return createRegisteredClient("client-2", "client-id-2", "client-secret-2");
 	}
 
-
 	private static RegisteredClient createRegisteredClient(String id, String clientId, String clientSecret) {
 		// @formatter:off
 		return RegisteredClient.withId(id)
@@ -300,4 +329,125 @@ public class JdbcRegisteredClientRepositoryTests {
 		// @formatter:on
 	}
 
+	private static final class CustomJdbcRegisteredClientRepository extends JdbcRegisteredClientRepository {
+
+		private static final String COLUMN_NAMES = "id, "
+				+ "clientId, "
+				+ "clientIdIssuedAt, "
+				+ "clientSecret, "
+				+ "clientSecretExpiresAt, "
+				+ "clientName, "
+				+ "clientAuthenticationMethods, "
+				+ "authorizationGrantTypes, "
+				+ "redirectUris, "
+				+ "scopes, "
+				+ "clientSettings,"
+				+ "tokenSettings";
+
+		private static final String TABLE_NAME = "oauth2RegisteredClient";
+
+		private static final String LOAD_REGISTERED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + " FROM " + TABLE_NAME + " WHERE ";
+
+		private static final String INSERT_REGISTERED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME
+				+ " (" + COLUMN_NAMES + ") values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
+
+		public CustomJdbcRegisteredClientRepository(JdbcOperations jdbcOperations) {
+			super(jdbcOperations);
+			setRegisteredClientRowMapper(new CustomRegisteredClientRowMapper());
+		}
+
+		@Override
+		public void save(RegisteredClient registeredClient) {
+			List<SqlParameterValue> parameters = getRegisteredClientParametersMapper().apply(registeredClient);
+			PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
+			getJdbcOperations().update(INSERT_REGISTERED_CLIENT_SQL, pss);
+		}
+
+		@Override
+		public RegisteredClient findById(String id) {
+			return findBy("id = ?", id);
+		}
+
+		@Override
+		public RegisteredClient findByClientId(String clientId) {
+			return findBy("clientId = ?", clientId);
+		}
+
+		private RegisteredClient findBy(String filter, Object... args) {
+			List<RegisteredClient> result = getJdbcOperations()
+					.query(LOAD_REGISTERED_CLIENT_SQL + filter, getRegisteredClientRowMapper(), args);
+			return !result.isEmpty() ? result.get(0) : null;
+		}
+
+		private static final class CustomRegisteredClientRowMapper implements RowMapper<RegisteredClient> {
+
+			private static final Map<String, AuthorizationGrantType> AUTHORIZATION_GRANT_TYPE_MAP;
+			private static final Map<String, ClientAuthenticationMethod> CLIENT_AUTHENTICATION_METHOD_MAP;
+
+			private final ObjectMapper objectMapper = new ObjectMapper();
+
+			@Override
+			public RegisteredClient mapRow(ResultSet rs, int rowNum) throws SQLException {
+				Set<String> clientScopes = StringUtils.commaDelimitedListToSet(rs.getString("scopes"));
+				Set<String> authGrantTypes = StringUtils.commaDelimitedListToSet(rs.getString("authorizationGrantTypes"));
+				Set<String> clientAuthMethods = StringUtils.commaDelimitedListToSet(rs.getString("clientAuthenticationMethods"));
+				Set<String> redirectUris = StringUtils.commaDelimitedListToSet(rs.getString("redirectUris"));
+				Timestamp clientIssuedAt = rs.getTimestamp("clientIdIssuedAt");
+				Timestamp clientSecretExpiresAt = rs.getTimestamp("clientSecretExpiresAt");
+				String clientSecret = rs.getString("clientSecret");
+				RegisteredClient.Builder builder = RegisteredClient
+						.withId(rs.getString("id"))
+						.clientId(rs.getString("clientId"))
+						.clientIdIssuedAt(clientIssuedAt != null ? clientIssuedAt.toInstant() : null)
+						.clientSecret(clientSecret)
+						.clientSecretExpiresAt(clientSecretExpiresAt != null ? clientSecretExpiresAt.toInstant() : null)
+						.clientName(rs.getString("clientName"))
+						.authorizationGrantTypes((grantTypes) -> authGrantTypes.forEach(authGrantType ->
+								grantTypes.add(AUTHORIZATION_GRANT_TYPE_MAP.get(authGrantType))))
+						.clientAuthenticationMethods((authenticationMethods) -> clientAuthMethods.forEach(clientAuthMethod ->
+								authenticationMethods.add(CLIENT_AUTHENTICATION_METHOD_MAP.get(clientAuthMethod))))
+						.redirectUris((uris) -> uris.addAll(redirectUris))
+						.scopes((scopes) -> scopes.addAll(clientScopes));
+
+				RegisteredClient registeredClient = builder.build();
+				registeredClient.getClientSettings().settings().putAll(parseMap(rs.getString("clientSettings")));
+				registeredClient.getTokenSettings().settings().putAll(parseMap(rs.getString("tokenSettings")));
+
+				return registeredClient;
+			}
+
+			private Map<String, Object> parseMap(String data) {
+				try {
+					return this.objectMapper.readValue(data, new TypeReference<Map<String, Object>>() {});
+				} catch (Exception ex) {
+					throw new IllegalArgumentException(ex.getMessage(), ex);
+				}
+			}
+
+			static {
+				Map<String, AuthorizationGrantType> am = new HashMap<>();
+				for (AuthorizationGrantType a : Arrays.asList(
+						AuthorizationGrantType.AUTHORIZATION_CODE,
+						AuthorizationGrantType.REFRESH_TOKEN,
+						AuthorizationGrantType.CLIENT_CREDENTIALS,
+						AuthorizationGrantType.PASSWORD,
+						AuthorizationGrantType.IMPLICIT)) {
+					am.put(a.getValue(), a);
+				}
+				AUTHORIZATION_GRANT_TYPE_MAP = Collections.unmodifiableMap(am);
+
+				Map<String, ClientAuthenticationMethod> cm = new HashMap<>();
+				for (ClientAuthenticationMethod c : Arrays.asList(
+						ClientAuthenticationMethod.NONE,
+						ClientAuthenticationMethod.BASIC,
+						ClientAuthenticationMethod.POST)) {
+					cm.put(c.getValue(), c);
+				}
+				CLIENT_AUTHENTICATION_METHOD_MAP = Collections.unmodifiableMap(cm);
+			}
+
+		}
+
+	}
+
 }

+ 15 - 0
oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/client/custom-oauth2-registered-client-schema.sql

@@ -0,0 +1,15 @@
+CREATE TABLE oauth2RegisteredClient (
+    id varchar(100) NOT NULL,
+    clientId varchar(100) NOT NULL,
+    clientIdIssuedAt timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
+    clientSecret varchar(200) DEFAULT NULL,
+    clientSecretExpiresAt timestamp DEFAULT NULL,
+    clientName varchar(200),
+    clientAuthenticationMethods varchar(1000) NOT NULL,
+    authorizationGrantTypes varchar(1000) NOT NULL,
+    redirectUris varchar(1000) NOT NULL,
+    scopes varchar(1000) NOT NULL,
+    clientSettings varchar(1000) DEFAULT NULL,
+    tokenSettings varchar(1000) DEFAULT NULL,
+    PRIMARY KEY (id)
+);