|
@@ -15,32 +15,48 @@
|
|
|
*/
|
|
|
package org.springframework.security.oauth2.server.authorization;
|
|
|
|
|
|
+import java.sql.ResultSet;
|
|
|
+import java.sql.SQLException;
|
|
|
+import java.sql.Timestamp;
|
|
|
+import java.sql.Types;
|
|
|
import java.time.Instant;
|
|
|
import java.time.temporal.ChronoUnit;
|
|
|
+import java.util.ArrayList;
|
|
|
+import java.util.Collections;
|
|
|
import java.util.List;
|
|
|
+import java.util.Map;
|
|
|
+import java.util.Set;
|
|
|
import java.util.function.Function;
|
|
|
|
|
|
+import com.fasterxml.jackson.core.type.TypeReference;
|
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
import org.junit.After;
|
|
|
import org.junit.Before;
|
|
|
import org.junit.Test;
|
|
|
|
|
|
+import org.springframework.dao.DataRetrievalFailureException;
|
|
|
+import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
|
|
|
import org.springframework.jdbc.core.JdbcOperations;
|
|
|
import org.springframework.jdbc.core.JdbcTemplate;
|
|
|
+import org.springframework.jdbc.core.PreparedStatementSetter;
|
|
|
import org.springframework.jdbc.core.RowMapper;
|
|
|
import org.springframework.jdbc.core.SqlParameterValue;
|
|
|
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
|
|
|
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
|
|
|
import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
|
|
|
+import org.springframework.security.oauth2.core.AbstractOAuth2Token;
|
|
|
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
|
|
import org.springframework.security.oauth2.core.OAuth2AccessToken;
|
|
|
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
|
|
import org.springframework.security.oauth2.core.OAuth2RefreshToken2;
|
|
|
import org.springframework.security.oauth2.core.OAuth2TokenType;
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
|
|
|
+import org.springframework.security.oauth2.core.oidc.OidcIdToken;
|
|
|
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient;
|
|
|
import org.springframework.security.oauth2.server.authorization.client.RegisteredClientRepository;
|
|
|
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients;
|
|
|
+import org.springframework.util.CollectionUtils;
|
|
|
+import org.springframework.util.StringUtils;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
|
@@ -59,6 +75,7 @@ import static org.mockito.Mockito.when;
|
|
|
*/
|
|
|
public class JdbcOAuth2AuthorizationServiceTests {
|
|
|
private static final String OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/oauth2-authorization-schema.sql";
|
|
|
+ private static final String CUSTOM_OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/server/authorization/custom-oauth2-authorization-schema.sql";
|
|
|
private static final OAuth2TokenType AUTHORIZATION_CODE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.CODE);
|
|
|
private static final OAuth2TokenType STATE_TOKEN_TYPE = new OAuth2TokenType(OAuth2ParameterNames.STATE);
|
|
|
private static final String ID = "id";
|
|
@@ -374,6 +391,30 @@ public class JdbcOAuth2AuthorizationServiceTests {
|
|
|
assertThat(result).isNull();
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void tableDefinitionWhenCustomThenAbleToOverride() {
|
|
|
+ when(this.registeredClientRepository.findById(eq(REGISTERED_CLIENT.getId())))
|
|
|
+ .thenReturn(REGISTERED_CLIENT);
|
|
|
+
|
|
|
+ EmbeddedDatabase db = createDb(CUSTOM_OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE);
|
|
|
+ OAuth2AuthorizationService authorizationService =
|
|
|
+ new CustomJdbcOAuth2AuthorizationService(new JdbcTemplate(db), this.registeredClientRepository);
|
|
|
+ String state = "state";
|
|
|
+ OAuth2Authorization originalAuthorization = OAuth2Authorization.withRegisteredClient(REGISTERED_CLIENT)
|
|
|
+ .id(ID)
|
|
|
+ .principalName(PRINCIPAL_NAME)
|
|
|
+ .authorizationGrantType(AUTHORIZATION_GRANT_TYPE)
|
|
|
+ .attribute(OAuth2ParameterNames.STATE, state)
|
|
|
+ .token(AUTHORIZATION_CODE)
|
|
|
+ .build();
|
|
|
+ authorizationService.save(originalAuthorization);
|
|
|
+ OAuth2Authorization foundAuthorization1 = authorizationService.findById(originalAuthorization.getId());
|
|
|
+ assertThat(foundAuthorization1).isEqualTo(originalAuthorization);
|
|
|
+ OAuth2Authorization foundAuthorization2 = authorizationService.findByToken(state, STATE_TOKEN_TYPE);
|
|
|
+ assertThat(foundAuthorization2).isEqualTo(originalAuthorization);
|
|
|
+ db.shutdown();
|
|
|
+ }
|
|
|
+
|
|
|
private static EmbeddedDatabase createDb() {
|
|
|
return createDb(OAUTH2_AUTHORIZATION_SCHEMA_SQL_RESOURCE);
|
|
|
}
|
|
@@ -388,4 +429,282 @@ public class JdbcOAuth2AuthorizationServiceTests {
|
|
|
.build();
|
|
|
// @formatter:on
|
|
|
}
|
|
|
+
|
|
|
+ private static final class CustomJdbcOAuth2AuthorizationService extends JdbcOAuth2AuthorizationService {
|
|
|
+
|
|
|
+ // @formatter:off
|
|
|
+ private static final String COLUMN_NAMES = "id, "
|
|
|
+ + "registeredClientId, "
|
|
|
+ + "principalName, "
|
|
|
+ + "authorizationGrantType, "
|
|
|
+ + "attributes, "
|
|
|
+ + "state, "
|
|
|
+ + "authorizationCodeValue, "
|
|
|
+ + "authorizationCodeIssuedAt, "
|
|
|
+ + "authorizationCodeExpiresAt,"
|
|
|
+ + "authorizationCodeMetadata,"
|
|
|
+ + "accessTokenValue,"
|
|
|
+ + "accessTokenIssuedAt,"
|
|
|
+ + "accessTokenExpiresAt,"
|
|
|
+ + "accessTokenMetadata,"
|
|
|
+ + "accessTokenType,"
|
|
|
+ + "accessTokenScopes,"
|
|
|
+ + "oidcIdTokenValue,"
|
|
|
+ + "oidcIdTokenIssuedAt,"
|
|
|
+ + "oidcIdTokenExpiresAt,"
|
|
|
+ + "oidcIdTokenMetadata,"
|
|
|
+ + "refreshTokenValue,"
|
|
|
+ + "refreshTokenIssuedAt,"
|
|
|
+ + "refreshTokenExpiresAt,"
|
|
|
+ + "refreshTokenMetadata";
|
|
|
+ // @formatter:on
|
|
|
+
|
|
|
+ private static final String TABLE_NAME = "oauth2Authorization";
|
|
|
+
|
|
|
+ private static final String PK_FILTER = "id = ?";
|
|
|
+ private static final String UNKNOWN_TOKEN_TYPE_FILTER = "state = ? OR authorizationCodeValue = ? OR " +
|
|
|
+ "accessTokenValue = ? OR " +
|
|
|
+ "refreshTokenValue = ?";
|
|
|
+
|
|
|
+ // @formatter:off
|
|
|
+ private static final String LOAD_AUTHORIZATION_SQL = "SELECT " + COLUMN_NAMES
|
|
|
+ + " FROM " + TABLE_NAME
|
|
|
+ + " WHERE ";
|
|
|
+ // @formatter:on
|
|
|
+
|
|
|
+ // @formatter:off
|
|
|
+ private static final String SAVE_AUTHORIZATION_SQL = "INSERT INTO " + TABLE_NAME
|
|
|
+ + " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?,?, ?, ?, ?, ?, ?, ?, ?,?, ?, ?, ?, ?, ?, ?, ?)";
|
|
|
+ // @formatter:on
|
|
|
+
|
|
|
+ private static final String REMOVE_AUTHORIZATION_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER;
|
|
|
+
|
|
|
+ CustomJdbcOAuth2AuthorizationService(JdbcOperations jdbcOperations,
|
|
|
+ RegisteredClientRepository registeredClientRepository) {
|
|
|
+ super(jdbcOperations, registeredClientRepository);
|
|
|
+ setAuthorizationRowMapper(new CustomOAuth2AuthorizationRowMapper(registeredClientRepository));
|
|
|
+ setAuthorizationParametersMapper(new CustomOAuth2AuthorizationParametersMapper());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void save(OAuth2Authorization authorization) {
|
|
|
+ List<SqlParameterValue> parameters = getAuthorizationParametersMapper().apply(authorization);
|
|
|
+ PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
|
|
|
+ getJdbcOperations().update(SAVE_AUTHORIZATION_SQL, pss);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void remove(OAuth2Authorization authorization) {
|
|
|
+ SqlParameterValue[] parameters = new SqlParameterValue[] {
|
|
|
+ new SqlParameterValue(Types.VARCHAR, authorization.getId())
|
|
|
+ };
|
|
|
+ PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
|
|
|
+ getJdbcOperations().update(REMOVE_AUTHORIZATION_SQL, pss);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public OAuth2Authorization findById(String id) {
|
|
|
+ return findBy(PK_FILTER, id);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public OAuth2Authorization findByToken(String token, OAuth2TokenType tokenType) {
|
|
|
+ return findBy(UNKNOWN_TOKEN_TYPE_FILTER, token, token, token, token);
|
|
|
+ }
|
|
|
+
|
|
|
+ private OAuth2Authorization findBy(String filter, Object... args) {
|
|
|
+ List<OAuth2Authorization> result = getJdbcOperations()
|
|
|
+ .query(LOAD_AUTHORIZATION_SQL + filter, getAuthorizationRowMapper(), args);
|
|
|
+ return !result.isEmpty() ? result.get(0) : null;
|
|
|
+ }
|
|
|
+
|
|
|
+ private static final class CustomOAuth2AuthorizationRowMapper extends JdbcOAuth2AuthorizationService.OAuth2AuthorizationRowMapper {
|
|
|
+
|
|
|
+ CustomOAuth2AuthorizationRowMapper(RegisteredClientRepository registeredClientRepository) {
|
|
|
+ super(registeredClientRepository);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
+ public OAuth2Authorization mapRow(ResultSet rs, int rowNum) throws SQLException {
|
|
|
+ String registeredClientId = rs.getString("registeredClientId");
|
|
|
+ RegisteredClient registeredClient = getRegisteredClientRepository().findById(registeredClientId);
|
|
|
+ if (registeredClient == null) {
|
|
|
+ throw new DataRetrievalFailureException(
|
|
|
+ "The RegisteredClient with id '" + registeredClientId + "' was not found in the RegisteredClientRepository.");
|
|
|
+ }
|
|
|
+
|
|
|
+ OAuth2Authorization.Builder builder = OAuth2Authorization.withRegisteredClient(registeredClient);
|
|
|
+ String id = rs.getString("id");
|
|
|
+ String principalName = rs.getString("principalName");
|
|
|
+ String authorizationGrantType = rs.getString("authorizationGrantType");
|
|
|
+ Map<String, Object> attributes = parseMap(rs.getString("attributes"));
|
|
|
+
|
|
|
+ builder.id(id)
|
|
|
+ .principalName(principalName)
|
|
|
+ .authorizationGrantType(new AuthorizationGrantType(authorizationGrantType))
|
|
|
+ .attributes((attrs) -> attrs.putAll(attributes));
|
|
|
+
|
|
|
+ String state = rs.getString("state");
|
|
|
+ if (StringUtils.hasText(state)) {
|
|
|
+ builder.attribute(OAuth2ParameterNames.STATE, state);
|
|
|
+ }
|
|
|
+
|
|
|
+ String tokenValue = rs.getString("authorizationCodeValue");
|
|
|
+ Instant tokenIssuedAt;
|
|
|
+ Instant tokenExpiresAt;
|
|
|
+ if (tokenValue != null) {
|
|
|
+ tokenIssuedAt = rs.getTimestamp("authorizationCodeIssuedAt").toInstant();
|
|
|
+ tokenExpiresAt = rs.getTimestamp("authorizationCodeExpiresAt").toInstant();
|
|
|
+ Map<String, Object> authorizationCodeMetadata = parseMap(rs.getString("authorizationCodeMetadata"));
|
|
|
+
|
|
|
+ OAuth2AuthorizationCode authorizationCode = new OAuth2AuthorizationCode(
|
|
|
+ tokenValue, tokenIssuedAt, tokenExpiresAt);
|
|
|
+ builder.token(authorizationCode, (metadata) -> metadata.putAll(authorizationCodeMetadata));
|
|
|
+ }
|
|
|
+
|
|
|
+ tokenValue = rs.getString("accessTokenValue");
|
|
|
+ if (tokenValue != null) {
|
|
|
+ tokenIssuedAt = rs.getTimestamp("accessTokenIssuedAt").toInstant();
|
|
|
+ tokenExpiresAt = rs.getTimestamp("accessTokenExpiresAt").toInstant();
|
|
|
+ Map<String, Object> accessTokenMetadata = parseMap(rs.getString("accessTokenMetadata"));
|
|
|
+ OAuth2AccessToken.TokenType tokenType = null;
|
|
|
+ if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(rs.getString("accessTokenType"))) {
|
|
|
+ tokenType = OAuth2AccessToken.TokenType.BEARER;
|
|
|
+ }
|
|
|
+
|
|
|
+ Set<String> scopes = Collections.emptySet();
|
|
|
+ String accessTokenScopes = rs.getString("accessTokenScopes");
|
|
|
+ if (accessTokenScopes != null) {
|
|
|
+ scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes);
|
|
|
+ }
|
|
|
+ OAuth2AccessToken accessToken = new OAuth2AccessToken(tokenType, tokenValue, tokenIssuedAt, tokenExpiresAt, scopes);
|
|
|
+ builder.token(accessToken, (metadata) -> metadata.putAll(accessTokenMetadata));
|
|
|
+ }
|
|
|
+
|
|
|
+ tokenValue = rs.getString("oidcIdTokenValue");
|
|
|
+ if (tokenValue != null) {
|
|
|
+ tokenIssuedAt = rs.getTimestamp("oidcIdTokenIssuedAt").toInstant();
|
|
|
+ tokenExpiresAt = rs.getTimestamp("oidcIdTokenExpiresAt").toInstant();
|
|
|
+ Map<String, Object> oidcTokenMetadata = parseMap(rs.getString("oidcIdTokenMetadata"));
|
|
|
+
|
|
|
+ OidcIdToken oidcToken = new OidcIdToken(
|
|
|
+ tokenValue, tokenIssuedAt, tokenExpiresAt, (Map<String, Object>) oidcTokenMetadata.get(OAuth2Authorization.Token.CLAIMS_METADATA_NAME));
|
|
|
+ builder.token(oidcToken, (metadata) -> metadata.putAll(oidcTokenMetadata));
|
|
|
+ }
|
|
|
+
|
|
|
+ tokenValue = rs.getString("refreshTokenValue");
|
|
|
+ if (tokenValue != null) {
|
|
|
+ tokenIssuedAt = rs.getTimestamp("refreshTokenIssuedAt").toInstant();
|
|
|
+ tokenExpiresAt = null;
|
|
|
+ Timestamp refreshTokenExpiresAt = rs.getTimestamp("refreshTokenExpiresAt");
|
|
|
+ if (refreshTokenExpiresAt != null) {
|
|
|
+ tokenExpiresAt = refreshTokenExpiresAt.toInstant();
|
|
|
+ }
|
|
|
+ Map<String, Object> refreshTokenMetadata = parseMap(rs.getString("refreshTokenMetadata"));
|
|
|
+
|
|
|
+ OAuth2RefreshToken refreshToken = new OAuth2RefreshToken2(
|
|
|
+ tokenValue, tokenIssuedAt, tokenExpiresAt);
|
|
|
+ builder.token(refreshToken, (metadata) -> metadata.putAll(refreshTokenMetadata));
|
|
|
+ }
|
|
|
+
|
|
|
+ return builder.build();
|
|
|
+ }
|
|
|
+
|
|
|
+ private Map<String, Object> parseMap(String data) {
|
|
|
+ try {
|
|
|
+ return getObjectMapper().readValue(data, new TypeReference<Map<String, Object>>() {});
|
|
|
+ } catch (Exception ex) {
|
|
|
+ throw new IllegalArgumentException(ex.getMessage(), ex);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ private static final class CustomOAuth2AuthorizationParametersMapper extends JdbcOAuth2AuthorizationService.OAuth2AuthorizationParametersMapper {
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public List<SqlParameterValue> apply(OAuth2Authorization authorization) {
|
|
|
+ List<SqlParameterValue> parameters = new ArrayList<>();
|
|
|
+ parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getId()));
|
|
|
+ parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getRegisteredClientId()));
|
|
|
+ parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getPrincipalName()));
|
|
|
+ parameters.add(new SqlParameterValue(Types.VARCHAR, authorization.getAuthorizationGrantType().getValue()));
|
|
|
+
|
|
|
+ String attributes = writeMap(authorization.getAttributes());
|
|
|
+ parameters.add(new SqlParameterValue(Types.VARCHAR, attributes));
|
|
|
+
|
|
|
+ String state = null;
|
|
|
+ String authorizationState = authorization.getAttribute(OAuth2ParameterNames.STATE);
|
|
|
+ if (StringUtils.hasText(authorizationState)) {
|
|
|
+ state = authorizationState;
|
|
|
+ }
|
|
|
+ parameters.add(new SqlParameterValue(Types.VARCHAR, state));
|
|
|
+
|
|
|
+ OAuth2Authorization.Token<OAuth2AuthorizationCode> authorizationCode =
|
|
|
+ authorization.getToken(OAuth2AuthorizationCode.class);
|
|
|
+ List<SqlParameterValue> authorizationCodeSqlParameters = toSqlParameterList(authorizationCode);
|
|
|
+ parameters.addAll(authorizationCodeSqlParameters);
|
|
|
+
|
|
|
+ OAuth2Authorization.Token<OAuth2AccessToken> accessToken =
|
|
|
+ authorization.getToken(OAuth2AccessToken.class);
|
|
|
+ List<SqlParameterValue> accessTokenSqlParameters = toSqlParameterList(accessToken);
|
|
|
+ parameters.addAll(accessTokenSqlParameters);
|
|
|
+ String accessTokenType = null;
|
|
|
+ String accessTokenScopes = null;
|
|
|
+ if (accessToken != null) {
|
|
|
+ accessTokenType = accessToken.getToken().getTokenType().getValue();
|
|
|
+ if (!CollectionUtils.isEmpty(accessToken.getToken().getScopes())) {
|
|
|
+ accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getToken().getScopes(), ",");
|
|
|
+ }
|
|
|
+ }
|
|
|
+ parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenType));
|
|
|
+ parameters.add(new SqlParameterValue(Types.VARCHAR, accessTokenScopes));
|
|
|
+
|
|
|
+ OAuth2Authorization.Token<OidcIdToken> oidcIdToken = authorization.getToken(OidcIdToken.class);
|
|
|
+ List<SqlParameterValue> oidcIdTokenSqlParameters = toSqlParameterList(oidcIdToken);
|
|
|
+ parameters.addAll(oidcIdTokenSqlParameters);
|
|
|
+
|
|
|
+ OAuth2Authorization.Token<OAuth2RefreshToken> refreshToken = authorization.getRefreshToken();
|
|
|
+ List<SqlParameterValue> refreshTokenSqlParameters = toSqlParameterList(refreshToken);
|
|
|
+ parameters.addAll(refreshTokenSqlParameters);
|
|
|
+ return parameters;
|
|
|
+ }
|
|
|
+
|
|
|
+ private <T extends AbstractOAuth2Token> List<SqlParameterValue> toSqlParameterList(OAuth2Authorization.Token<T> token) {
|
|
|
+ List<SqlParameterValue> parameters = new ArrayList<>();
|
|
|
+ String tokenValue = null;
|
|
|
+ Timestamp tokenIssuedAt = null;
|
|
|
+ Timestamp tokenExpiresAt = null;
|
|
|
+ String metadata = null;
|
|
|
+ if (token != null) {
|
|
|
+ tokenValue = token.getToken().getTokenValue();
|
|
|
+ if (token.getToken().getIssuedAt() != null) {
|
|
|
+ tokenIssuedAt = Timestamp.from(token.getToken().getIssuedAt());
|
|
|
+ }
|
|
|
+
|
|
|
+ if (token.getToken().getExpiresAt() != null) {
|
|
|
+ tokenExpiresAt = Timestamp.from(token.getToken().getExpiresAt());
|
|
|
+ }
|
|
|
+ metadata = writeMap(token.getMetadata());
|
|
|
+ }
|
|
|
+ parameters.add(new SqlParameterValue(Types.VARCHAR, tokenValue));
|
|
|
+ parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenIssuedAt));
|
|
|
+ parameters.add(new SqlParameterValue(Types.TIMESTAMP, tokenExpiresAt));
|
|
|
+ parameters.add(new SqlParameterValue(Types.VARCHAR, metadata));
|
|
|
+ return parameters;
|
|
|
+ }
|
|
|
+
|
|
|
+ private String writeMap(Map<String, Object> data) {
|
|
|
+ try {
|
|
|
+ return getObjectMapper().writeValueAsString(data);
|
|
|
+ } catch (Exception ex) {
|
|
|
+ throw new IllegalArgumentException(ex.getMessage(), ex);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
+ }
|
|
|
+
|
|
|
}
|