Browse Source

Provide JDBC implementation of OAuth2AuthorizationService

Add new JDBC implementation of the OAuth2AuthorizationService

Closes gh-245
Ovidiu Popa 4 years ago
parent
commit
552751bd93

+ 1 - 0
gradle/dependency-management.gradle

@@ -32,5 +32,6 @@ dependencyManagement {
 		dependency "com.squareup.okhttp3:mockwebserver:3.14.9"
 		dependency "com.squareup.okhttp3:okhttp:3.14.9"
 		dependency "com.jayway.jsonpath:json-path:2.4.0"
+		dependency "org.hsqldb:hsqldb:2.5.+"
 	}
 }

+ 4 - 0
oauth2-authorization-server/spring-security-oauth2-authorization-server.gradle

@@ -10,6 +10,8 @@ dependencies {
 	compile 'com.nimbusds:nimbus-jose-jwt'
 	compile 'com.fasterxml.jackson.core:jackson-databind'
 
+	optional 'org.springframework:spring-jdbc'
+
 	testCompile 'org.springframework.security:spring-security-test'
 	testCompile 'org.springframework:spring-webmvc'
 	testCompile 'junit:junit'
@@ -17,6 +19,8 @@ dependencies {
 	testCompile 'org.mockito:mockito-core'
 	testCompile 'com.jayway.jsonpath:json-path'
 
+	testRuntime 'org.hsqldb:hsqldb'
+
 	provided 'javax.servlet:javax.servlet-api'
 }
 

+ 554 - 0
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationService.java

@@ -0,0 +1,554 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization;
+
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.springframework.dao.DataRetrievalFailureException;
+import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
+import org.springframework.jdbc.core.JdbcOperations;
+import org.springframework.jdbc.core.PreparedStatementSetter;
+import org.springframework.jdbc.core.RowMapper;
+import org.springframework.jdbc.core.SqlParameterValue;
+import org.springframework.jdbc.support.lob.DefaultLobHandler;
+import org.springframework.jdbc.support.lob.LobCreator;
+import org.springframework.jdbc.support.lob.LobHandler;
+import org.springframework.lang.Nullable;
+import org.springframework.security.oauth2.core.AbstractOAuth2Token;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
+import org.springframework.security.oauth2.core.OAuth2TokenType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.StringUtils;
+
+import java.nio.charset.StandardCharsets;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Timestamp;
+import java.sql.Types;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Function;
+
+/**
+ * A JDBC implementation of an {@link OAuth2AuthorizationService} that uses a
+ * <p>
+ * {@link JdbcOperations} for {@link OAuth2Authorization} persistence.
+ *
+ * <p>
+ * <b>NOTE:</b> This {@code OAuth2AuthorizationService} depends on the table definition
+ * described in
+ * "classpath:org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql" and
+ * therefore MUST be defined in the database schema.
+ *
+ * @author Ovidiu Popa
+ * @see OAuth2AuthorizationService
+ * @see OAuth2Authorization
+ * @see JdbcOperations
+ * @see RowMapper
+ * @since 0.1.2
+ */
+public final class JdbcOAuth2AuthorizationService implements OAuth2AuthorizationService {
+
+	// @formatter:off
+	private static final String COLUMN_NAMES = "id, "
+			+ "registered_client_id, "
+			+ "principal_name, "
+			+ "authorization_grant_type, "
+			+ "attributes, "
+			+ "state, "
+			+ "authorization_code_value, "
+			+ "authorization_code_issued_at, "
+			+ "authorization_code_expires_at,"
+			+ "authorization_code_metadata,"
+			+ "access_token_value,"
+			+ "access_token_issued_at,"
+			+ "access_token_expires_at,"
+			+ "access_token_metadata,"
+			+ "access_token_type,"
+			+ "access_token_scopes,"
+			+ "oidc_id_token_value,"
+			+ "oidc_id_token_issued_at,"
+			+ "oidc_id_token_expires_at,"
+			+ "oidc_id_token_metadata,"
+			+ "refresh_token_value,"
+			+ "refresh_token_issued_at,"
+			+ "refresh_token_expires_at,"
+			+ "refresh_token_metadata";
+	// @formatter:on
+
+	private static final String TABLE_NAME = "oauth2_authorization";
+
+	private static final String PK_FILTER = "id = ?";
+	private static final String UNKNOWN_TOKEN_TYPE_FILTER = "state = ? OR authorization_code_value = ? OR " +
+			"access_token_value = ? OR " +
+			"refresh_token_value = ?";
+
+	private static final String STATE_FILTER = "state = ?";
+	private static final String AUTHORIZATION_CODE_FILTER = "authorization_code_value = ?";
+	private static final String ACCESS_TOKEN_FILTER = "access_token_value = ?";
+	private static final String REFRESH_TOKEN_FILTER = "refresh_token_value = ?";
+
+	// @formatter:off
+	private static final String LOAD_AUTHORIZATION_SQL = "SELECT " + COLUMN_NAMES
+			+ " FROM " + TABLE_NAME
+			+ " WHERE ";
+	// @formatter:on
+
+	// @formatter:off
+	private static final String SAVE_AUTHORIZATION_SQL = "INSERT INTO " + TABLE_NAME
+			+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?,?, ?, ?, ?, ?, ?, ?, ?,?, ?, ?, ?, ?, ?, ?, ?)";
+	// @formatter:on
+
+	// @formatter:off
+	private static final String UPDATE_AUTHORIZATION_SQL = "UPDATE " + TABLE_NAME
+			+ " SET registered_client_id = ?, principal_name = ?, authorization_grant_type = ?, attributes = ?, state = ?,"
+			+ " authorization_code_value = ?, authorization_code_issued_at = ?, authorization_code_expires_at = ?, authorization_code_metadata = ?,"
+			+ " access_token_value = ?, access_token_issued_at = ?, access_token_expires_at = ?, access_token_metadata = ?, access_token_type = ?, access_token_scopes = ?,"
+			+ " oidc_id_token_value = ?, oidc_id_token_issued_at = ?, oidc_id_token_expires_at = ?, oidc_id_token_metadata = ?,"
+			+ " refresh_token_value = ?, refresh_token_issued_at = ?, refresh_token_expires_at = ?, refresh_token_metadata = ?"
+			+ " WHERE " + PK_FILTER;
+	// @formatter:on
+
+	private static final String REMOVE_AUTHORIZATION_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER;
+
+	private final JdbcOperations jdbcOperations;
+	private final LobHandler lobHandler;
+	private RowMapper<OAuth2Authorization> authorizationRowMapper;
+	private Function<OAuth2Authorization, List<SqlParameterValue>> authorizationParametersMapper;
+
+	/**
+	 * Constructs a {@code JdbcOAuth2AuthorizationService} using the provided parameters.
+	 *
+	 * @param jdbcOperations             the JDBC operations
+	 * @param registeredClientRepository the registered client repository
+	 */
+	public JdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations,
+			RegisteredClientRepository registeredClientRepository) {
+		this(jdbcOperations, registeredClientRepository, new DefaultLobHandler());
+	}
+
+	/**
+	 * Constructs a {@code JdbcOAuth2AuthorizationService} using the provided parameters.
+	 *
+	 * @param jdbcOperations             the JDBC operations
+	 * @param registeredClientRepository the registered client repository
+	 * @param lobHandler                 the handler for large binary fields and large text fields
+	 */
+	public JdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations,
+			RegisteredClientRepository registeredClientRepository, LobHandler lobHandler) {
+		this(jdbcOperations, registeredClientRepository, lobHandler, new ObjectMapper());
+	}
+
+	/**
+	 * Constructs a {@code JdbcOAuth2AuthorizationService} using the provided parameters.
+	 *
+	 * @param jdbcOperations             the JDBC operations
+	 * @param registeredClientRepository the registered client repository
+	 * @param lobHandler                 the handler for large binary fields and large text fields
+	 * @param objectMapper               the object mapper
+	 */
+	public JdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations,
+			RegisteredClientRepository registeredClientRepository,
+			LobHandler lobHandler, ObjectMapper objectMapper) {
+		Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
+		Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
+		Assert.notNull(lobHandler, "lobHandler cannot be null");
+		Assert.notNull(objectMapper, "objectMapper cannot be null");
+		this.jdbcOperations = jdbcOperations;
+		this.lobHandler = lobHandler;
+		OAuth2AuthorizationRowMapper authorizationRowMapper = new OAuth2AuthorizationRowMapper(registeredClientRepository, objectMapper);
+		authorizationRowMapper.setLobHandler(lobHandler);
+		this.authorizationRowMapper = authorizationRowMapper;
+		this.authorizationParametersMapper = new OAuth2AuthorizationParametersMapper(objectMapper);
+	}
+
+
+	@Override
+	public void save(OAuth2Authorization authorization) {
+		Assert.notNull(authorization, "authorization cannot be null");
+
+		OAuth2Authorization existingAuthorization = findById(authorization.getId());
+		if (existingAuthorization == null) {
+			insertAuthorization(authorization);
+		} else {
+			updateAuthorization(authorization);
+		}
+	}
+
+	private void updateAuthorization(OAuth2Authorization authorization) {
+		List<SqlParameterValue> parameters = this.authorizationParametersMapper.apply(authorization);
+		SqlParameterValue id = parameters.remove(0);
+		parameters.add(id);
+		try (LobCreator lobCreator = this.lobHandler.getLobCreator()) {
+			PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator,
+					parameters.toArray());
+			this.jdbcOperations.update(UPDATE_AUTHORIZATION_SQL, pss);
+		}
+	}
+
+	private void insertAuthorization(OAuth2Authorization authorization) {
+		List<SqlParameterValue> parameters = this.authorizationParametersMapper.apply(authorization);
+		try (LobCreator lobCreator = this.lobHandler.getLobCreator()) {
+			PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator,
+					parameters.toArray());
+			this.jdbcOperations.update(SAVE_AUTHORIZATION_SQL, pss);
+		}
+	}
+
+	@Override
+	public void remove(OAuth2Authorization authorization) {
+		Assert.notNull(authorization, "authorization cannot be null");
+		SqlParameterValue[] parameters = new SqlParameterValue[]{
+				new SqlParameterValue(Types.VARCHAR, authorization.getId())
+		};
+		PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
+		this.jdbcOperations.update(REMOVE_AUTHORIZATION_SQL, pss);
+	}
+
+	@Nullable
+	@Override
+	public OAuth2Authorization findById(String id) {
+		Assert.hasText(id, "id cannot be empty");
+		List<SqlParameterValue> parameters = new ArrayList<>();
+		parameters.add(new SqlParameterValue(Types.VARCHAR, id));
+		return findBy(PK_FILTER, parameters);
+	}
+
+	@Nullable
+	@Override
+	public OAuth2Authorization findByToken(String token, @Nullable OAuth2TokenType tokenType) {
+		Assert.hasText(token, "token cannot be empty");
+		List<SqlParameterValue> parameters = new ArrayList<>();
+		if (tokenType == null) {
+			parameters.add(new SqlParameterValue(Types.VARCHAR, token));
+			parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
+			parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
+			parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
+			return findBy(UNKNOWN_TOKEN_TYPE_FILTER, parameters);
+		} else if (OAuth2ParameterNames.STATE.equals(tokenType.getValue())) {
+			parameters.add(new SqlParameterValue(Types.VARCHAR, token));
+			return findBy(STATE_FILTER, parameters);
+		} else if (OAuth2ParameterNames.CODE.equals(tokenType.getValue())) {
+			parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
+			return findBy(AUTHORIZATION_CODE_FILTER, parameters);
+		} else if (OAuth2TokenType.ACCESS_TOKEN.equals(tokenType)) {
+			parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
+			return findBy(ACCESS_TOKEN_FILTER, parameters);
+		} else if (OAuth2TokenType.REFRESH_TOKEN.equals(tokenType)) {
+			parameters.add(new SqlParameterValue(Types.BLOB, token.getBytes(StandardCharsets.UTF_8)));
+			return findBy(REFRESH_TOKEN_FILTER, parameters);
+		}
+		return null;
+	}
+
+	private OAuth2Authorization findBy(String filter, List<SqlParameterValue> parameters) {
+		PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
+		List<OAuth2Authorization> result = this.jdbcOperations.query(LOAD_AUTHORIZATION_SQL + filter, pss, this.authorizationRowMapper);
+		return !result.isEmpty() ? result.get(0) : null;
+	}
+
+	/**
+	 * Sets the {@link RowMapper} used for mapping the current row in
+	 * {@code java.sql.ResultSet} to {@link OAuth2Authorization}. The default is
+	 * {@link OAuth2AuthorizationRowMapper}.
+	 *
+	 * @param authorizationRowMapper the {@link RowMapper} used for mapping the current
+	 *                               row in {@code ResultSet} to {@link OAuth2Authorization}
+	 */
+	public void setAuthorizationRowMapper(RowMapper<OAuth2Authorization> authorizationRowMapper) {
+		Assert.notNull(authorizationRowMapper, "authorizationRowMapper cannot be null");
+		this.authorizationRowMapper = authorizationRowMapper;
+	}
+
+	/**
+	 * Sets the {@code Function} used for mapping {@link OAuth2Authorization} to
+	 * a {@code List} of {@link SqlParameterValue}. The default is
+	 * {@link OAuth2AuthorizationParametersMapper}.
+	 *
+	 * @param authorizationParametersMapper the {@code Function} used for mapping
+	 *                                      {@link OAuth2Authorization} to a {@code List} of {@link SqlParameterValue}
+	 */
+	public void setAuthorizationParametersMapper(
+			Function<OAuth2Authorization, List<SqlParameterValue>> authorizationParametersMapper) {
+		Assert.notNull(authorizationParametersMapper, "authorizationParametersMapper cannot be null");
+		this.authorizationParametersMapper = authorizationParametersMapper;
+	}
+
+	/**
+	 * The default {@link RowMapper} that maps the current row in
+	 * {@code java.sql.ResultSet} to {@link OAuth2Authorization}.
+	 */
+	public static class OAuth2AuthorizationRowMapper implements RowMapper<OAuth2Authorization> {
+
+		private final RegisteredClientRepository registeredClientRepository;
+		private final ObjectMapper objectMapper;
+		private LobHandler lobHandler = new DefaultLobHandler();
+
+
+		public OAuth2AuthorizationRowMapper(RegisteredClientRepository registeredClientRepository, ObjectMapper objectMapper) {
+			Assert.notNull(registeredClientRepository, "registeredClientRepository cannot be null");
+			Assert.notNull(objectMapper, "objectMapper cannot be null");
+			this.registeredClientRepository = registeredClientRepository;
+			this.objectMapper = objectMapper;
+		}
+
+		@Override
+		@SuppressWarnings("unchecked")
+		public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException {
+			try {
+				String registeredClientId = rs.getString("registered_client_id");
+				RegisteredClient registeredClient = this.registeredClientRepository
+						.findById(registeredClientId);
+				if (registeredClient == null) {
+					throw new DataRetrievalFailureException(
+							"The RegisteredClient with id '" + registeredClientId + "' it was not found in the RegisteredClientRepository.");
+				}
+
+				OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient);
+				String id = rs.getString("id");
+				String principalName = rs.getString("principal_name");
+				String authorizationGrantType = rs.getString("authorization_grant_type");
+				Map<String, Object> attributes = this.objectMapper.readValue(rs.getString("attributes"), Map.class);
+
+				builder.id(id)
+						.principalName(principalName)
+						.authorizationGrantType(new AuthorizationGrantType(authorizationGrantType))
+						.attributes(attrs -> attrs.putAll(attributes));
+
+				String state = rs.getString("state");
+				if (StringUtils.hasText(state)) {
+					builder.attribute(OAuth2ParameterNames.STATE, state);
+				}
+
+				String tokenValue;
+				Instant tokenIssuedAt;
+				Instant tokenExpiresAt;
+				byte[] authorizationCodeValue = this.lobHandler.getBlobAsBytes(rs, "authorization_code_value");
+
+				if (authorizationCodeValue != null) {
+					tokenValue = new String(authorizationCodeValue,
+							StandardCharsets.UTF_8);
+					tokenIssuedAt = rs.getTimestamp("authorization_code_issued_at").toInstant();
+					tokenExpiresAt = rs.getTimestamp("authorization_code_expires_at").toInstant();
+					Map<String, Object> authorizationCodeMetadata = this.objectMapper.readValue(rs.getString("authorization_code_metadata"), Map.class);
+
+					OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
+							tokenValue, tokenIssuedAt, tokenExpiresAt);
+					builder
+							.token(authorizationCode, (metadata) -> metadata.putAll(authorizationCodeMetadata));
+				}
+
+				byte[] accessTokenValue = this.lobHandler.getBlobAsBytes(rs, "access_token_value");
+				if (accessTokenValue != null) {
+					tokenValue = new String(accessTokenValue,
+							StandardCharsets.UTF_8);
+					tokenIssuedAt = rs.getTimestamp("access_token_issued_at").toInstant();
+					tokenExpiresAt = rs.getTimestamp("access_token_expires_at").toInstant();
+					Map<String, Object> accessTokenMetadata = this.objectMapper.readValue(rs.getString("access_token_metadata"), Map.class);
+					OAuth2AccessToken.TokenType tokenType = null;
+					if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(rs.getString("access_token_type"))) {
+						tokenType = OAuth2AccessToken.TokenType.BEARER;
+					}
+
+					Set<String> scopes = Collections.emptySet();
+					String accessTokenScopes = rs.getString("access_token_scopes");
+					if (accessTokenScopes != null) {
+						scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes);
+					}
+					OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, tokenIssuedAt, tokenExpiresAt, scopes);
+					builder
+							.token(accessToken, (metadata) -> metadata.putAll(accessTokenMetadata));
+				}
+
+				byte[] oidcIdTokenValue = this.lobHandler.getBlobAsBytes(rs, "oidc_id_token_value");
+
+				if (oidcIdTokenValue != null) {
+					tokenValue = new String(oidcIdTokenValue,
+							StandardCharsets.UTF_8);
+					tokenIssuedAt = rs.getTimestamp("oidc_id_token_issued_at").toInstant();
+					tokenExpiresAt = rs.getTimestamp("oidc_id_token_expires_at").toInstant();
+					Map<String, Object> oidcTokenMetadata = this.objectMapper.readValue(rs.getString("oidc_id_token_metadata"), Map.class);
+
+					OidcIdToken oidcToken = new OidcIdToken(
+							tokenValue, tokenIssuedAt, tokenExpiresAt, (Map<String, Object>) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME));
+					builder
+							.token(oidcToken, (metadata) -> metadata.putAll(oidcTokenMetadata));
+				}
+
+				byte[] refreshTokenValue = this.lobHandler.getBlobAsBytes(rs, "refresh_token_value");
+				if (refreshTokenValue != null) {
+					tokenValue = new String(refreshTokenValue,
+							StandardCharsets.UTF_8);
+					tokenIssuedAt = rs.getTimestamp("refresh_token_issued_at").toInstant();
+					tokenExpiresAt = null;
+					Timestamp refreshTokenExpiresAt = rs.getTimestamp("refresh_token_expires_at");
+					if (refreshTokenExpiresAt != null) {
+						tokenExpiresAt = refreshTokenExpiresAt.toInstant();
+					}
+					Map<String, Object> refreshTokenMetadata = this.objectMapper.readValue(rs.getString("refresh_token_metadata"), Map.class);
+
+					OAuth2RefreshToken refreshToken = new OAuth2RefreshToken2(
+							tokenValue, tokenIssuedAt, tokenExpiresAt);
+					builder
+							.token(refreshToken, (metadata) -> metadata.putAll(refreshTokenMetadata));
+				}
+				return builder.build();
+			} catch (JsonProcessingException e) {
+				throw new IllegalArgumentException(e.getMessage(), e);
+			}
+		}
+
+		public final void setLobHandler(LobHandler lobHandler) {
+			Assert.notNull(lobHandler, "lobHandler cannot be null");
+			this.lobHandler = lobHandler;
+		}
+	}
+
+	/**
+	 * The default {@code Function} that maps {@link OAuth2Authorization} to a
+	 * {@code List} of {@link SqlParameterValue}.
+	 */
+	public static class OAuth2AuthorizationParametersMapper implements Function<OAuth2Authorization, List<SqlParameterValue>> {
+		private final ObjectMapper objectMapper;
+
+		public OAuth2AuthorizationParametersMapper(ObjectMapper objectMapper) {
+			Assert.notNull(objectMapper, "objectMapper cannot be null");
+			this.objectMapper = objectMapper;
+		}
+
+		@Override
+		public List<SqlParameterValue> apply(OAuth2Authorization authorization) {
+
+			try {
+				List<SqlParameterValue> parameters = new ArrayList<>();
+				parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getId()));
+				parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getRegisteredClientId()));
+				parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getPrincipalName()));
+				parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getAuthorizationGrantType().getValue()));
+
+				String attributes = this.objectMapper.writeValueAsString(authorization.getAttributes());
+				parameters.add(new SqlParameterValue(Types.VARCHAR, attributes));
+
+				String state = null;
+				String authorizationState = authorization.getAttribute(OAuth2ParameterNames.STATE);
+				if (StringUtils.hasText(authorizationState)) {
+					state = authorizationState;
+				}
+				parameters.add(new SqlParameterValue(Types.VARCHAR, state));
+
+				OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
+						authorization.getToken(OAuth2AuthorizationCode.class);
+				List<SqlParameterValue> authorizationCodeSqlParameters = toSqlParameterList(authorizationCode);
+				parameters.addAll(authorizationCodeSqlParameters);
+
+				OAuth2Authorization.Token<OAuth2AccessToken> accessToken =
+						authorization.getToken(OAuth2AccessToken.class);
+				List<SqlParameterValue> accessTokenSqlParameters = toSqlParameterList(accessToken);
+				parameters.addAll(accessTokenSqlParameters);
+				String accessTokenType = null;
+				String accessTokenScopes = null;
+
+				if (accessToken != null) {
+					accessTokenType = accessToken.getToken().getTokenType().getValue();
+					if (!CollectionUtils.isEmpty(accessToken.getToken().getScopes())) {
+						accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getToken().getScopes(), ",");
+					}
+				}
+
+				parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenType));
+				parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenScopes));
+				OAuth2Authorization.Token<OidcIdToken> oidcIdToken = authorization.getToken(OidcIdToken.class);
+				List<SqlParameterValue> oidcTokenSqlParameters = toSqlParameterList(oidcIdToken);
+				parameters.addAll(oidcTokenSqlParameters);
+
+				OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = authorization.getRefreshToken();
+
+				List<SqlParameterValue> refreshTokenSqlParameters = toSqlParameterList(refreshToken);
+				parameters.addAll(refreshTokenSqlParameters);
+				return parameters;
+			} catch (JsonProcessingException e) {
+				throw new IllegalArgumentException(e.getMessage(), e);
+			}
+
+		}
+
+		private <T extends AbstractOAuth2Token> List<SqlParameterValue> toSqlParameterList(OAuth2Authorization.Token<T> token) throws JsonProcessingException {
+			List<SqlParameterValue> parameters = new ArrayList<>();
+			byte[] tokenValue = null;
+			Timestamp tokenIssuedAt = null;
+			Timestamp tokenExpiresAt = null;
+			String codeMetadata = null;
+			if (token != null) {
+
+				tokenValue = token.getToken().getTokenValue().getBytes(StandardCharsets.UTF_8);
+				if (token.getToken().getIssuedAt() != null) {
+					tokenIssuedAt = Timestamp.from(token.getToken().getIssuedAt());
+				}
+
+				if (token.getToken().getExpiresAt() != null) {
+					tokenExpiresAt = Timestamp.from(token.getToken().getExpiresAt());
+				}
+				codeMetadata = this.objectMapper.writeValueAsString(token.getMetadata());
+			}
+			parameters.add(new SqlParameterValue(Types.BLOB, tokenValue));
+			parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenIssuedAt));
+			parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenExpiresAt));
+			parameters.add(new SqlParameterValue(Types.VARCHAR, codeMetadata));
+			return parameters;
+		}
+	}
+
+	private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
+
+		protected final LobCreator lobCreator;
+
+		private LobCreatorArgumentPreparedStatementSetter(LobCreator lobCreator, Object[] args) {
+			super(args);
+			this.lobCreator = lobCreator;
+		}
+
+		@Override
+		protected void doSetValue(PreparedStatement ps, int parameterPosition, Object argValue) throws SQLException {
+			if (argValue instanceof SqlParameterValue) {
+				SqlParameterValue paramValue = (SqlParameterValue) argValue;
+				if (paramValue.getSqlType() == Types.BLOB) {
+					if (paramValue.getValue() != null) {
+						Assert.isInstanceOf(byte[].class, paramValue.getValue(),
+								"Value of blob parameter must be byte[]");
+					}
+					byte[] valueBytes = (byte[]) paramValue.getValue();
+					this.lobCreator.setBlobAsBytes(ps, parameterPosition, valueBytes);
+					return;
+				}
+			}
+			super.doSetValue(ps, parameterPosition, argValue);
+		}
+
+	}
+}

