Эх сурвалжийг харах

Support clob and text datatype for token columns

Closes gh-480
Ovidiu Popa 3 жил өмнө
parent
commit
66bc5a0e65

+ 84 - 27
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationService.java

@@ -16,6 +16,7 @@
 package org.springframework.security.oauth2.server.authorization;
 
 import java.nio.charset.StandardCharsets;
+import java.sql.DatabaseMetaData;
 import java.sql.PreparedStatement;
 import java.sql.ResultSet;
 import java.sql.SQLException;
@@ -35,6 +36,7 @@ import com.fasterxml.jackson.databind.ObjectMapper;
 
 import org.springframework.dao.DataRetrievalFailureException;
 import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
+import org.springframework.jdbc.core.ConnectionCallback;
 import org.springframework.jdbc.core.JdbcOperations;
 import org.springframework.jdbc.core.PreparedStatementSetter;
 import org.springframework.jdbc.core.RowMapper;
@@ -141,6 +143,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 
 	private final JdbcOperations jdbcOperations;
 	private final LobHandler lobHandler;
+	private static int tokenColumnType;
 	private RowMapper<OAuth2Authorization> authorizationRowMapper;
 	private Function<OAuth2Authorization, List<SqlParameterValue>> authorizationParametersMapper;
 
@@ -169,12 +172,15 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 		Assert.notNull(lobHandler, "lobHandler cannot be null");
 		this.jdbcOperations = jdbcOperations;
 		this.lobHandler = lobHandler;
+		tokenColumnType = getColumnDataType(jdbcOperations, "access_token_value");
 		OAuth2AuthorizationRowMapper authorizationRowMapper = new OAuth2AuthorizationRowMapper(registeredClientRepository);
 		authorizationRowMapper.setLobHandler(lobHandler);
 		this.authorizationRowMapper = authorizationRowMapper;
-		this.authorizationParametersMapper = new OAuth2AuthorizationParametersMapper();
+		OAuth2AuthorizationParametersMapper authorizationParametersMapper = new OAuth2AuthorizationParametersMapper();
+		this.authorizationParametersMapper = authorizationParametersMapper;
 	}
 
+
 	@Override
 	public void save(OAuth2Authorization authorization) {
 		Assert.notNull(authorization, "authorization cannot be null");
@@ -232,26 +238,33 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 		List<SqlParameterValue> parameters = new ArrayList<>();
 		if (tokenType == null) {
 			parameters.add(new SqlParameterValue(Types.VARCHAR, token));
-			parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
-			parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
-			parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
+			parameters.add(mapTokenToSqlParameter(token));
+			parameters.add(mapTokenToSqlParameter(token));
+			parameters.add(mapTokenToSqlParameter(token));
 			return findBy(UNKNOWN_TOKEN_TYPE_FILTER, parameters);
 		} else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) {
 			parameters.add(new SqlParameterValue(Types.VARCHAR, token));
 			return findBy(STATE_FILTER, parameters);
 		} else if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) {
-			parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
+			parameters.add(mapTokenToSqlParameter(token));
 			return findBy(AUTHORIZATION_CODE_FILTER, parameters);
 		} else if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenType)) {
-			parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
+			parameters.add(mapTokenToSqlParameter(token));
 			return findBy(ACCESS_TOKEN_FILTER, parameters);
 		} else if (OAuth2TokenType.REFRESH_TOKEN.equals(tokenType)) {
-			parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
+			parameters.add(mapTokenToSqlParameter(token));
 			return findBy(REFRESH_TOKEN_FILTER, parameters);
 		}
 		return null;
 	}
 
