Parcourir la source

JdbcOAuth2AuthorizationService improves support for large data columns

Closes gh-604
Joe Grandja il y a 3 ans
Parent
commit
58bac49f97

+ 116 - 66
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationService.java

@@ -25,6 +25,7 @@ import java.sql.Types;
 import java.time.Instant;
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -72,6 +73,7 @@ import org.springframework.util.StringUtils;
  * therefore MUST be defined in the database schema.
  *
  * @author Ovidiu Popa
+ * @author Joe Grandja
  * @since 0.1.2
  * @see OAuth2AuthorizationService
  * @see OAuth2Authorization
@@ -141,7 +143,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 
 	private static final String REMOVE_AUTHORIZATION_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER;
 
-	private static int tokenColumnDataType;
+	private static Map<String, ColumnMetadata> columnMetadataMap;
 
 	private final JdbcOperations jdbcOperations;
 	private final LobHandler lobHandler;
@@ -177,7 +179,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 		authorizationRowMapper.setLobHandler(lobHandler);
 		this.authorizationRowMapper = authorizationRowMapper;
 		this.authorizationParametersMapper = new OAuth2AuthorizationParametersMapper();
-		tokenColumnDataType = getColumnDataType(jdbcOperations, "access_token_value", Types.BLOB);
+		initColumnMetadata(jdbcOperations);
 	}
 
 	@Override
@@ -237,32 +239,26 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 		List<SqlParameterValue> parameters = new ArrayList<>();
 		if (tokenType == null) {
 			parameters.add(new SqlParameterValue(Types.VARCHAR, token));
-			parameters.add(mapTokenToSqlParameter(token));
-			parameters.add(mapTokenToSqlParameter(token));
-			parameters.add(mapTokenToSqlParameter(token));
+			parameters.add(mapToSqlParameter("authorization_code_value", token));
+			parameters.add(mapToSqlParameter("access_token_value", token));
+			parameters.add(mapToSqlParameter("refresh_token_value", token));
 			return findBy(UNKNOWN_TOKEN_TYPE_FILTER, parameters);
 		} else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) {
 			parameters.add(new SqlParameterValue(Types.VARCHAR, token));
 			return findBy(STATE_FILTER, parameters);
 		} else if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) {
-			parameters.add(mapTokenToSqlParameter(token));
+			parameters.add(mapToSqlParameter("authorization_code_value", token));
 			return findBy(AUTHORIZATION_CODE_FILTER, parameters);
 		} else if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenType)) {
-			parameters.add(mapTokenToSqlParameter(token));
+			parameters.add(mapToSqlParameter("access_token_value", token));
 			return findBy(ACCESS_TOKEN_FILTER, parameters);
 		} else if (OAuth2TokenType.REFRESH_TOKEN.equals(tokenType)) {
-			parameters.add(mapTokenToSqlParameter(token));
+			parameters.add(mapToSqlParameter("refresh_token_value", token));
 			return findBy(REFRESH_TOKEN_FILTER, parameters);
 		}
 		return null;
 	}
 