+ 27 - 0
oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql

@@ -0,0 +1,27 @@
+CREATE TABLE oauth2_authorization (
+    id varchar(100) NOT NULL,
+    registered_client_id varchar(100) NOT NULL,
+    principal_name varchar(200) NOT NULL,
+    authorization_grant_type varchar(100) NOT NULL,
+    attributes varchar(1000) DEFAULT NULL,
+    state varchar(1000) DEFAULT NULL,
+    authorization_code_value blob DEFAULT NULL,
+    authorization_code_issued_at timestamp DEFAULT NULL,
+    authorization_code_expires_at timestamp DEFAULT NULL,
+    authorization_code_metadata varchar(1000) DEFAULT NULL,
+    access_token_value blob DEFAULT NULL,
+    access_token_issued_at timestamp DEFAULT NULL,
+    access_token_expires_at timestamp DEFAULT NULL,
+    access_token_metadata varchar(1000) DEFAULT NULL,
+    access_token_type varchar(100) DEFAULT NULL,
+    access_token_scopes varchar(1000) DEFAULT NULL,
+    oidc_id_token_value blob DEFAULT NULL,
+    oidc_id_token_issued_at timestamp DEFAULT NULL,
+    oidc_id_token_expires_at timestamp DEFAULT NULL,
+    oidc_id_token_metadata varchar(1000) DEFAULT NULL,
+    refresh_token_value blob DEFAULT NULL,
+    refresh_token_issued_at timestamp DEFAULT NULL,
+    refresh_token_expires_at timestamp DEFAULT NULL,
+    refresh_token_metadata varchar(1000) DEFAULT NULL,
+    PRIMARY KEY (id)
+);

