Browse Source

Add JDBC implementation of OAuth2AuthorizedClientService

Fixes gh-7655
Joe Grandja 5 years ago
parent
commit
de8b558561

+ 3 - 0
oauth2/oauth2-client/spring-security-oauth2-client.gradle

@@ -12,6 +12,7 @@ dependencies {
 	optional 'org.springframework:spring-webflux'
 	optional 'com.fasterxml.jackson.core:jackson-databind'
 	optional 'com.fasterxml.jackson.datatype:jackson-datatype-jsr310'
+	optional 'org.springframework:spring-jdbc'
 
 	testCompile project(path: ':spring-security-oauth2-core', configuration: 'tests')
 	testCompile project(path: ':spring-security-oauth2-jose', configuration: 'tests')
@@ -22,5 +23,7 @@ dependencies {
 	testCompile 'io.projectreactor.tools:blockhound'
 	testCompile 'org.skyscreamer:jsonassert'
 
+	testRuntime 'org.hsqldb:hsqldb'
+
 	provided 'javax.servlet:javax.servlet-api'
 }

+ 312 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java

@@ -0,0 +1,312 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client;
+
+import org.springframework.dao.DataRetrievalFailureException;
+import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
+import org.springframework.jdbc.core.JdbcOperations;
+import org.springframework.jdbc.core.PreparedStatementSetter;
+import org.springframework.jdbc.core.RowMapper;
+import org.springframework.jdbc.core.SqlParameterValue;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.util.Assert;
+import org.springframework.util.CollectionUtils;
+import org.springframework.util.StringUtils;
+
+import java.nio.charset.StandardCharsets;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Timestamp;
+import java.sql.Types;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+import java.util.function.Function;
+
+/**
+ * A JDBC implementation of an {@link OAuth2AuthorizedClientService}
+ * that uses a {@link JdbcOperations} for {@link OAuth2AuthorizedClient} persistence.
+ *
+ * <p>
+ * <b>NOTE:</b> This {@code OAuth2AuthorizedClientService} depends on the table definition
+ * described in "classpath:org/springframework/security/oauth2/client/oauth2-client-schema.sql"
+ * and therefore MUST be defined in the database schema.
+ *
+ * @author Joe Grandja
+ * @since 5.3
+ * @see OAuth2AuthorizedClientService
+ * @see OAuth2AuthorizedClient
+ * @see JdbcOperations
+ * @see RowMapper
+ */
+public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClientService {
+	private static final String COLUMN_NAMES =
+			"client_registration_id, " +
+			"principal_name, " +
+			"access_token_type, " +
+			"access_token_value, " +
+			"access_token_issued_at, " +
+			"access_token_expires_at, " +
+			"access_token_scopes, " +
+			"refresh_token_value, " +
+			"refresh_token_issued_at";
+	private static final String TABLE_NAME = "oauth2_authorized_client";
+	private static final String PK_FILTER = "client_registration_id = ? AND principal_name = ?";
+	private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES +
+			" FROM " + TABLE_NAME + " WHERE " + PK_FILTER;
+	private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME +
+			" (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
+	private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME +
+			" WHERE " + PK_FILTER;
+	protected final JdbcOperations jdbcOperations;
+	protected RowMapper<OAuth2AuthorizedClient> authorizedClientRowMapper;
+	protected Function<OAuth2AuthorizedClientHolder, List<SqlParameterValue>> authorizedClientParametersMapper;
+
+	/**
+	 * Constructs a {@code JdbcOAuth2AuthorizedClientService} using the provided parameters.
+	 *
+	 * @param jdbcOperations the JDBC operations
+	 * @param clientRegistrationRepository the repository of client registrations
+	 */
+	public JdbcOAuth2AuthorizedClientService(
+			JdbcOperations jdbcOperations, ClientRegistrationRepository clientRegistrationRepository) {
+
+		Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
+		Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
+		this.jdbcOperations = jdbcOperations;
+		this.authorizedClientRowMapper = new OAuth2AuthorizedClientRowMapper(clientRegistrationRepository);
+		this.authorizedClientParametersMapper = new OAuth2AuthorizedClientParametersMapper();
+	}
+
+	@Override
+	@SuppressWarnings("unchecked")
+	public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, String principalName) {
+		Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
+		Assert.hasText(principalName, "principalName cannot be empty");
+
+		SqlParameterValue[] parameters = new SqlParameterValue[] {
+				new SqlParameterValue(Types.VARCHAR, clientRegistrationId),
+				new SqlParameterValue(Types.VARCHAR, principalName)
+		};
+		PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
+
+		List<OAuth2AuthorizedClient> result = this.jdbcOperations.query(
+				LOAD_AUTHORIZED_CLIENT_SQL, pss, this.authorizedClientRowMapper);
+
+		return !result.isEmpty() ? (T) result.get(0) : null;
+	}
+
+	@Override
+	public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
+		Assert.notNull(authorizedClient, "authorizedClient cannot be null");
+		Assert.notNull(principal, "principal cannot be null");
+
+		List<SqlParameterValue> parameters = this.authorizedClientParametersMapper.apply(
+				new OAuth2AuthorizedClientHolder(authorizedClient, principal));
+		PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
+
+		this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss);
+	}
+
+	@Override
+	public void removeAuthorizedClient(String clientRegistrationId, String principalName) {
+		Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
+		Assert.hasText(principalName, "principalName cannot be empty");
+
+		SqlParameterValue[] parameters = new SqlParameterValue[] {
+				new SqlParameterValue(Types.VARCHAR, clientRegistrationId),
+				new SqlParameterValue(Types.VARCHAR, principalName)
+		};
+		PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
+
+		this.jdbcOperations.update(REMOVE_AUTHORIZED_CLIENT_SQL, pss);
+	}
+
+	/**
+	 * Sets the {@link RowMapper} used for mapping the current row in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient}.
+	 * The default is {@link OAuth2AuthorizedClientRowMapper}.
+	 *
+	 * @param authorizedClientRowMapper the {@link RowMapper} used for mapping the current row in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient}
+	 */
+	public final void setAuthorizedClientRowMapper(RowMapper<OAuth2AuthorizedClient> authorizedClientRowMapper) {
+		Assert.notNull(authorizedClientRowMapper, "authorizedClientRowMapper cannot be null");
+		this.authorizedClientRowMapper = authorizedClientRowMapper;
+	}
+
+	/**
+	 * Sets the {@code Function} used for mapping {@link OAuth2AuthorizedClientHolder} to a {@code List} of {@link SqlParameterValue}.
+	 * The default is {@link OAuth2AuthorizedClientParametersMapper}.
+	 *
+	 * @param authorizedClientParametersMapper the {@code Function} used for mapping {@link OAuth2AuthorizedClientHolder} to a {@code List} of {@link SqlParameterValue}
+	 */
+	public final void setAuthorizedClientParametersMapper(Function<OAuth2AuthorizedClientHolder, List<SqlParameterValue>> authorizedClientParametersMapper) {
+		Assert.notNull(authorizedClientParametersMapper, "authorizedClientParametersMapper cannot be null");
+		this.authorizedClientParametersMapper = authorizedClientParametersMapper;
+	}
+
+	/**
+	 * The default {@link RowMapper} that maps the current row
+	 * in {@code java.sql.ResultSet} to {@link OAuth2AuthorizedClient}.
+	 */
+	public static class OAuth2AuthorizedClientRowMapper implements RowMapper<OAuth2AuthorizedClient> {
+		protected final ClientRegistrationRepository clientRegistrationRepository;
+
+		public OAuth2AuthorizedClientRowMapper(ClientRegistrationRepository clientRegistrationRepository) {
+			Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
+			this.clientRegistrationRepository = clientRegistrationRepository;
+		}
+
+		@Override
+		public OAuth2AuthorizedClient mapRow(ResultSet rs, int rowNum) throws SQLException {
+			String clientRegistrationId = rs.getString("client_registration_id");
+			ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(
+					clientRegistrationId);
+			if (clientRegistration == null) {
+				throw new DataRetrievalFailureException("The ClientRegistration with id '" +
+						clientRegistrationId + "' exists in the data source, " +
+						"however, it was not found in the ClientRegistrationRepository.");
+			}
+
+			OAuth2AccessToken.TokenType tokenType = null;
+			if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(
+					rs.getString("access_token_type"))) {
+				tokenType = OAuth2AccessToken.TokenType.BEARER;
+			}
+			String tokenValue = new String(rs.getBytes("access_token_value"), StandardCharsets.UTF_8);
+			Instant issuedAt = rs.getTimestamp("access_token_issued_at").toInstant();
+			Instant expiresAt = rs.getTimestamp("access_token_expires_at").toInstant();
+			Set<String> scopes = Collections.emptySet();
+			String accessTokenScopes = rs.getString("access_token_scopes");
+			if (accessTokenScopes != null) {
+				scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes);
+			}
+			OAuth2AccessToken accessToken = new OAuth2AccessToken(
+					tokenType, tokenValue, issuedAt, expiresAt, scopes);
+
+			OAuth2RefreshToken refreshToken = null;
+			byte[] refreshTokenValue = rs.getBytes("refresh_token_value");
+			if (refreshTokenValue != null) {
+				tokenValue = new String(refreshTokenValue, StandardCharsets.UTF_8);
+				issuedAt = null;
+				Timestamp refreshTokenIssuedAt = rs.getTimestamp("refresh_token_issued_at");
+				if (refreshTokenIssuedAt != null) {
+					issuedAt = refreshTokenIssuedAt.toInstant();
+				}
+				refreshToken = new OAuth2RefreshToken(tokenValue, issuedAt);
+			}
+
+			String principalName = rs.getString("principal_name");
+
+			return new OAuth2AuthorizedClient(
+					clientRegistration, principalName, accessToken, refreshToken);
+		}
+	}
+
+	/**
+	 * The default {@code Function} that maps {@link OAuth2AuthorizedClientHolder}
+	 * to a {@code List} of {@link SqlParameterValue}.
+	 */
+	public static class OAuth2AuthorizedClientParametersMapper implements Function<OAuth2AuthorizedClientHolder, List<SqlParameterValue>> {
+
+		@Override
+		public List<SqlParameterValue> apply(OAuth2AuthorizedClientHolder authorizedClientHolder) {
+			OAuth2AuthorizedClient authorizedClient = authorizedClientHolder.getAuthorizedClient();
+			Authentication principal = authorizedClientHolder.getPrincipal();
+			ClientRegistration clientRegistration = authorizedClient.getClientRegistration();
+			OAuth2AccessToken accessToken = authorizedClient.getAccessToken();
+			OAuth2RefreshToken refreshToken = authorizedClient.getRefreshToken();
+
+			List<SqlParameterValue> parameters = new ArrayList<>();
+			parameters.add(new SqlParameterValue(
+					Types.VARCHAR, clientRegistration.getRegistrationId()));
+			parameters.add(new SqlParameterValue(
+					Types.VARCHAR, principal.getName()));
+			parameters.add(new SqlParameterValue(
+					Types.VARCHAR, accessToken.getTokenType().getValue()));
+			parameters.add(new SqlParameterValue(
+					Types.BLOB, accessToken.getTokenValue().getBytes(StandardCharsets.UTF_8)));
+			parameters.add(new SqlParameterValue(
+					Types.TIMESTAMP, Timestamp.from(accessToken.getIssuedAt())));
+			parameters.add(new SqlParameterValue(
+					Types.TIMESTAMP, Timestamp.from(accessToken.getExpiresAt())));
+			String accessTokenScopes = null;
+			if (!CollectionUtils.isEmpty(accessToken.getScopes())) {
+				accessTokenScopes = StringUtils.collectionToDelimitedString(accessToken.getScopes(), ",");
+			}
+			parameters.add(new SqlParameterValue(
+					Types.VARCHAR, accessTokenScopes));
+			byte[] refreshTokenValue = null;
+			Timestamp refreshTokenIssuedAt = null;
+			if (refreshToken != null) {
+				refreshTokenValue = refreshToken.getTokenValue().getBytes(StandardCharsets.UTF_8);
+				if (refreshToken.getIssuedAt() != null) {
+					refreshTokenIssuedAt = Timestamp.from(refreshToken.getIssuedAt());
+				}
+			}
+			parameters.add(new SqlParameterValue(
+					Types.BLOB, refreshTokenValue));
+			parameters.add(new SqlParameterValue(
+					Types.TIMESTAMP, refreshTokenIssuedAt));
+
+			return parameters;
+		}
+	}
+
+	/**
+	 * A holder for an {@link OAuth2AuthorizedClient} and End-User {@link Authentication} (Resource Owner).
+	 */
+	public static final class OAuth2AuthorizedClientHolder {
+		private final OAuth2AuthorizedClient authorizedClient;
+		private final Authentication principal;
+
+		/**
+		 * Constructs an {@code OAuth2AuthorizedClientHolder} using the provided parameters.
+		 *
+		 * @param authorizedClient the authorized client
+		 * @param principal the End-User {@link Authentication} (Resource Owner)
+		 */
+		public OAuth2AuthorizedClientHolder(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
+			Assert.notNull(authorizedClient, "authorizedClient cannot be null");
+			Assert.notNull(principal, "principal cannot be null");
+			this.authorizedClient = authorizedClient;
+			this.principal = principal;
+		}
+
+		/**
+		 * Returns the {@link OAuth2AuthorizedClient}.
+		 *
+		 * @return the {@link OAuth2AuthorizedClient}
+		 */
+		public OAuth2AuthorizedClient getAuthorizedClient() {
+			return this.authorizedClient;
+		}
+
+		/**
+		 * Returns the End-User {@link Authentication} (Resource Owner).
+		 *
+		 * @return the End-User {@link Authentication} (Resource Owner)
+		 */
+		public Authentication getPrincipal() {
+			return this.principal;
+		}
+	}
+}