-	private SqlParameterValue mapTokenToSqlParameter(String token) {
-		return Types.BLOB == tokenColumnDataType ?
-				new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)) :
-				new SqlParameterValue(tokenColumnDataType, token);
-	}
-
 	private OAuth2Authorization findBy(String filter, List<SqlParameterValue> parameters) {
 		try (LobCreator lobCreator = getLobHandler().getLobCreator()) {
 			PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator,
@@ -348,7 +344,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 			String id = rs.getString("id");
 			String principalName = rs.getString("principal_name");
 			String authorizationGrantType = rs.getString("authorization_grant_type");
-			Map<String, Object> attributes = parseMap(rs.getString("attributes"));
+			Map<String, Object> attributes = parseMap(getLobValue(rs, "attributes"));
 
 			builder.id(id)
 					.principalName(principalName)
@@ -362,23 +358,23 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 
 			Instant tokenIssuedAt;
 			Instant tokenExpiresAt;
-			String authorizationCodeValue = getTokenValue(rs, "authorization_code_value");
+			String authorizationCodeValue = getLobValue(rs, "authorization_code_value");
 
 			if (StringUtils.hasText(authorizationCodeValue)) {
 				tokenIssuedAt = rs.getTimestamp("authorization_code_issued_at").toInstant();
 				tokenExpiresAt = rs.getTimestamp("authorization_code_expires_at").toInstant();
-				Map<String, Object> authorizationCodeMetadata = parseMap(rs.getString("authorization_code_metadata"));
+				Map<String, Object> authorizationCodeMetadata = parseMap(getLobValue(rs, "authorization_code_metadata"));
 
 				OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
 						authorizationCodeValue, tokenIssuedAt, tokenExpiresAt);
 				builder.token(authorizationCode, (metadata) -> metadata.putAll(authorizationCodeMetadata));
 			}
 
-			String accessTokenValue = getTokenValue(rs, "access_token_value");
+			String accessTokenValue = getLobValue(rs, "access_token_value");
 			if (StringUtils.hasText(accessTokenValue)) {
 				tokenIssuedAt = rs.getTimestamp("access_token_issued_at").toInstant();
 				tokenExpiresAt = rs.getTimestamp("access_token_expires_at").toInstant();
-				Map<String, Object> accessTokenMetadata = parseMap(rs.getString("access_token_metadata"));
+				Map<String, Object> accessTokenMetadata = parseMap(getLobValue(rs, "access_token_metadata"));
 				OAuth2AccessToken.TokenType tokenType = null;
 				if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(rs.getString("access_token_type"))) {
 					tokenType = OAuth2AccessToken.TokenType.BEARER;
@@ -393,18 +389,18 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 				builder.token(accessToken, (metadata) -> metadata.putAll(accessTokenMetadata));
 			}
 
-			String oidcIdTokenValue = getTokenValue(rs, "oidc_id_token_value");
+			String oidcIdTokenValue = getLobValue(rs, "oidc_id_token_value");
 			if (StringUtils.hasText(oidcIdTokenValue)) {
 				tokenIssuedAt = rs.getTimestamp("oidc_id_token_issued_at").toInstant();
 				tokenExpiresAt = rs.getTimestamp("oidc_id_token_expires_at").toInstant();
-				Map<String, Object> oidcTokenMetadata = parseMap(rs.getString("oidc_id_token_metadata"));
+				Map<String, Object> oidcTokenMetadata = parseMap(getLobValue(rs, "oidc_id_token_metadata"));
 
 				OidcIdToken oidcToken = new OidcIdToken(
 						oidcIdTokenValue, tokenIssuedAt, tokenExpiresAt, (Map<String, Object>) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME));
 				builder.token(oidcToken, (metadata) -> metadata.putAll(oidcTokenMetadata));
 			}
 
-			String refreshTokenValue = getTokenValue(rs, "refresh_token_value");
+			String refreshTokenValue = getLobValue(rs, "refresh_token_value");
 			if (StringUtils.hasText(refreshTokenValue)) {
 				tokenIssuedAt = rs.getTimestamp("refresh_token_issued_at").toInstant();
 				tokenExpiresAt = null;
@@ -412,7 +408,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 				if (refreshTokenExpiresAt != null) {
 					tokenExpiresAt = refreshTokenExpiresAt.toInstant();
 				}
-				Map<String, Object> refreshTokenMetadata = parseMap(rs.getString("refresh_token_metadata"));
+				Map<String, Object> refreshTokenMetadata = parseMap(getLobValue(rs, "refresh_token_metadata"));
 
 				OAuth2RefreshToken refreshToken = new OAuth2RefreshToken(
 						refreshTokenValue, tokenIssuedAt, tokenExpiresAt);
@@ -421,19 +417,20 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 			return builder.build();
 		}
 