+	private SqlParameterValue mapTokenToSqlParameter(String token) {
+		if (Types.BLOB == tokenColumnType) {
+			return new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8));
+		}
+		return new SqlParameterValue(tokenColumnType, token);
+	}
+
 	private OAuth2Authorization findBy(String filter, List<SqlParameterValue> parameters) {
 		try (LobCreator lobCreator = getLobHandler().getLobCreator()) {
 			PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator,
@@ -349,25 +362,22 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 				builder.attribute(OAuth2ParameterNames.STATE, state);
 			}
 
-			String tokenValue;
 			Instant tokenIssuedAt;
 			Instant tokenExpiresAt;
-			byte[] authorizationCodeValue = this.lobHandler.getBlobAsBytes(rs, "authorization_code_value");
+			String authorizationCodeValue = getTokenValue(rs, "authorization_code_value");
 
-			if (authorizationCodeValue != null) {
-				tokenValue = new String(authorizationCodeValue, StandardCharsets.UTF_8);
+			if (StringUtils.hasText(authorizationCodeValue)) {
 				tokenIssuedAt = rs.getTimestamp("authorization_code_issued_at").toInstant();
 				tokenExpiresAt = rs.getTimestamp("authorization_code_expires_at").toInstant();
 				Map<String, Object> authorizationCodeMetadata = parseMap(rs.getString("authorization_code_metadata"));
 
 				OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
-						tokenValue, tokenIssuedAt, tokenExpiresAt);
+						authorizationCodeValue, tokenIssuedAt, tokenExpiresAt);
 				builder.token(authorizationCode, (metadata) -> metadata.putAll(authorizationCodeMetadata));
 			}
 
-			byte[] accessTokenValue = this.lobHandler.getBlobAsBytes(rs, "access_token_value");
-			if (accessTokenValue != null) {
-				tokenValue = new String(accessTokenValue, StandardCharsets.UTF_8);
+			String accessTokenValue = getTokenValue(rs, "access_token_value");
+			if (StringUtils.hasText(accessTokenValue)) {
 				tokenIssuedAt = rs.getTimestamp("access_token_issued_at").toInstant();
 				tokenExpiresAt = rs.getTimestamp("access_token_expires_at").toInstant();
 				Map<String, Object> accessTokenMetadata = parseMap(rs.getString("access_token_metadata"));
@@ -381,25 +391,23 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 				if (accessTokenScopes != null) {
 					scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes);
 				}
-				OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, tokenIssuedAt, tokenExpiresAt, scopes);
+				OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, accessTokenValue, tokenIssuedAt, tokenExpiresAt, scopes);
 				builder.token(accessToken, (metadata) -> metadata.putAll(accessTokenMetadata));
 			}
 
-			byte[] oidcIdTokenValue = this.lobHandler.getBlobAsBytes(rs, "oidc_id_token_value");
-			if (oidcIdTokenValue != null) {
-				tokenValue = new String(oidcIdTokenValue, StandardCharsets.UTF_8);
+			String oidcIdTokenValue = getTokenValue(rs, "oidc_id_token_value");
+			if (StringUtils.hasText(oidcIdTokenValue)) {
 				tokenIssuedAt = rs.getTimestamp("oidc_id_token_issued_at").toInstant();
 				tokenExpiresAt = rs.getTimestamp("oidc_id_token_expires_at").toInstant();
 				Map<String, Object> oidcTokenMetadata = parseMap(rs.getString("oidc_id_token_metadata"));
 
 				OidcIdToken oidcToken = new OidcIdToken(
-						tokenValue, tokenIssuedAt, tokenExpiresAt, (Map<String, Object>) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME));
+						oidcIdTokenValue, tokenIssuedAt, tokenExpiresAt, (Map<String, Object>) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME));
 				builder.token(oidcToken, (metadata) -> metadata.putAll(oidcTokenMetadata));
 			}
 
