ソースを参照

Add test to override schema for JdbcOAuth2AuthorizationService

Steve Riesenberg 4 年 前
コミット
99fb4c8a5f

+ 319 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationServiceTests.java

@@ -15,32 +15,48 @@
  */
 package org.springframework.security.oauth2.server.authorization;
 
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Timestamp;
+import java.sql.Types;
 import java.time.Instant;
 import java.time.temporal.ChronoUnit;
+import java.util.ArrayList;
+import java.util.Collections;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
 import java.util.function.Function;
 
+import com.fasterxml.jackson.core.type.TypeReference;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
+import org.springframework.dao.DataRetrievalFailureException;
+import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
 import org.springframework.jdbc.core.JdbcOperations;
 import org.springframework.jdbc.core.JdbcTemplate;
+import org.springframework.jdbc.core.PreparedStatementSetter;
 import org.springframework.jdbc.core.RowMapper;
 import org.springframework.jdbc.core.SqlParameterValue;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
+import org.springframework.security.oauth2.core.AbstractOAuth2Token;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken;
 import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
 import org.springframework.security.oauth2.core.OAuth2TokenType;
 import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
 import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
 import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.StringUtils;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatThrownBy;
@@ -59,6 +75,7 @@ import static org.mockito.Mockito.when;
  */
 public class JdbcOAuth2AuthorizationServiceTests {
 	private static final String OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql";
+	private static final String CUSTOM_OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-schema.sql";
 	private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
 	private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE);
 	private static final String ID = "id";
@@ -374,6 +391,30 @@ public class JdbcOAuth2AuthorizationServiceTests {
 		assertThat(result).isNull();
 	}
 
+	@Test
+	public void tableDefinitionWhenCustomThenAbleToOverride() {
+		when(this.registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId())))
+				.thenReturn(REGISTERED_CLIENT);
+
+		EmbeddedDatabase db = createDb(CUSTOM_OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE);
+		OAuth2AuthorizationService authorizationService =
+				new CustomJdbcOAuth2AuthorizationService(new JdbcTemplate(db), this.registeredClientRepository);
+		String state = "state";
+		OAuth2Authorization originalAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.id(ID)
+				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
+				.attribute(OAuth2ParameterNames.STATE, state)
+				.token(AUTHORIZATION_CODE)
+				.build();
+		authorizationService.save(originalAuthorization);
+		OAuth2Authorization foundAuthorization1 = authorizationService.findById(originalAuthorization.getId());
+		assertThat(foundAuthorization1).isEqualTo(originalAuthorization);
+		OAuth2Authorization foundAuthorization2 = authorizationService.findByToken(state, STATE_TOKEN_TYPE);
+		assertThat(foundAuthorization2).isEqualTo(originalAuthorization);
+		db.shutdown();
+	}
+
 	private static EmbeddedDatabase createDb() {
 		return createDb(OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE);
 	}
@@ -388,4 +429,282 @@ public class JdbcOAuth2AuthorizationServiceTests {
 				.build();
 		// @formatter:on
 	}