-		private String getTokenValue(ResultSet rs, String tokenColumn) throws SQLException {
-			String tokenValue = null;
-			if (Types.BLOB == tokenColumnDataType) {
-				byte[] tokenValueBytes = this.lobHandler.getBlobAsBytes(rs, tokenColumn);
-				if (tokenValueBytes != null) {
-					tokenValue = new String(tokenValueBytes, StandardCharsets.UTF_8);
+		private String getLobValue(ResultSet rs, String columnName) throws SQLException {
+			String columnValue = null;
+			ColumnMetadata columnMetadata = columnMetadataMap.get(columnName);
+			if (Types.BLOB == columnMetadata.getDataType()) {
+				byte[] columnValueBytes = this.lobHandler.getBlobAsBytes(rs, columnName);
+				if (columnValueBytes != null) {
+					columnValue = new String(columnValueBytes, StandardCharsets.UTF_8);
 				}
-			} else if (Types.CLOB == tokenColumnDataType) {
-				tokenValue = this.lobHandler.getClobAsString(rs, tokenColumn);
+			} else if (Types.CLOB == columnMetadata.getDataType()) {
+				columnValue = this.lobHandler.getClobAsString(rs, columnName);
 			} else {
-				tokenValue = rs.getString(tokenColumn);
+				columnValue = rs.getString(columnName);
 			}
-			return tokenValue;
+			return columnValue;
 		}
 
 		public final void setLobHandler(LobHandler lobHandler) {
@@ -491,7 +488,7 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 			parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getAuthorizationGrantType().getValue()));
 
 			String attributes = writeMap(authorization.getAttributes());
-			parameters.add(new SqlParameterValue(Types.VARCHAR, attributes));
+			parameters.add(mapToSqlParameter("attributes", attributes));
 
 			String state = null;
 			String authorizationState = authorization.getAttribute(OAuth2ParameterNames.STATE);
@@ -502,12 +499,14 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 
 			OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
 					authorization.getToken(OAuth2AuthorizationCode.class);
-			List<SqlParameterValue> authorizationCodeSqlParameters = toSqlParameterList(authorizationCode);
+			List<SqlParameterValue> authorizationCodeSqlParameters = toSqlParameterList(
+					"authorization_code_value", "authorization_code_metadata", authorizationCode);
 			parameters.addAll(authorizationCodeSqlParameters);
 
 			OAuth2Authorization.Token<OAuth2AccessToken> accessToken =
 					authorization.getToken(OAuth2AccessToken.class);
-			List<SqlParameterValue> accessTokenSqlParameters = toSqlParameterList(accessToken);
+			List<SqlParameterValue> accessTokenSqlParameters = toSqlParameterList(
+					"access_token_value", "access_token_metadata", accessToken);
 			parameters.addAll(accessTokenSqlParameters);
 			String accessTokenType = null;
 			String accessTokenScopes = null;
@@ -521,11 +520,13 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 			parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenScopes));
 
 			OAuth2Authorization.Token<OidcIdToken> oidcIdToken = authorization.getToken(OidcIdToken.class);
-			List<SqlParameterValue> oidcIdTokenSqlParameters = toSqlParameterList(oidcIdToken);
+			List<SqlParameterValue> oidcIdTokenSqlParameters = toSqlParameterList(
+					"oidc_id_token_value", "oidc_id_token_metadata", oidcIdToken);
 			parameters.addAll(oidcIdTokenSqlParameters);
 
 			OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = authorization.getRefreshToken();
-			List<SqlParameterValue> refreshTokenSqlParameters = toSqlParameterList(refreshToken);
+			List<SqlParameterValue> refreshTokenSqlParameters = toSqlParameterList(
+					"refresh_token_value", "refresh_token_metadata", refreshToken);
 			parameters.addAll(refreshTokenSqlParameters);
 			return parameters;
 		}
@@ -539,7 +540,9 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 			return this.objectMapper;
 		}
 
