浏览代码

Add test to override schema for JdbcOAuth2AuthorizationConsentService

Steve Riesenberg 4 年之前
父节点
当前提交
a949998664

+ 122 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationConsentServiceTests.java

@@ -15,11 +15,21 @@
  */
  */
 package org.springframework.security.oauth2.server.authorization;
 package org.springframework.security.oauth2.server.authorization;
 
 
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Types;
+import java.util.List;
+
 import org.junit.After;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
+
+import org.springframework.dao.DataRetrievalFailureException;
+import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
 import org.springframework.jdbc.core.JdbcOperations;
 import org.springframework.jdbc.core.JdbcOperations;
 import org.springframework.jdbc.core.JdbcTemplate;
 import org.springframework.jdbc.core.JdbcTemplate;
+import org.springframework.jdbc.core.PreparedStatementSetter;
+import org.springframework.jdbc.core.SqlParameterValue;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
@@ -27,6 +37,7 @@ import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.util.StringUtils;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -47,6 +58,7 @@ import static org.mockito.Mockito.when;
 public class JdbcOAuth2AuthorizationConsentServiceTests {
 public class JdbcOAuth2AuthorizationConsentServiceTests {
 
 
 	private static final String OAUTH2_AUTHORIZATION_CONSENT_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql";
 	private static final String OAUTH2_AUTHORIZATION_CONSENT_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/oauth2-authorization-consent-schema.sql";
+	private static final String CUSTOM_OAUTH2_AUTHORIZATION_CONSENT_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-consent-schema.sql";
 	private static final String PRINCIPAL_NAME = "principal-name";
 	private static final String PRINCIPAL_NAME = "principal-name";
 	private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
 	private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
 
 
@@ -200,6 +212,23 @@ public class JdbcOAuth2AuthorizationConsentServiceTests {
 		assertThat(this.authorizationConsentService.findById(REGISTERED_CLIENT.getId(), "unknown-user")).isNull();
 		assertThat(this.authorizationConsentService.findById(REGISTERED_CLIENT.getId(), "unknown-user")).isNull();
 	}
 	}
 
 
+	@Test
+	public void tableDefinitionWhenCustomThenAbleToOverride() {
+		when(this.registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId())))
+				.thenReturn(REGISTERED_CLIENT);
+
+		EmbeddedDatabase db = createDb(CUSTOM_OAUTH2_AUTHORIZATION_CONSENT_SCHEMA_SQL_RESOURCE);
+		OAuth2AuthorizationConsentService authorizationConsentService =
+				new CustomJdbcOAuth2AuthorizationConsentService(new JdbcTemplate(db), this.registeredClientRepository);
+		authorizationConsentService.save(AUTHORIZATION_CONSENT);
+		OAuth2AuthorizationConsent foundAuthorizationConsent1 = authorizationConsentService.findById(AUTHORIZATION_CONSENT.getRegisteredClientId(), AUTHORIZATION_CONSENT.getPrincipalName());
+		assertThat(foundAuthorizationConsent1).isEqualTo(AUTHORIZATION_CONSENT);
+		authorizationConsentService.remove(AUTHORIZATION_CONSENT);
+		OAuth2AuthorizationConsent foundAuthorizationConsent2 = authorizationConsentService.findById(REGISTERED_CLIENT.getClientId(), AUTHORIZATION_CONSENT.getPrincipalName());
+		assertThat(foundAuthorizationConsent2).isNull();
+		db.shutdown();
+	}
+
 	@Before
 	@Before
 	public void setUp() {
 	public void setUp() {
 		this.db = createDb();
 		this.db = createDb();
@@ -216,6 +245,7 @@ public class JdbcOAuth2AuthorizationConsentServiceTests {
 	private static EmbeddedDatabase createDb() {
 	private static EmbeddedDatabase createDb() {
 		return createDb(OAUTH2_AUTHORIZATION_CONSENT_SCHEMA_SQL_RESOURCE);
 		return createDb(OAUTH2_AUTHORIZATION_CONSENT_SCHEMA_SQL_RESOURCE);
 	}
 	}
+
 	private static EmbeddedDatabase createDb(String schema) {
 	private static EmbeddedDatabase createDb(String schema) {
 		// @formatter:off
 		// @formatter:off
 		return new EmbeddedDatabaseBuilder()
 		return new EmbeddedDatabaseBuilder()
@@ -226,4 +256,96 @@ public class JdbcOAuth2AuthorizationConsentServiceTests {
 				.build();
 				.build();
 		// @formatter:on
 		// @formatter:on
 	}
 	}
+
+	private static final class CustomJdbcOAuth2AuthorizationConsentService extends JdbcOAuth2AuthorizationConsentService {
+
+		// @formatter:off
+		private static final String COLUMN_NAMES = "registeredClientId, "
+				+ "principalName, "
+				+ "authorities";
+		// @formatter:on
+
+		private static final String TABLE_NAME = "oauth2AuthorizationConsent";
+
+		private static final String PK_FILTER = "registeredClientId = ? AND principalName = ?";
+
+		// @formatter:off
+		private static final String LOAD_AUTHORIZATION_CONSENT_SQL = "SELECT " + COLUMN_NAMES
+				+ " FROM " + TABLE_NAME
+				+ " WHERE " + PK_FILTER;
+		// @formatter:on
+
+		// @formatter:off
+		private static final String SAVE_AUTHORIZATION_CONSENT_SQL = "INSERT INTO " + TABLE_NAME
+				+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?)";
+		// @formatter:on
+
+		private static final String REMOVE_AUTHORIZATION_CONSENT_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER;
+
+		CustomJdbcOAuth2AuthorizationConsentService(JdbcOperations jdbcOperations, RegisteredClientRepository registeredClientRepository) {
+			super(jdbcOperations, registeredClientRepository);
+			setAuthorizationConsentRowMapper(new CustomOAuth2AuthorizationConsentRowMapper(registeredClientRepository));
+		}
+
+		@Override
+		public void save(OAuth2AuthorizationConsent authorizationConsent) {
+			List<SqlParameterValue> parameters = getAuthorizationConsentParametersMapper().apply(authorizationConsent);
+			PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
+			getJdbcOperations().update(SAVE_AUTHORIZATION_CONSENT_SQL, pss);
+		}
+
+		@Override
+		public void remove(OAuth2AuthorizationConsent authorizationConsent) {
+			SqlParameterValue[] parameters = new SqlParameterValue[] {
+					new SqlParameterValue(Types.VARCHAR, authorizationConsent.getRegisteredClientId()),
+					new SqlParameterValue(Types.VARCHAR, authorizationConsent.getPrincipalName())
+			};
+			PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
+			getJdbcOperations().update(REMOVE_AUTHORIZATION_CONSENT_SQL, pss);
+		}
+
+		@Override
+		public OAuth2AuthorizationConsent findById(String registeredClientId, String principalName) {
+			SqlParameterValue[] parameters = new SqlParameterValue[] {
+					new SqlParameterValue(Types.VARCHAR, registeredClientId),
+					new SqlParameterValue(Types.VARCHAR, principalName)};
+			PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
+			List<OAuth2AuthorizationConsent> result = getJdbcOperations().query(LOAD_AUTHORIZATION_CONSENT_SQL, pss,
+					getAuthorizationConsentRowMapper());
+			return !result.isEmpty() ? result.get(0) : null;
+		}
+
+		private static final class CustomOAuth2AuthorizationConsentRowMapper extends JdbcOAuth2AuthorizationConsentService.OAuth2AuthorizationConsentRowMapper {
+
+			CustomOAuth2AuthorizationConsentRowMapper(RegisteredClientRepository registeredClientRepository) {
+				super(registeredClientRepository);
+			}
+
+			@Override
+			public OAuth2AuthorizationConsent mapRow(ResultSet rs, int rowNum) throws SQLException {
+				String registeredClientId = rs.getString("registeredClientId");
+
+				RegisteredClient registeredClient = getRegisteredClientRepository()
+						.findById(registeredClientId);
+				if (registeredClient == null) {
+					throw new DataRetrievalFailureException(
+							"The RegisteredClient with id '" + registeredClientId + "' was not found in the RegisteredClientRepository.");
+				}
+
+				String principalName = rs.getString("principalName");
+
+				OAuth2AuthorizationConsent.Builder builder = OAuth2AuthorizationConsent.withId(registeredClientId, principalName);
+				String authorizationConsentAuthorities = rs.getString("authorities");
+				if (authorizationConsentAuthorities != null) {
+					for (String authority : StringUtils.commaDelimitedListToSet(authorizationConsentAuthorities)) {
+						builder.authority(new SimpleGrantedAuthority(authority));
+					}
+				}
+				return builder.build();
+			}
+
+		}
+
+	}
+
 }
 }

+ 6 - 0
oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-consent-schema.sql

@@ -0,0 +1,6 @@
+CREATE TABLE oauth2AuthorizationConsent (
+    registeredClientId varchar(100) NOT NULL,
+    principalName varchar(200) NOT NULL,
+    authorities varchar(1000) NOT NULL,
+    PRIMARY KEY (registeredClientId, principalName)
+);