-			byte[] refreshTokenValue = this.lobHandler.getBlobAsBytes(rs, "refresh_token_value");
-			if (refreshTokenValue != null) {
-				tokenValue = new String(refreshTokenValue, StandardCharsets.UTF_8);
+			String refreshTokenValue = getTokenValue(rs, "refresh_token_value");
+			if (StringUtils.hasText(refreshTokenValue)) {
 				tokenIssuedAt = rs.getTimestamp("refresh_token_issued_at").toInstant();
 				tokenExpiresAt = null;
 				Timestamp refreshTokenExpiresAt = rs.getTimestamp("refresh_token_expires_at");
@@ -409,12 +417,29 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 				Map<String, Object> refreshTokenMetadata = parseMap(rs.getString("refresh_token_metadata"));
 
 				OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(
-						tokenValue, tokenIssuedAt, tokenExpiresAt);
+						refreshTokenValue, tokenIssuedAt, tokenExpiresAt);
 				builder.token(refreshToken, (metadata) -> metadata.putAll(refreshTokenMetadata));
 			}
 			return builder.build();
 		}
 
+		private String getTokenValue(ResultSet rs, String tokenColumn) throws SQLException {
+			String tokenValue = null;
+			if (Types.CLOB == tokenColumnType) {
+				tokenValue = this.lobHandler.getClobAsString(rs, tokenColumn);
+			}
+			if (Types.VARCHAR == tokenColumnType) {
+				tokenValue = rs.getString(tokenColumn);
+			}
+			if (Types.BLOB == tokenColumnType) {
+				byte[] tokenValueByte = this.lobHandler.getBlobAsBytes(rs, tokenColumn);
+				if (tokenValueByte != null) {
+					tokenValue = new String(tokenValueByte, StandardCharsets.UTF_8);
+				}
+			}
+			return tokenValue;
+		}
+
 		public final void setLobHandler(LobHandler lobHandler) {
 			Assert.notNull(lobHandler, "lobHandler cannot be null");
 			this.lobHandler = lobHandler;
@@ -520,12 +545,12 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 
 		private <T extends AbstractOAuth2Token> List<SqlParameterValue> toSqlParameterList(OAuth2Authorization.Token<T> token) {
 			List<SqlParameterValue> parameters = new ArrayList<>();
-			byte[] tokenValue = null;
+			String tokenValue = null;
 			Timestamp tokenIssuedAt = null;
 			Timestamp tokenExpiresAt = null;
 			String metadata = null;
 			if (token != null) {
-				tokenValue = token.getToken().getTokenValue().getBytes(StandardCharsets.UTF_8);
+				tokenValue = token.getToken().getTokenValue();
 				if (token.getToken().getIssuedAt() != null) {
 					tokenIssuedAt = Timestamp.from(token.getToken().getIssuedAt());
 				}
@@ -534,7 +559,13 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 				}
 				metadata = writeMap(token.getMetadata());
 			}
-			parameters.add(new SqlParameterValue(Types.BLOB, tokenValue));
+			if (Types.BLOB == tokenColumnType && StringUtils.hasText(tokenValue)) {
+				byte[] tokenValueAsBytes = tokenValue.getBytes(StandardCharsets.UTF_8);
+				parameters.add(new SqlParameterValue(tokenColumnType, tokenValueAsBytes));
+			} else {
+				parameters.add(new SqlParameterValue(tokenColumnType, tokenValue));
+			}
+
 			parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenIssuedAt));
 			parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenExpiresAt));
 			parameters.add(new SqlParameterValue(Types.VARCHAR, metadata));
@@ -551,6 +582,23 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 
 	}
 
+	private static int getColumnDataType(JdbcOperations jdbcOperations, String columnName){
+			return jdbcOperations.execute((ConnectionCallback<Integer>) con -> {
+				DatabaseMetaData databaseMetaData = con.getMetaData();
+				ResultSet rs = databaseMetaData.getColumns(null, null, TABLE_NAME, columnName);
+				if (rs.next()) {
+					return rs.getInt("DATA_TYPE");
+				}
+				// NOTE: When using HSQL: When a database object is created with one of the CREATE statements if the name is enclosed in double quotes, the exact name is used as the case-normal form.
+				// But if it is not enclosed in double quotes, the name is converted to uppercase and this uppercase version is stored in the database as the case-normal form
+				rs = databaseMetaData.getColumns(null, null, TABLE_NAME.toUpperCase(), columnName.toUpperCase());
+				if (rs.next()) {
+					return rs.getInt("DATA_TYPE");
+				}
+				return Types.NULL;
+			});
+		}
+
 	private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
 		private final LobCreator lobCreator;
 