+ 397 - 0
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/JdbcOAuth2AuthorizationServiceTests.java

@@ -0,0 +1,397 @@
+/*
+ * Copyright 2020-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.server.authorization;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.jdbc.core.JdbcOperations;
+import org.springframework.jdbc.core.JdbcTemplate;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
+import org.springframework.jdbc.support.lob.DefaultLobHandler;
+import org.springframework.security.oauth2.core.AuthorizationGrantType;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
+import org.springframework.security.oauth2.core.OAuth2TokenType;
+import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
+import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
+import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
+
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link JdbcOAuth2AuthorizationService}.
+ *
+ * @author Ovidiu Popa
+ */
+public class JdbcOAuth2AuthorizationServiceTests {
+	private static final String OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql";
+	private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
+	private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE);
+	private static final String ID = "id";
+	private static final RegisteredClient REGISTERED_CLIENT = TestRegisteredClients.registeredClient().build();
+	private static final String PRINCIPAL_NAME = "principal";
+	private static final AuthorizationGrantType AUTHORIZATION_GRANT_TYPE = AuthorizationGrantType.AUTHORIZATION_CODE;
+	private static final OAuth2AuthorizationCode AUTHORIZATION_CODE = new OAuth2AuthorizationCode(
+			"code", Instant.now().truncatedTo(ChronoUnit.MILLIS), Instant.now().plus(5, ChronoUnit.MINUTES).truncatedTo(ChronoUnit.MILLIS));
+
+	private EmbeddedDatabase db;
+	private JdbcOperations jdbcOperations;
+	private RegisteredClientRepository registeredClientRepository;
+	private JdbcOAuth2AuthorizationService authorizationService;
+
+
+	@Before
+	public void setUp() {
+		this.db = createDb();
+		this.registeredClientRepository = mock(RegisteredClientRepository.class);
+		this.jdbcOperations = new JdbcTemplate(this.db);
+		this.authorizationService = new JdbcOAuth2AuthorizationService(this.jdbcOperations, this.registeredClientRepository);
+	}
+
+	@After
+	public void tearDown() {
+		this.db.shutdown();
+	}
+
+	@Test
+	public void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatThrownBy(() -> new JdbcOAuth2AuthorizationService(null, this.registeredClientRepository))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("jdbcOperations cannot be null");
+		// @formatter:on
+	}
+
+	@Test
+	public void constructorWhenRegisteredClientRepositoryIsNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatThrownBy(() -> new JdbcOAuth2AuthorizationService(this.jdbcOperations, null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("registeredClientRepository cannot be null");
+		// @formatter:on
+	}
+
+	@Test
+	public void constructorWhenLobHandlerIsNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatThrownBy(() -> new JdbcOAuth2AuthorizationService(this.jdbcOperations, this.registeredClientRepository, null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("lobHandler cannot be null");
+		// @formatter:on
+	}
+
+	@Test
+	public void constructorWhenObjectMapperIsNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatThrownBy(() -> new JdbcOAuth2AuthorizationService(this.jdbcOperations, this.registeredClientRepository, new DefaultLobHandler(), null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("objectMapper cannot be null");
+		// @formatter:on
+	}
+
+	@Test
+	public void setAuthorizationRowMapperWhenNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatThrownBy(() -> this.authorizationService.setAuthorizationRowMapper(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationRowMapper cannot be null");
+		// @formatter:on
+	}
+
+	@Test
+	public void setAuthorizationParametersMapperWhenNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatThrownBy(() -> this.authorizationService.setAuthorizationParametersMapper(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizationParametersMapper cannot be null");
+		// @formatter:on
+	}
+
+	@Test
+	public void saveWhenAuthorizationNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatThrownBy(() -> this.authorizationService.save(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorization cannot be null");
+		// @formatter:on
+	}
+
+	@Test
+	public void saveWhenAuthorizationNewThenSaved() {
+		when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId())))
+				.thenReturn(REGISTERED_CLIENT);
+		OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.id(ID)
+				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
+				.token(AUTHORIZATION_CODE)
+				.build();
+		this.authorizationService.save(expectedAuthorization);
+
+		OAuth2Authorization authorization = this.authorizationService.findById(ID);
+		assertThat(authorization).isEqualTo(expectedAuthorization);
+	}
+
+	@Test
+	public void saveWhenAuthorizationExistsThenUpdated() {
+		when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId())))
+				.thenReturn(REGISTERED_CLIENT);
+		OAuth2Authorization originalAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.id(ID)
+				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
+				.token(AUTHORIZATION_CODE)
+				.build();
+		this.authorizationService.save(originalAuthorization);
+
+		OAuth2Authorization authorization = this.authorizationService.findById(
+				originalAuthorization.getId());
+		assertThat(authorization).isEqualTo(originalAuthorization);
+
+		OAuth2Authorization updatedAuthorization = OAuth2Authorization.from(authorization)
+				.attribute("custom-name-1", "custom-value-1")
+				.build();
+		this.authorizationService.save(updatedAuthorization);
+
+		authorization = this.authorizationService.findById(
+				updatedAuthorization.getId());
+		assertThat(authorization).isEqualTo(updatedAuthorization);
+		assertThat(authorization).isNotEqualTo(originalAuthorization);
+	}
+
+	@Test
+	public void saveLoadAuthorizationWhenCustomStrategiesSetThenCalled() throws Exception {
+		when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId())))
+				.thenReturn(REGISTERED_CLIENT);
+		OAuth2Authorization originalAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.id(ID)
+				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
+				.token(AUTHORIZATION_CODE)
+				.build();
+		ObjectMapper objectMapper = new ObjectMapper();
+		JdbcOAuth2AuthorizationService.OAuth2AuthorizationRowMapper authorizationRowMapper = spy(
+				new JdbcOAuth2AuthorizationService.OAuth2AuthorizationRowMapper(
+						this.registeredClientRepository, objectMapper));
+		this.authorizationService.setAuthorizationRowMapper(authorizationRowMapper);
+		JdbcOAuth2AuthorizationService.OAuth2AuthorizationParametersMapper authorizationParametersMapper = spy(
+				new JdbcOAuth2AuthorizationService.OAuth2AuthorizationParametersMapper(objectMapper));
+		this.authorizationService.setAuthorizationParametersMapper(authorizationParametersMapper);
+
+		this.authorizationService.save(originalAuthorization);
+		OAuth2Authorization authorization = this.authorizationService.findById(
+				originalAuthorization.getId());
+		assertThat(authorization).isEqualTo(originalAuthorization);
+		verify(authorizationRowMapper).mapRow(any(), anyInt());
+		verify(authorizationParametersMapper).apply(any());
+	}
+
+	@Test
+	public void removeWhenAuthorizationNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatThrownBy(() -> this.authorizationService.remove(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorization cannot be null");
+		// @formatter:on
+	}
+
+	@Test
+	public void removeWhenAuthorizationProvidedThenRemoved() {
+		when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId())))
+				.thenReturn(REGISTERED_CLIENT);
+		OAuth2Authorization expectedAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.id(ID)
+				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
+				.token(AUTHORIZATION_CODE)
+				.build();
+
+		this.authorizationService.save(expectedAuthorization);
+		OAuth2Authorization authorization = this.authorizationService.findByToken(
+				AUTHORIZATION_CODE.getTokenValue(), AUTHORIZATION_CODE_TOKEN_TYPE);
+		assertThat(authorization).isEqualTo(expectedAuthorization);
+
+		this.authorizationService.remove(expectedAuthorization);
+		authorization = this.authorizationService.findByToken(
+				AUTHORIZATION_CODE.getTokenValue(), AUTHORIZATION_CODE_TOKEN_TYPE);
+		assertThat(authorization).isNull();
+	}
+
+	@Test
+	public void findByIdWhenIdNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatThrownBy(() -> this.authorizationService.findById(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("id cannot be empty");
+		// @formatter:on
+	}
+
+	@Test
+	public void findByIdWhenIdEmptyThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatThrownBy(() -> this.authorizationService.findById(" "))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("id cannot be empty");
+		// @formatter:on
+	}
+
+	@Test
+	public void findByTokenWhenTokenNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatThrownBy(() -> this.authorizationService.findByToken(null, AUTHORIZATION_CODE_TOKEN_TYPE))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("token cannot be empty");
+		// @formatter:on
+	}
+
+	@Test
+	public void findByTokenWhenStateExistsThenFound() {
+		when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId())))
+				.thenReturn(REGISTERED_CLIENT);
+		String state = "state";
+		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.id(ID)
+				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
+				.attribute(OAuth2ParameterNames.STATE, state)
+				.build();
+		this.authorizationService.save(authorization);
+
+		OAuth2Authorization result = this.authorizationService.findByToken(
+				state, STATE_TOKEN_TYPE);
+		assertThat(authorization).isEqualTo(result);
+		result = this.authorizationService.findByToken(state, null);
+		assertThat(authorization).isEqualTo(result);
+	}
+
+	@Test
+	public void findByTokenWhenAuthorizationCodeExistsThenFound() {
+		when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId())))
+				.thenReturn(REGISTERED_CLIENT);
+		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.id(ID)
+				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
+				.token(AUTHORIZATION_CODE)
+				.build();
+		this.authorizationService.save(authorization);
+
+		OAuth2Authorization result = this.authorizationService.findByToken(
+				AUTHORIZATION_CODE.getTokenValue(), AUTHORIZATION_CODE_TOKEN_TYPE);
+		assertThat(authorization).isEqualTo(result);
+		result = this.authorizationService.findByToken(AUTHORIZATION_CODE.getTokenValue(), null);
+		assertThat(authorization).isEqualTo(result);
+	}
+
+	@Test
+	public void findByTokenWhenAccessTokenExistsThenFound() {
+		when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId())))
+				.thenReturn(REGISTERED_CLIENT);
+		OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER,
+				"access-token", Instant.now().minusSeconds(60).truncatedTo(ChronoUnit.MILLIS), Instant.now().truncatedTo(ChronoUnit.MILLIS));
+		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.id(ID)
+				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
+				.token(AUTHORIZATION_CODE)
+				.accessToken(accessToken)
+				.build();
+		this.authorizationService.save(authorization);
+
+		OAuth2Authorization result = this.authorizationService.findByToken(
+				accessToken.getTokenValue(), OAuth2TokenType.ACCESS_TOKEN);
+		assertThat(authorization).isEqualTo(result);
+		result = this.authorizationService.findByToken(accessToken.getTokenValue(), null);
+		assertThat(authorization).isEqualTo(result);
+	}
+
+	@Test
+	public void findByTokenWhenRefreshTokenExistsThenFound() {
+		when(registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId())))
+				.thenReturn(REGISTERED_CLIENT);
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken2("refresh-token",
+				Instant.now().truncatedTo(ChronoUnit.MILLIS),
+				Instant.now().plus(5, ChronoUnit.MINUTES).truncatedTo(ChronoUnit.MILLIS));
+		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.id(ID)
+				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
+				.refreshToken(refreshToken)
+				.build();
+		this.authorizationService.save(authorization);
+
+		OAuth2Authorization result = this.authorizationService.findByToken(
+				refreshToken.getTokenValue(), OAuth2TokenType.REFRESH_TOKEN);
+		assertThat(authorization).isEqualTo(result);
+		result = this.authorizationService.findByToken(refreshToken.getTokenValue(), null);
+		assertThat(authorization).isEqualTo(result);
+	}
+
+	@Test
+	public void findByTokenWhenWrongTokenTypeThenNotFound() {
+		OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", Instant.now().truncatedTo(ChronoUnit.MILLIS));
+		OAuth2Authorization authorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
+				.id(ID)
+				.principalName(PRINCIPAL_NAME)
+				.authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
+				.refreshToken(refreshToken)
+				.build();
+		this.authorizationService.save(authorization);
+
+		OAuth2Authorization result = this.authorizationService.findByToken(
+				refreshToken.getTokenValue(), OAuth2TokenType.ACCESS_TOKEN);
+		assertThat(result).isNull();
+	}
+
+	@Test
+	public void findByTokenWhenTokenDoesNotExistThenNull() {
+		OAuth2Authorization result = this.authorizationService.findByToken(
+				"access-token", OAuth2TokenType.ACCESS_TOKEN);
+		assertThat(result).isNull();
+	}
+
+	private static EmbeddedDatabase createDb() {
+		return createDb(OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE);
+	}
+
+	private static EmbeddedDatabase createDb(String schema) {
+		// @formatter:off
+		return new EmbeddedDatabaseBuilder()
+				.generateUniqueName(true)
+				.setType(EmbeddedDatabaseType.HSQL)
+				.setScriptEncoding("UTF-8")
+				.addScript(schema)
+				.build();
+		// @formatter:on
+	}
+}