+
+	private static final class CustomJdbcOAuth2AuthorizationService extends JdbcOAuth2AuthorizationService {
+
+		// @formatter:off
+		private static final String COLUMN_NAMES = "id, "
+				+ "registeredClientId, "
+				+ "principalName, "
+				+ "authorizationGrantType, "
+				+ "attributes, "
+				+ "state, "
+				+ "authorizationCodeValue, "
+				+ "authorizationCodeIssuedAt, "
+				+ "authorizationCodeExpiresAt,"
+				+ "authorizationCodeMetadata,"
+				+ "accessTokenValue,"
+				+ "accessTokenIssuedAt,"
+				+ "accessTokenExpiresAt,"
+				+ "accessTokenMetadata,"
+				+ "accessTokenType,"
+				+ "accessTokenScopes,"
+				+ "oidcIdTokenValue,"
+				+ "oidcIdTokenIssuedAt,"
+				+ "oidcIdTokenExpiresAt,"
+				+ "oidcIdTokenMetadata,"
+				+ "refreshTokenValue,"
+				+ "refreshTokenIssuedAt,"
+				+ "refreshTokenExpiresAt,"
+				+ "refreshTokenMetadata";
+		// @formatter:on
+
+		private static final String TABLE_NAME = "oauth2Authorization";
+
+		private static final String PK_FILTER = "id = ?";
+		private static final String UNKNOWN_TOKEN_TYPE_FILTER = "state = ? OR authorizationCodeValue = ? OR " +
+				"accessTokenValue = ? OR " +
+				"refreshTokenValue = ?";
+
+		// @formatter:off
+		private static final String LOAD_AUTHORIZATION_SQL = "SELECT " + COLUMN_NAMES
+				+ " FROM " + TABLE_NAME
+				+ " WHERE ";
+		// @formatter:on
+
+		// @formatter:off
+		private static final String SAVE_AUTHORIZATION_SQL = "INSERT INTO " + TABLE_NAME
+				+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?,?, ?, ?, ?, ?, ?, ?, ?,?, ?, ?, ?, ?, ?, ?, ?)";
+		// @formatter:on
+
+		private static final String REMOVE_AUTHORIZATION_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER;
+
+		CustomJdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations,
+				RegisteredClientRepository registeredClientRepository) {
+			super(jdbcOperations, registeredClientRepository);
+			setAuthorizationRowMapper(new CustomOAuth2AuthorizationRowMapper(registeredClientRepository));
+			setAuthorizationParametersMapper(new CustomOAuth2AuthorizationParametersMapper());
+		}
+
+		@Override
+		public void save(OAuth2Authorization authorization) {
+			List<SqlParameterValue> parameters = getAuthorizationParametersMapper().apply(authorization);
+			PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
+			getJdbcOperations().update(SAVE_AUTHORIZATION_SQL, pss);
+		}
+
+		@Override
+		public void remove(OAuth2Authorization authorization) {
+			SqlParameterValue[] parameters = new SqlParameterValue[] {
+					new SqlParameterValue(Types.VARCHAR, authorization.getId())
+			};
+			PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
+			getJdbcOperations().update(REMOVE_AUTHORIZATION_SQL, pss);
+		}
+
+		@Override
+		public OAuth2Authorization findById(String id) {
+			return findBy(PK_FILTER, id);
+		}
+
+		@Override
+		public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) {
+			return findBy(UNKNOWN_TOKEN_TYPE_FILTER, token, token, token, token);
+		}
+
+		private OAuth2Authorization findBy(String filter, Object... args) {
+			List<OAuth2Authorization> result = getJdbcOperations()
+					.query(LOAD_AUTHORIZATION_SQL + filter, getAuthorizationRowMapper(), args);
+			return !result.isEmpty() ? result.get(0) : null;
+		}
+
+		private static final class CustomOAuth2AuthorizationRowMapper extends JdbcOAuth2AuthorizationService.OAuth2AuthorizationRowMapper {
+
+			CustomOAuth2AuthorizationRowMapper(RegisteredClientRepository registeredClientRepository) {
+				super(registeredClientRepository);
+			}
+
+			@Override
+			@SuppressWarnings("unchecked")
+			public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException {
+				String registeredClientId = rs.getString("registeredClientId");
+				RegisteredClient registeredClient = getRegisteredClientRepository().findById(registeredClientId);
+				if (registeredClient == null) {
+					throw new DataRetrievalFailureException(
+							"The RegisteredClient with id '" + registeredClientId + "' was not found in the RegisteredClientRepository.");
+				}
+
+				OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient);
+				String id = rs.getString("id");
+				String principalName = rs.getString("principalName");
+				String authorizationGrantType = rs.getString("authorizationGrantType");
+				Map<String, Object> attributes = parseMap(rs.getString("attributes"));
+
+				builder.id(id)
+						.principalName(principalName)
+						.authorizationGrantType(new AuthorizationGrantType(authorizationGrantType))
+						.attributes((attrs) -> attrs.putAll(attributes));
+
+				String state = rs.getString("state");
+				if (StringUtils.hasText(state)) {
+					builder.attribute(OAuth2ParameterNames.STATE, state);
+				}
+
+				String tokenValue = rs.getString("authorizationCodeValue");
+				Instant tokenIssuedAt;
+				Instant tokenExpiresAt;
+				if (tokenValue != null) {
+					tokenIssuedAt = rs.getTimestamp("authorizationCodeIssuedAt").toInstant();
+					tokenExpiresAt = rs.getTimestamp("authorizationCodeExpiresAt").toInstant();
+					Map<String, Object> authorizationCodeMetadata = parseMap(rs.getString("authorizationCodeMetadata"));
+
+					OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
+							tokenValue, tokenIssuedAt, tokenExpiresAt);
+					builder.token(authorizationCode, (metadata) -> metadata.putAll(authorizationCodeMetadata));
+				}
+
+				tokenValue = rs.getString("accessTokenValue");
+				if (tokenValue != null) {
+					tokenIssuedAt = rs.getTimestamp("accessTokenIssuedAt").toInstant();
+					tokenExpiresAt = rs.getTimestamp("accessTokenExpiresAt").toInstant();
+					Map<String, Object> accessTokenMetadata = parseMap(rs.getString("accessTokenMetadata"));
+					OAuth2AccessToken.TokenType tokenType = null;
+					if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(rs.getString("accessTokenType"))) {
+						tokenType = OAuth2AccessToken.TokenType.BEARER;
+					}
+
+					Set<String> scopes = Collections.emptySet();
+					String accessTokenScopes = rs.getString("accessTokenScopes");
+					if (accessTokenScopes != null) {
+						scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes);
+					}
+					OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, tokenIssuedAt, tokenExpiresAt, scopes);
+					builder.token(accessToken, (metadata) -> metadata.putAll(accessTokenMetadata));
+				}
+
+				tokenValue = rs.getString("oidcIdTokenValue");
+				if (tokenValue != null) {
+					tokenIssuedAt = rs.getTimestamp("oidcIdTokenIssuedAt").toInstant();
+					tokenExpiresAt = rs.getTimestamp("oidcIdTokenExpiresAt").toInstant();
+					Map<String, Object> oidcTokenMetadata = parseMap(rs.getString("oidcIdTokenMetadata"));
+
+					OidcIdToken oidcToken = new OidcIdToken(
+							tokenValue, tokenIssuedAt, tokenExpiresAt, (Map<String, Object>) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME));
+					builder.token(oidcToken, (metadata) -> metadata.putAll(oidcTokenMetadata));
+				}
+
+				tokenValue = rs.getString("refreshTokenValue");
+				if (tokenValue != null) {
+					tokenIssuedAt = rs.getTimestamp("refreshTokenIssuedAt").toInstant();
+					tokenExpiresAt = null;
+					Timestamp refreshTokenExpiresAt = rs.getTimestamp("refreshTokenExpiresAt");
+					if (refreshTokenExpiresAt != null) {
+						tokenExpiresAt = refreshTokenExpiresAt.toInstant();
+					}
+					Map<String, Object> refreshTokenMetadata = parseMap(rs.getString("refreshTokenMetadata"));
+
+					OAuth2RefreshToken refreshToken = new OAuth2RefreshToken2(
+							tokenValue, tokenIssuedAt, tokenExpiresAt);
+					builder.token(refreshToken, (metadata) -> metadata.putAll(refreshTokenMetadata));
+				}
+
+				return builder.build();
+			}
+
+			private Map<String, Object> parseMap(String data) {
+				try {
+					return getObjectMapper().readValue(data, new TypeReference<Map<String, Object>>() {});
+				} catch (Exception ex) {
+					throw new IllegalArgumentException(ex.getMessage(), ex);
+				}
+			}
+
+		}
+
+		private static final class CustomOAuth2AuthorizationParametersMapper extends JdbcOAuth2AuthorizationService.OAuth2AuthorizationParametersMapper {
+
+			@Override
+			public List<SqlParameterValue> apply(OAuth2Authorization authorization) {
+				List<SqlParameterValue> parameters = new ArrayList<>();
+				parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getId()));
+				parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getRegisteredClientId()));
+				parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getPrincipalName()));
+				parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getAuthorizationGrantType().getValue()));
+
+				String attributes = writeMap(authorization.getAttributes());
+				parameters.add(new SqlParameterValue(Types.VARCHAR, attributes));
+
+				String state = null;
+				String authorizationState = authorization.getAttribute(OAuth2ParameterNames.STATE);
+				if (StringUtils.hasText(authorizationState)) {
+					state = authorizationState;
+				}
+				parameters.add(new SqlParameterValue(Types.VARCHAR, state));
+
+				OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
+						authorization.getToken(OAuth2AuthorizationCode.class);
+				List<SqlParameterValue> authorizationCodeSqlParameters = toSqlParameterList(authorizationCode);
+				parameters.addAll(authorizationCodeSqlParameters);
+
+				OAuth2Authorization.Token<OAuth2AccessToken> accessToken =
+						authorization.getToken(OAuth2AccessToken.class);
+				List<SqlParameterValue> accessTokenSqlParameters = toSqlParameterList(accessToken);
+				parameters.addAll(accessTokenSqlParameters);
+				String accessTokenType = null;
+				String accessTokenScopes = null;
+				if (accessToken != null) {
+					accessTokenType = accessToken.getToken().getTokenType().getValue();
+					if (!CollectionUtils.isEmpty(accessToken.getToken().getScopes())) {
+						accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getToken().getScopes(), ",");
+					}
+				}
+				parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenType));
+				parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenScopes));
+
+				OAuth2Authorization.Token<OidcIdToken> oidcIdToken = authorization.getToken(OidcIdToken.class);
+				List<SqlParameterValue> oidcIdTokenSqlParameters = toSqlParameterList(oidcIdToken);
+				parameters.addAll(oidcIdTokenSqlParameters);
+
+				OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = authorization.getRefreshToken();
+				List<SqlParameterValue> refreshTokenSqlParameters = toSqlParameterList(refreshToken);
+				parameters.addAll(refreshTokenSqlParameters);
+				return parameters;
+			}
+
+			private <T extends AbstractOAuth2Token> List<SqlParameterValue> toSqlParameterList(OAuth2Authorization.Token<T> token) {
+				List<SqlParameterValue> parameters = new ArrayList<>();
+				String tokenValue = null;
+				Timestamp tokenIssuedAt = null;
+				Timestamp tokenExpiresAt = null;
+				String metadata = null;
+				if (token != null) {
+					tokenValue = token.getToken().getTokenValue();
+					if (token.getToken().getIssuedAt() != null) {
+						tokenIssuedAt = Timestamp.from(token.getToken().getIssuedAt());
+					}
+
+					if (token.getToken().getExpiresAt() != null) {
+						tokenExpiresAt = Timestamp.from(token.getToken().getExpiresAt());
+					}
+					metadata = writeMap(token.getMetadata());
+				}
+				parameters.add(new SqlParameterValue(Types.VARCHAR, tokenValue));
+				parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenIssuedAt));
+				parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenExpiresAt));
+				parameters.add(new SqlParameterValue(Types.VARCHAR, metadata));
+				return parameters;
+			}
+
+			private String writeMap(Map<String, Object> data) {
+				try {
+					return getObjectMapper().writeValueAsString(data);
+				} catch (Exception ex) {
+					throw new IllegalArgumentException(ex.getMessage(), ex);
+				}
+			}
+
+		}
+
+	}
+
 }

+ 6 - 0
oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-consent-schema.sql

@@ -0,0 +1,6 @@
+CREATE TABLE oauth2_authorization_consent (
+    registered_client_id varchar(100) NOT NULL,
+    principal_name varchar(200) NOT NULL,
+    authorities varchar(1000) NOT NULL,
+    PRIMARY KEY (registered_client_id, principal_name)
+);