-		private <T extends AbstractOAuth2Token> List<SqlParameterValue> toSqlParameterList(OAuth2Authorization.Token<T> token) {
+		private <T extends AbstractOAuth2Token> List<SqlParameterValue> toSqlParameterList(
+				String tokenColumnName, String tokenMetadataColumnName, OAuth2Authorization.Token<T> token) {
+
 			List<SqlParameterValue> parameters = new ArrayList<>();
 			String tokenValue = null;
 			Timestamp tokenIssuedAt = null;
@@ -555,15 +558,11 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 				}
 				metadata = writeMap(token.getMetadata());
 			}
-			if (Types.BLOB == tokenColumnDataType && StringUtils.hasText(tokenValue)) {
-				byte[] tokenValueBytes = tokenValue.getBytes(StandardCharsets.UTF_8);
-				parameters.add(new SqlParameterValue(Types.BLOB, tokenValueBytes));
-			} else {
-				parameters.add(new SqlParameterValue(tokenColumnDataType, tokenValue));
-			}
+
+			parameters.add(mapToSqlParameter(tokenColumnName, tokenValue));
 			parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenIssuedAt));
 			parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenExpiresAt));
-			parameters.add(new SqlParameterValue(Types.VARCHAR, metadata));
+			parameters.add(mapToSqlParameter(tokenMetadataColumnName, metadata));
 			return parameters;
 		}
 
@@ -577,26 +576,6 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 
 	}
 
-	private static Integer getColumnDataType(JdbcOperations jdbcOperations, String columnName, int defaultDataType) {
-		return jdbcOperations.execute((ConnectionCallback<Integer>) conn -> {
-			DatabaseMetaData databaseMetaData = conn.getMetaData();
-			ResultSet rs = databaseMetaData.getColumns(null, null, TABLE_NAME, columnName);
-			if (rs.next()) {
-				return rs.getInt("DATA_TYPE");
-			}
-			// NOTE: (Applies to HSQL)
-			// When a database object is created with one of the CREATE statements or renamed with the ALTER statement,
-			// if the name is enclosed in double quotes, the exact name is used as the case-normal form.
-			// But if it is not enclosed in double quotes,
-			// the name is converted to uppercase and this uppercase version is stored in the database as the case-normal form.
-			rs = databaseMetaData.getColumns(null, null, TABLE_NAME.toUpperCase(), columnName.toUpperCase());
-			if (rs.next()) {
-				return rs.getInt("DATA_TYPE");
-			}
-			return defaultDataType;
-		});
-	}
-
 	private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
 		private final LobCreator lobCreator;
 
@@ -633,4 +612,75 @@ public class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationServic
 
 	}
 