+ 13 - 0
oauth2/oauth2-client/src/main/resources/org/springframework/security/oauth2/client/oauth2-client-schema.sql

@@ -0,0 +1,13 @@
+CREATE TABLE oauth2_authorized_client (
+  client_registration_id varchar(100) NOT NULL,
+  principal_name varchar(200) NOT NULL,
+  access_token_type varchar(100) NOT NULL,
+  access_token_value blob NOT NULL,
+  access_token_issued_at timestamp NOT NULL,
+  access_token_expires_at timestamp NOT NULL,
+  access_token_scopes varchar(1000) DEFAULT NULL,
+  refresh_token_value blob DEFAULT NULL,
+  refresh_token_issued_at timestamp DEFAULT NULL,
+  created_at timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
+  PRIMARY KEY (client_registration_id, principal_name)
+);

+ 474 - 0
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientServiceTests.java

@@ -0,0 +1,474 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.oauth2.client;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.dao.DataRetrievalFailureException;
+import org.springframework.dao.DuplicateKeyException;
+import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
+import org.springframework.jdbc.core.JdbcOperations;
+import org.springframework.jdbc.core.JdbcTemplate;
+import org.springframework.jdbc.core.PreparedStatementSetter;
+import org.springframework.jdbc.core.RowMapper;
+import org.springframework.jdbc.core.SqlParameterValue;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.oauth2.client.registration.ClientRegistration;
+import org.springframework.security.oauth2.client.registration.ClientRegistrationRepository;
+import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
+import org.springframework.security.oauth2.core.OAuth2AccessToken;
+import org.springframework.security.oauth2.core.OAuth2RefreshToken;
+import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
+import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+
+import java.nio.charset.StandardCharsets;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Timestamp;
+import java.sql.Types;
+import java.time.Instant;
+import java.util.Collections;
+import java.util.List;
+import java.util.Set;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyInt;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+/**
+ * Tests for {@link JdbcOAuth2AuthorizedClientService}.
+ *
+ * @author Joe Grandja
+ */
+public class JdbcOAuth2AuthorizedClientServiceTests {
+	private static final String OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/client/oauth2-client-schema.sql";
+	private static int principalId = 1000;
+	private ClientRegistration clientRegistration;
+	private ClientRegistrationRepository clientRegistrationRepository;
+	private EmbeddedDatabase db;
+	private JdbcOperations jdbcOperations;
+	private JdbcOAuth2AuthorizedClientService authorizedClientService;
+
+	@Before
+	public void setUp() {
+		this.clientRegistration = TestClientRegistrations.clientRegistration().build();
+		this.clientRegistrationRepository = mock(ClientRegistrationRepository.class);
+		when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(this.clientRegistration);
+		this.db = createDb();
+		this.jdbcOperations = new JdbcTemplate(this.db);
+		this.authorizedClientService = new JdbcOAuth2AuthorizedClientService(
+				this.jdbcOperations, this.clientRegistrationRepository);
+	}
+
+	@After
+	public void tearDown() {
+		this.db.shutdown();
+	}
+
+	@Test
+	public void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new JdbcOAuth2AuthorizedClientService(null, this.clientRegistrationRepository))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("jdbcOperations cannot be null");
+	}
+
+	@Test
+	public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> new JdbcOAuth2AuthorizedClientService(this.jdbcOperations, null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("clientRegistrationRepository cannot be null");
+	}
+
+	@Test
+	public void setAuthorizedClientRowMapperWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientService.setAuthorizedClientRowMapper(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizedClientRowMapper cannot be null");
+	}
+
+	@Test
+	public void setAuthorizedClientParametersMapperWhenNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientService.setAuthorizedClientParametersMapper(null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizedClientParametersMapper cannot be null");
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, "principalName"))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("clientRegistrationId cannot be empty");
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("principalName cannot be empty");
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenDoesNotExistThenReturnNull() {
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+				"registration-not-found", "principalName");
+		assertThat(authorizedClient).isNull();
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenExistsThenReturnAuthorizedClient() {
+		Authentication principal = createPrincipal();
+		OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration);
+
+		this.authorizedClientService.saveAuthorizedClient(expected, principal);
+
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+				this.clientRegistration.getRegistrationId(), principal.getName());
+
+		assertThat(authorizedClient).isNotNull();
+		assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration());
+		assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName());
+		assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
+		assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
+		assertThat(authorizedClient.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
+		assertThat(authorizedClient.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
+		assertThat(authorizedClient.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
+		assertThat(authorizedClient.getRefreshToken().getTokenValue()).isEqualTo(expected.getRefreshToken().getTokenValue());
+		assertThat(authorizedClient.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt());
+	}
+
+	@Test
+	public void loadAuthorizedClientWhenExistsButNotFoundInClientRegistrationRepositoryThenThrowDataRetrievalFailureException() {
+		when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(null);
+		Authentication principal = createPrincipal();
+		OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration);
+
+		this.authorizedClientService.saveAuthorizedClient(expected, principal);
+
+		assertThatThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), principal.getName()))
+				.isInstanceOf(DataRetrievalFailureException.class)
+				.hasMessage("The ClientRegistration with id '" + this.clientRegistration.getRegistrationId() +
+						"' exists in the data source, however, it was not found in the ClientRegistrationRepository.");
+	}
+
+	@Test
+	public void saveAuthorizedClientWhenAuthorizedClientIsNullThenThrowIllegalArgumentException() {
+		Authentication principal = createPrincipal();
+
+		assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(null, principal))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("authorizedClient cannot be null");
+	}
+
+	@Test
+	public void saveAuthorizedClientWhenPrincipalIsNullThenThrowIllegalArgumentException() {
+		Authentication principal = createPrincipal();
+		OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration);
+
+		assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("principal cannot be null");
+	}
+
+	@Test
+	public void saveAuthorizedClientWhenSaveThenLoadReturnsSaved() {
+		Authentication principal = createPrincipal();
+		OAuth2AuthorizedClient expected = createAuthorizedClient(principal, this.clientRegistration);
+
+		this.authorizedClientService.saveAuthorizedClient(expected, principal);
+
+		OAuth2AuthorizedClient authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+				this.clientRegistration.getRegistrationId(), principal.getName());
+
+		assertThat(authorizedClient).isNotNull();
+		assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration());
+		assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName());
+		assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
+		assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
+		assertThat(authorizedClient.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
+		assertThat(authorizedClient.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
+		assertThat(authorizedClient.getAccessToken().getScopes()).isEqualTo(expected.getAccessToken().getScopes());
+		assertThat(authorizedClient.getRefreshToken().getTokenValue()).isEqualTo(expected.getRefreshToken().getTokenValue());
+		assertThat(authorizedClient.getRefreshToken().getIssuedAt()).isEqualTo(expected.getRefreshToken().getIssuedAt());
+
+		// Test save/load of NOT NULL attributes only
+		principal = createPrincipal();
+		expected = createAuthorizedClient(principal, this.clientRegistration, true);
+
+		this.authorizedClientService.saveAuthorizedClient(expected, principal);
+
+		authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+				this.clientRegistration.getRegistrationId(), principal.getName());
+
+		assertThat(authorizedClient).isNotNull();
+		assertThat(authorizedClient.getClientRegistration()).isEqualTo(expected.getClientRegistration());
+		assertThat(authorizedClient.getPrincipalName()).isEqualTo(expected.getPrincipalName());
+		assertThat(authorizedClient.getAccessToken().getTokenType()).isEqualTo(expected.getAccessToken().getTokenType());
+		assertThat(authorizedClient.getAccessToken().getTokenValue()).isEqualTo(expected.getAccessToken().getTokenValue());
+		assertThat(authorizedClient.getAccessToken().getIssuedAt()).isEqualTo(expected.getAccessToken().getIssuedAt());
+		assertThat(authorizedClient.getAccessToken().getExpiresAt()).isEqualTo(expected.getAccessToken().getExpiresAt());
+		assertThat(authorizedClient.getAccessToken().getScopes()).isEmpty();
+		assertThat(authorizedClient.getRefreshToken()).isNull();
+	}
+
+	@Test
+	public void saveAuthorizedClientWhenSaveDuplicateThenThrowDuplicateKeyException() {
+		Authentication principal = createPrincipal();
+		OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration);
+
+		this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
+
+		assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal))
+				.isInstanceOf(DuplicateKeyException.class);
+	}
+
+	@Test
+	public void saveLoadAuthorizedClientWhenCustomStrategiesSetThenCalled() throws Exception {
+		JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientRowMapper authorizedClientRowMapper =
+				spy(new JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientRowMapper(this.clientRegistrationRepository));
+		this.authorizedClientService.setAuthorizedClientRowMapper(authorizedClientRowMapper);
+		JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientParametersMapper authorizedClientParametersMapper =
+				spy(new JdbcOAuth2AuthorizedClientService.OAuth2AuthorizedClientParametersMapper());
+		this.authorizedClientService.setAuthorizedClientParametersMapper(authorizedClientParametersMapper);
+
+		Authentication principal = createPrincipal();
+		OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration);
+
+		this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
+		this.authorizedClientService.loadAuthorizedClient(
+				this.clientRegistration.getRegistrationId(), principal.getName());
+
+		verify(authorizedClientRowMapper).mapRow(any(), anyInt());
+		verify(authorizedClientParametersMapper).apply(any());
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(null, "principalName"))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("clientRegistrationId cannot be empty");
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
+		assertThatThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(this.clientRegistration.getRegistrationId(), null))
+				.isInstanceOf(IllegalArgumentException.class)
+				.hasMessage("principalName cannot be empty");
+	}
+
+	@Test
+	public void removeAuthorizedClientWhenExistsThenRemoved() {
+		Authentication principal = createPrincipal();
+		OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration);
+
+		this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
+
+		authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+				this.clientRegistration.getRegistrationId(), principal.getName());
+		assertThat(authorizedClient).isNotNull();
+
+		this.authorizedClientService.removeAuthorizedClient(
+				this.clientRegistration.getRegistrationId(), principal.getName());
+
+		authorizedClient = this.authorizedClientService.loadAuthorizedClient(
+				this.clientRegistration.getRegistrationId(), principal.getName());
+		assertThat(authorizedClient).isNull();
+	}
+
+	@Test
+	public void tableDefinitionWhenCustomThenAbleToOverride() {
+		CustomTableDefinitionJdbcOAuth2AuthorizedClientService customAuthorizedClientService =
+				new CustomTableDefinitionJdbcOAuth2AuthorizedClientService(
+						new JdbcTemplate(createDb("custom-oauth2-client-schema.sql")),
+						this.clientRegistrationRepository);
+
+		Authentication principal = createPrincipal();
+		OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration);
+
+		customAuthorizedClientService.saveAuthorizedClient(authorizedClient, principal);
+
+		authorizedClient = customAuthorizedClientService.loadAuthorizedClient(
+				this.clientRegistration.getRegistrationId(), principal.getName());
+		assertThat(authorizedClient).isNotNull();
+
+		customAuthorizedClientService.removeAuthorizedClient(
+				this.clientRegistration.getRegistrationId(), principal.getName());
+
+		authorizedClient = customAuthorizedClientService.loadAuthorizedClient(
+				this.clientRegistration.getRegistrationId(), principal.getName());
+		assertThat(authorizedClient).isNull();
+	}
+
+	private static EmbeddedDatabase createDb() {
+		return createDb(OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE);
+	}
+
+	private static EmbeddedDatabase createDb(String schema) {
+		return new EmbeddedDatabaseBuilder()
+				.generateUniqueName(true)
+				.setType(EmbeddedDatabaseType.HSQL)
+				.setScriptEncoding("UTF-8")
+				.addScript(schema)
+				.build();
+	}
+
+	private static Authentication createPrincipal() {
+		return new TestingAuthenticationToken("principal-" + principalId++, "password");
+	}
+
+	private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal, ClientRegistration clientRegistration) {
+		return createAuthorizedClient(principal, clientRegistration, false);
+	}
+
+	private static OAuth2AuthorizedClient createAuthorizedClient(Authentication principal,
+			ClientRegistration clientRegistration, boolean requiredAttributesOnly) {
+		OAuth2AccessToken accessToken;
+		if (!requiredAttributesOnly) {
+			accessToken = TestOAuth2AccessTokens.scopes("read", "write");
+		} else {
+			accessToken = TestOAuth2AccessTokens.noScopes();
+		}
+		OAuth2RefreshToken refreshToken = null;
+		if (!requiredAttributesOnly) {
+			refreshToken = TestOAuth2RefreshTokens.refreshToken();
+		}
+		return new OAuth2AuthorizedClient(
+				clientRegistration, principal.getName(), accessToken, refreshToken);
+	}
+
+	private static class CustomTableDefinitionJdbcOAuth2AuthorizedClientService extends JdbcOAuth2AuthorizedClientService {
+		private static final String COLUMN_NAMES =
+				"clientRegistrationId, " +
+				"principalName, " +
+				"accessTokenType, " +
+				"accessTokenValue, " +
+				"accessTokenIssuedAt, " +
+				"accessTokenExpiresAt, " +
+				"accessTokenScopes, " +
+				"refreshTokenValue, " +
+				"refreshTokenIssuedAt";
+		private static final String TABLE_NAME = "oauth2AuthorizedClient";
+		private static final String PK_FILTER = "clientRegistrationId = ? AND principalName = ?";
+		private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES +
+				" FROM " + TABLE_NAME + " WHERE " + PK_FILTER;
+		private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME +
+				" (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
+		private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME +
+				" WHERE " + PK_FILTER;
+
+		private CustomTableDefinitionJdbcOAuth2AuthorizedClientService(
+				JdbcOperations jdbcOperations, ClientRegistrationRepository clientRegistrationRepository) {
+			super(jdbcOperations, clientRegistrationRepository);
+			setAuthorizedClientRowMapper(new OAuth2AuthorizedClientRowMapper(clientRegistrationRepository));
+		}
+
+		@Override
+		@SuppressWarnings("unchecked")
+		public <T extends OAuth2AuthorizedClient> T loadAuthorizedClient(String clientRegistrationId, String principalName) {
+			SqlParameterValue[] parameters = new SqlParameterValue[] {
+					new SqlParameterValue(Types.VARCHAR, clientRegistrationId),
+					new SqlParameterValue(Types.VARCHAR, principalName)
+			};
+			PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
+			List<OAuth2AuthorizedClient> result = this.jdbcOperations.query(
+					LOAD_AUTHORIZED_CLIENT_SQL, pss, this.authorizedClientRowMapper);
+			return !result.isEmpty() ? (T) result.get(0) : null;
+		}
+
+		@Override
+		public void saveAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
+			List<SqlParameterValue> parameters = this.authorizedClientParametersMapper.apply(
+					new OAuth2AuthorizedClientHolder(authorizedClient, principal));
+			PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
+			this.jdbcOperations.update(SAVE_AUTHORIZED_CLIENT_SQL, pss);
+		}
+
+		@Override
+		public void removeAuthorizedClient(String clientRegistrationId, String principalName) {
+			SqlParameterValue[] parameters = new SqlParameterValue[] {
+					new SqlParameterValue(Types.VARCHAR, clientRegistrationId),
+					new SqlParameterValue(Types.VARCHAR, principalName)
+			};
+			PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
+			this.jdbcOperations.update(REMOVE_AUTHORIZED_CLIENT_SQL, pss);
+		}
+
+		private static class OAuth2AuthorizedClientRowMapper implements RowMapper<OAuth2AuthorizedClient> {
+			private final ClientRegistrationRepository clientRegistrationRepository;
+
+			private OAuth2AuthorizedClientRowMapper(ClientRegistrationRepository clientRegistrationRepository) {
+				Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
+				this.clientRegistrationRepository = clientRegistrationRepository;
+			}
+
+			@Override
+			public OAuth2AuthorizedClient mapRow(ResultSet rs, int rowNum) throws SQLException {
+				String clientRegistrationId = rs.getString("clientRegistrationId");
+				ClientRegistration clientRegistration = this.clientRegistrationRepository.findByRegistrationId(
+						clientRegistrationId);
+				if (clientRegistration == null) {
+					throw new DataRetrievalFailureException("The ClientRegistration with id '" +
+							clientRegistrationId + "' exists in the data source, " +
+							"however, it was not found in the ClientRegistrationRepository.");
+				}
+
+				OAuth2AccessToken.TokenType tokenType = null;
+				if (OAuth2AccessToken.TokenType.BEARER.getValue().equalsIgnoreCase(
+						rs.getString("accessTokenType"))) {
+					tokenType = OAuth2AccessToken.TokenType.BEARER;
+				}
+				String tokenValue = new String(rs.getBytes("accessTokenValue"), StandardCharsets.UTF_8);
+				Instant issuedAt = rs.getTimestamp("accessTokenIssuedAt").toInstant();
+				Instant expiresAt = rs.getTimestamp("accessTokenExpiresAt").toInstant();
+				Set<String> scopes = Collections.emptySet();
+				String accessTokenScopes = rs.getString("accessTokenScopes");
+				if (accessTokenScopes != null) {
+					scopes = StringUtils.commaDelimitedListToSet(accessTokenScopes);
+				}
+				OAuth2AccessToken accessToken = new OAuth2AccessToken(
+						tokenType, tokenValue, issuedAt, expiresAt, scopes);
+
+				OAuth2RefreshToken refreshToken = null;
+				byte[] refreshTokenValue = rs.getBytes("refreshTokenValue");
+				if (refreshTokenValue != null) {
+					tokenValue = new String(refreshTokenValue, StandardCharsets.UTF_8);
+					issuedAt = null;
+					Timestamp refreshTokenIssuedAt = rs.getTimestamp("refreshTokenIssuedAt");
+					if (refreshTokenIssuedAt != null) {
+						issuedAt = refreshTokenIssuedAt.toInstant();
+					}
+					refreshToken = new OAuth2RefreshToken(tokenValue, issuedAt);
+				}
+
+				String principalName = rs.getString("principalName");
+
+				return new OAuth2AuthorizedClient(
+						clientRegistration, principalName, accessToken, refreshToken);
+			}
+		}
+	}
+}

+ 13 - 0
oauth2/oauth2-client/src/test/resources/custom-oauth2-client-schema.sql

@@ -0,0 +1,13 @@
+CREATE TABLE oauth2AuthorizedClient (
+  clientRegistrationId varchar(100) NOT NULL,
+  principalName varchar(200) NOT NULL,
+  accessTokenType varchar(100) NOT NULL,
+  accessTokenValue blob NOT NULL,
+  accessTokenIssuedAt timestamp NOT NULL,
+  accessTokenExpiresAt timestamp NOT NULL,
+  accessTokenScopes varchar(1000) DEFAULT NULL,
+  refreshTokenValue blob DEFAULT NULL,
+  refreshTokenIssuedAt timestamp DEFAULT NULL,
+  createdAt timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
+  PRIMARY KEY (clientRegistrationId, principalName)
+);