2
0
Эх сурвалжийг харах

Support update when saving with JdbcOAuth2AuthorizedClientService

Before this commit, JdbcOAuth2AuthorizedClientService threw DuplicateKeyException when re-authorizing or when authorizing the same user from a different client.

This commit makes JdbcOAuth2AuthorizedClientService's saveAuthorizedClient method consistent with that of InMemoryOAuth2AuthorizedClientService.

Fixes gh-8425
Stav Shamir 5 жил өмнө
parent
commit
6f2359ccae

+ 36 - 0
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientService.java

@@ -16,6 +16,7 @@
 package org.springframework.security.oauth2.client;
 
 import org.springframework.dao.DataRetrievalFailureException;
+import org.springframework.dao.DuplicateKeyException;
 import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
 import org.springframework.jdbc.core.JdbcOperations;
 import org.springframework.jdbc.core.PreparedStatementSetter;
@@ -52,6 +53,7 @@ import java.util.function.Function;
  * and therefore MUST be defined in the database schema.
  *
  * @author Joe Grandja
+ * @author Stav Shamir
  * @since 5.3
  * @see OAuth2AuthorizedClientService
  * @see OAuth2AuthorizedClient
@@ -77,6 +79,11 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient
 			" (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
 	private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME +
 			" WHERE " + PK_FILTER;
+	private static final String UPDATE_AUTHORIZED_CLIENT_SQL = "UPDATE " + TABLE_NAME +
+			" SET access_token_type = ?, access_token_value = ?, access_token_issued_at = ?," +
+			" access_token_expires_at = ?, access_token_scopes = ?," +
+			" refresh_token_value = ?, refresh_token_issued_at = ?" +
+			" WHERE " + PK_FILTER;
 	protected final JdbcOperations jdbcOperations;
 	protected RowMapper<OAuth2AuthorizedClient> authorizedClientRowMapper;
 	protected Function<OAuth2AuthorizedClientHolder, List<SqlParameterValue>> authorizedClientParametersMapper;
@@ -120,6 +127,35 @@ public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClient
 		Assert.notNull(authorizedClient, "authorizedClient cannot be null");
 		Assert.notNull(principal, "principal cannot be null");
 
+		boolean existsAuthorizedClient = null != this.loadAuthorizedClient(
+				authorizedClient.getClientRegistration().getRegistrationId(), principal.getName());
+
+		if (existsAuthorizedClient) {
+			updateAuthorizedClient(authorizedClient, principal);
+		} else {
+			try {
+				insertAuthorizedClient(authorizedClient, principal);
+			} catch (DuplicateKeyException e) {
+				updateAuthorizedClient(authorizedClient, principal);
+			}
+		}
+	}
+
+	private void updateAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
+		List<SqlParameterValue> parameters = this.authorizedClientParametersMapper.apply(
+				new OAuth2AuthorizedClientHolder(authorizedClient, principal));
+
+		SqlParameterValue clientRegistrationIdParameter = parameters.remove(0);
+		SqlParameterValue principalNameParameter = parameters.remove(0);
+		parameters.add(clientRegistrationIdParameter);
+		parameters.add(principalNameParameter);
+
+		PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
+
+		this.jdbcOperations.update(UPDATE_AUTHORIZED_CLIENT_SQL, pss);
+	}
+
+	private void insertAuthorizedClient(OAuth2AuthorizedClient authorizedClient, Authentication principal) {
 		List<SqlParameterValue> parameters = this.authorizedClientParametersMapper.apply(
 				new OAuth2AuthorizedClientHolder(authorizedClient, principal));
 		PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());

+ 21 - 5
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/JdbcOAuth2AuthorizedClientServiceTests.java

@@ -19,7 +19,6 @@ import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.springframework.dao.DataRetrievalFailureException;
-import org.springframework.dao.DuplicateKeyException;
 import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
 import org.springframework.jdbc.core.JdbcOperations;
 import org.springframework.jdbc.core.JdbcTemplate;
@@ -64,6 +63,7 @@ import static org.mockito.Mockito.when;
  * Tests for {@link JdbcOAuth2AuthorizedClientService}.
  *
  * @author Joe Grandja
+ * @author Stav Shamir
  */
 public class JdbcOAuth2AuthorizedClientServiceTests {
 	private static final String OAUTH2_CLIENT_SCHEMA_SQL_RESOURCE = "org/springframework/security/oauth2/client/oauth2-client-schema.sql";
@@ -236,14 +236,30 @@ public class JdbcOAuth2AuthorizedClientServiceTests {
 	}
 
 	@Test
-	public void saveAuthorizedClientWhenSaveDuplicateThenThrowDuplicateKeyException() {
+	public void saveAuthorizedClientWhenSaveClientWithExistingPrimaryKeyThenUpdate() {
+		// Given a saved authorized client
 		Authentication principal = createPrincipal();
 		OAuth2AuthorizedClient authorizedClient = createAuthorizedClient(principal, this.clientRegistration);
-
 		this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal);
 
-		assertThatThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, principal))
-				.isInstanceOf(DuplicateKeyException.class);
+		// When a client with the same principal and registration id is saved
+		OAuth2AuthorizedClient updatedClient = createAuthorizedClient(principal, this.clientRegistration);
+		this.authorizedClientService.saveAuthorizedClient(updatedClient, principal);
+
+		// Then the saved client is updated
+		OAuth2AuthorizedClient savedClient = this.authorizedClientService.loadAuthorizedClient(
+				this.clientRegistration.getRegistrationId(), principal.getName());
+
+		assertThat(savedClient).isNotNull();
+		assertThat(savedClient.getClientRegistration()).isEqualTo(updatedClient.getClientRegistration());
+		assertThat(savedClient.getPrincipalName()).isEqualTo(updatedClient.getPrincipalName());
+		assertThat(savedClient.getAccessToken().getTokenType()).isEqualTo(updatedClient.getAccessToken().getTokenType());
+		assertThat(savedClient.getAccessToken().getTokenValue()).isEqualTo(updatedClient.getAccessToken().getTokenValue());
+		assertThat(savedClient.getAccessToken().getIssuedAt()).isEqualTo(updatedClient.getAccessToken().getIssuedAt());
+		assertThat(savedClient.getAccessToken().getExpiresAt()).isEqualTo(updatedClient.getAccessToken().getExpiresAt());
+		assertThat(savedClient.getAccessToken().getScopes()).isEqualTo(updatedClient.getAccessToken().getScopes());
+		assertThat(savedClient.getRefreshToken().getTokenValue()).isEqualTo(updatedClient.getRefreshToken().getTokenValue());
+		assertThat(savedClient.getRefreshToken().getIssuedAt()).isEqualTo(updatedClient.getRefreshToken().getIssuedAt());
 	}
 
 	@Test