+	private static final class ColumnMetadata {
+		private final String columnName;
+		private final int dataType;
+
+		private ColumnMetadata(String columnName, int dataType) {
+			this.columnName = columnName;
+			this.dataType = dataType;
+		}
+
+		private String getColumnName() {
+			return this.columnName;
+		}
+
+		private int getDataType() {
+			return this.dataType;
+		}
+
+	}
+
+	private static void initColumnMetadata(JdbcOperations jdbcOperations) {
+		columnMetadataMap = new HashMap<>();
+		ColumnMetadata columnMetadata;
+
+		columnMetadata = getColumnMetadata(jdbcOperations, "attributes", Types.BLOB);
+		columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
+		columnMetadata = getColumnMetadata(jdbcOperations, "authorization_code_value", Types.BLOB);
+		columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
+		columnMetadata = getColumnMetadata(jdbcOperations, "authorization_code_metadata", Types.BLOB);
+		columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
+		columnMetadata = getColumnMetadata(jdbcOperations, "access_token_value", Types.BLOB);
+		columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
+		columnMetadata = getColumnMetadata(jdbcOperations, "access_token_metadata", Types.BLOB);
+		columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
+		columnMetadata = getColumnMetadata(jdbcOperations, "oidc_id_token_value", Types.BLOB);
+		columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
+		columnMetadata = getColumnMetadata(jdbcOperations, "oidc_id_token_metadata", Types.BLOB);
+		columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
+		columnMetadata = getColumnMetadata(jdbcOperations, "refresh_token_value", Types.BLOB);
+		columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
+		columnMetadata = getColumnMetadata(jdbcOperations, "refresh_token_metadata", Types.BLOB);
+		columnMetadataMap.put(columnMetadata.getColumnName(), columnMetadata);
+	}
+
+	private static ColumnMetadata getColumnMetadata(JdbcOperations jdbcOperations, String columnName, int defaultDataType) {
+		Integer dataType = jdbcOperations.execute((ConnectionCallback<Integer>) conn -> {
+			DatabaseMetaData databaseMetaData = conn.getMetaData();
+			ResultSet rs = databaseMetaData.getColumns(null, null, TABLE_NAME, columnName);
+			if (rs.next()) {
+				return rs.getInt("DATA_TYPE");
+			}
+			// NOTE: (Applies to HSQL)
+			// When a database object is created with one of the CREATE statements or renamed with the ALTER statement,
+			// if the name is enclosed in double quotes, the exact name is used as the case-normal form.
+			// But if it is not enclosed in double quotes,
+			// the name is converted to uppercase and this uppercase version is stored in the database as the case-normal form.
+			rs = databaseMetaData.getColumns(null, null, TABLE_NAME.toUpperCase(), columnName.toUpperCase());
+			if (rs.next()) {
+				return rs.getInt("DATA_TYPE");
+			}
+			return null;
+		});
+		return new ColumnMetadata(columnName, dataType != null ? dataType : defaultDataType);
+	}
+
+	private static SqlParameterValue mapToSqlParameter(String columnName, String value) {
+		ColumnMetadata columnMetadata = columnMetadataMap.get(columnName);
+		return Types.BLOB == columnMetadata.getDataType() && StringUtils.hasText(value) ?
+				new SqlParameterValue(Types.BLOB, value.getBytes(StandardCharsets.UTF_8)) :
+				new SqlParameterValue(columnMetadata.getDataType(), value);
+	}
+
 }

+ 5 - 5
oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql

@@ -8,25 +8,25 @@ CREATE TABLE oauth2_authorization (
     registered_client_id varchar(100) NOT NULL,
     principal_name varchar(200) NOT NULL,
     authorization_grant_type varchar(100) NOT NULL,
-    attributes varchar(4000) DEFAULT NULL,
+    attributes blob DEFAULT NULL,
     state varchar(500) DEFAULT NULL,
     authorization_code_value blob DEFAULT NULL,
     authorization_code_issued_at timestamp DEFAULT NULL,
     authorization_code_expires_at timestamp DEFAULT NULL,
-    authorization_code_metadata varchar(2000) DEFAULT NULL,
+    authorization_code_metadata blob DEFAULT NULL,
     access_token_value blob DEFAULT NULL,
     access_token_issued_at timestamp DEFAULT NULL,
     access_token_expires_at timestamp DEFAULT NULL,
-    access_token_metadata varchar(2000) DEFAULT NULL,
+    access_token_metadata blob DEFAULT NULL,
     access_token_type varchar(100) DEFAULT NULL,
     access_token_scopes varchar(1000) DEFAULT NULL,
     oidc_id_token_value blob DEFAULT NULL,
     oidc_id_token_issued_at timestamp DEFAULT NULL,
     oidc_id_token_expires_at timestamp DEFAULT NULL,
-    oidc_id_token_metadata varchar(2000) DEFAULT NULL,
+    oidc_id_token_metadata blob DEFAULT NULL,
     refresh_token_value blob DEFAULT NULL,
     refresh_token_issued_at timestamp DEFAULT NULL,
     refresh_token_expires_at timestamp DEFAULT NULL,
-    refresh_token_metadata varchar(2000) DEFAULT NULL,
+    refresh_token_metadata blob DEFAULT NULL,
     PRIMARY KEY (id)
 );