@@ -572,6 +620,15 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 					this.lobCreator.setBlobAsBytes(ps, parameterPosition, valueBytes);
 					return;
 				}
+				if (paramValue.getSqlType() == Types.CLOB) {
+					if (paramValue.getValue() != null) {
+						Assert.isInstanceOf(String.class, paramValue.getValue(),
+								"Value of clob parameter must be String");
+					}
+					String valueString = (String) paramValue.getValue();
+					this.lobCreator.setClobAsString(ps, parameterPosition, valueString);
+					return;
+				}
 			}
 			super.doSetValue(ps, parameterPosition, argValue);
 		}

+ 42 - 0
oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema-postgres.sql

@@ -0,0 +1,42 @@
+/*
+ * Copyright 2020-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE oauth2_authorization (
+    id varchar(100) NOT NULL,
+    registered_client_id varchar(100) NOT NULL,
+    principal_name varchar(200) NOT NULL,
+    authorization_grant_type varchar(100) NOT NULL,
+    attributes varchar(15000) DEFAULT NULL,
+    state varchar(500) DEFAULT NULL,
+    authorization_code_value text DEFAULT NULL,
+    authorization_code_issued_at timestamp DEFAULT NULL,
+    authorization_code_expires_at timestamp DEFAULT NULL,
+    authorization_code_metadata varchar(2000) DEFAULT NULL,
+    access_token_value text DEFAULT NULL,
+    access_token_issued_at timestamp DEFAULT NULL,
+    access_token_expires_at timestamp DEFAULT NULL,
+    access_token_metadata varchar(2000) DEFAULT NULL,
+    access_token_type varchar(100) DEFAULT NULL,
+    access_token_scopes varchar(1000) DEFAULT NULL,
+    oidc_id_token_value text DEFAULT NULL,
+    oidc_id_token_issued_at timestamp DEFAULT NULL,
+    oidc_id_token_expires_at timestamp DEFAULT NULL,
+    oidc_id_token_metadata varchar(2000) DEFAULT NULL,
+    refresh_token_value text DEFAULT NULL,
+    refresh_token_issued_at timestamp DEFAULT NULL,
+    refresh_token_expires_at timestamp DEFAULT NULL,
+    refresh_token_metadata varchar(2000) DEFAULT NULL,
+    PRIMARY KEY (id)
+);

+ 37 - 1
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationServiceTests.java

@@ -43,6 +43,7 @@ import org.springframework.jdbc.core.SqlParameterValue;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
+import org.springframework.jdbc.support.lob.DefaultLobHandler;
 import org.springframework.security.oauth2.core.AbstractOAuth2Token;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.OAuth2AccessToken;
@@ -75,6 +76,7 @@ import static org.mockito.Mockito.when;
 public class JdbcOAuth2AuthorizationServiceTests {
 	private static final String OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql";
 	private static final String CUSTOM_OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-schema.sql";
+	private static final String OAUTH2_AUTHORIZATION_SCHEMA_CLOB_COLUMN_TYPE_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-schema-clob-data-type.sql";
 	private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
 	private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE);
 	private static final String ID = "id";
@@ -414,6 +416,37 @@ public class JdbcOAuth2AuthorizationServiceTests {
 		db.shutdown();
 	}
 
+	@Test
+	public void tableDefinitionWhenClobSqlTypeThenUpdateAuthorization() {
+		EmbeddedDatabase db = createDb(OAUTH2_AUTHORIZATION_SCHEMA_CLOB_COLUMN_TYPE_SQL_RESOURCE);
+		OAuth2AuthorizationService authorizationService =
+				new JdbcOAuth2AuthorizationService(new JdbcTemplate(db), this.registeredClientRepository);
+		when(this.registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId())))
+				.thenReturn(REGISTERED_CLIENT);
+		OAuth2Authorization originalAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.id(ID)
+				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
+				.token(AUTHORIZATION_CODE)
+				.build();
+		authorizationService.save(originalAuthorization);
+
+		OAuth2Authorization authorization = authorizationService.findById(
+				originalAuthorization.getId());
+		assertThat(authorization).isEqualTo(originalAuthorization);
+
+		OAuth2Authorization updatedAuthorization = OAuth2Authorization.from(authorization)
+				.attribute("custom-name-1", "custom-value-1")
+				.build();
+		authorizationService.save(updatedAuthorization);
+
+		authorization = authorizationService.findById(
+				updatedAuthorization.getId());
+		assertThat(authorization).isEqualTo(updatedAuthorization);
+		assertThat(authorization).isNotEqualTo(originalAuthorization);
+		db.shutdown();
+	}
+
 	private static EmbeddedDatabase createDb() {
 		return createDb(OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE);
 	}
@@ -479,11 +512,14 @@ public class JdbcOAuth2AuthorizationServiceTests {
 
 		private CustomJdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations,
 				RegisteredClientRepository registeredClientRepository) {
-			super(jdbcOperations, registeredClientRepository);
+			super(jdbcOperations, registeredClientRepository, new DefaultLobHandler());
 			setAuthorizationRowMapper(new CustomOAuth2AuthorizationRowMapper(registeredClientRepository));
 			setAuthorizationParametersMapper(new CustomOAuth2AuthorizationParametersMapper());
+
 		}
 
+
+
 		@Override
 		public void save(OAuth2Authorization authorization) {
 			List<SqlParameterValue> parameters = getAuthorizationParametersMapper().apply(authorization);

+ 42 - 0
oauth2-authorization-server/src/test/resources/org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-schema-clob-data-type.sql

@@ -0,0 +1,42 @@
+/*
+ * Copyright 2020-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+CREATE TABLE oauth2_authorization (
+    id varchar(100) NOT NULL,
+    registered_client_id varchar(100) NOT NULL,
+    principal_name varchar(200) NOT NULL,
+    authorization_grant_type varchar(100) NOT NULL,
+    attributes varchar(15000) DEFAULT NULL,
+    state varchar(500) DEFAULT NULL,
+    authorization_code_value clob DEFAULT NULL,
+    authorization_code_issued_at timestamp DEFAULT NULL,
+    authorization_code_expires_at timestamp DEFAULT NULL,
+    authorization_code_metadata varchar(2000) DEFAULT NULL,
+    access_token_value clob DEFAULT NULL,
+    access_token_issued_at timestamp DEFAULT NULL,
+    access_token_expires_at timestamp DEFAULT NULL,
+    access_token_metadata varchar(2000) DEFAULT NULL,
+    access_token_type varchar(100) DEFAULT NULL,
+    access_token_scopes varchar(1000) DEFAULT NULL,
+    oidc_id_token_value clob DEFAULT NULL,
+    oidc_id_token_issued_at timestamp DEFAULT NULL,
+    oidc_id_token_expires_at timestamp DEFAULT NULL,
+    oidc_id_token_metadata varchar(2000) DEFAULT NULL,
+    refresh_token_value clob DEFAULT NULL,
+    refresh_token_issued_at timestamp DEFAULT NULL,
+    refresh_token_expires_at timestamp DEFAULT NULL,
+    refresh_token_metadata varchar(2000) DEFAULT NULL,
+    PRIMARY KEY (id)
+);