فهرست منبع

Polish gh-291

Steve Riesenberg 4 سال پیش
والد
کامیت
763ef2224b

+ 97 - 39
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java

@@ -15,9 +15,31 @@
  */
 package org.springframework.security.oauth2.server.authorization.client;
 
+import java.nio.charset.StandardCharsets;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Timestamp;
+import java.sql.Types;
+import java.time.Duration;
+import java.time.Instant;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.function.Function;
+
 import com.fasterxml.jackson.core.JsonProcessingException;
 import com.fasterxml.jackson.databind.ObjectMapper;
-import org.springframework.jdbc.core.*;
+
+import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
+import org.springframework.jdbc.core.JdbcOperations;
+import org.springframework.jdbc.core.PreparedStatementSetter;
+import org.springframework.jdbc.core.RowMapper;
+import org.springframework.jdbc.core.SqlParameterValue;
 import org.springframework.jdbc.support.lob.DefaultLobHandler;
 import org.springframework.jdbc.support.lob.LobCreator;
 import org.springframework.jdbc.support.lob.LobHandler;
@@ -26,14 +48,7 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
 import org.springframework.security.oauth2.server.authorization.config.TokenSettings;
 import org.springframework.util.Assert;
-
-import java.nio.charset.StandardCharsets;
-import java.sql.*;
-import java.time.Duration;
-import java.time.Instant;
-import java.util.*;
-import java.util.function.Function;
-import java.util.stream.Collectors;
+import org.springframework.util.StringUtils;
 
 /**
  * JDBC-backed registered client repository
@@ -72,17 +87,44 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 
 	private final JdbcOperations jdbcOperations;
 
-	private final LobHandler lobHandler = new DefaultLobHandler();
+	private final LobHandler lobHandler;
 
-	private final ObjectMapper objectMapper;
+	/**
+	 * Constructs a {@code JdbcRegisteredClientRepository} using the provided parameters.
+	 *
+	 * @param jdbcOperations the JDBC operations
+	 */
+	public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations) {
+		this(jdbcOperations, new ObjectMapper());
+	}
 
+	/**
+	 * Constructs a {@code JdbcRegisteredClientRepository} using the provided parameters.
+	 *
+	 * @param jdbcOperations the JDBC operations
+	 * @param objectMapper the object mapper
+	 */
 	public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations, ObjectMapper objectMapper) {
+		this(jdbcOperations, new DefaultLobHandler(), objectMapper);
+	}
+
+	/**
+	 * Constructs a {@code JdbcRegisteredClientRepository} using the provided parameters.
+	 *
+	 * @param jdbcOperations the JDBC operations
+	 * @param lobHandler the handler for large binary fields and large text fields
+	 * @param objectMapper the object mapper
+	 */
+	public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations, LobHandler lobHandler, ObjectMapper objectMapper) {
 		Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
+		Assert.notNull(lobHandler, "lobHandler cannot be null");
 		Assert.notNull(objectMapper, "objectMapper cannot be null");
 		this.jdbcOperations = jdbcOperations;
-		this.objectMapper = objectMapper;
-		this.registeredClientRowMapper = new DefaultRegisteredClientRowMapper();
-		this.registeredClientParametersMapper = new DefaultRegisteredClientParametersMapper();
+		this.lobHandler = lobHandler;
+		DefaultRegisteredClientRowMapper registeredClientRowMapper = new DefaultRegisteredClientRowMapper(objectMapper);
+		registeredClientRowMapper.setLobHandler(lobHandler);
+		this.registeredClientRowMapper = registeredClientRowMapper;
+		this.registeredClientParametersMapper = new DefaultRegisteredClientParametersMapper(objectMapper);
 	}
 
 	/**
@@ -148,23 +190,27 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 		return !lst.isEmpty() ? lst.get(0) : null;
 	}
 
-	private class DefaultRegisteredClientRowMapper implements RowMapper<RegisteredClient> {
+	public static class DefaultRegisteredClientRowMapper implements RowMapper<RegisteredClient> {
+
+		private final ObjectMapper objectMapper;
 
-		private final LobHandler lobHandler = new DefaultLobHandler();
+		private LobHandler lobHandler = new DefaultLobHandler();
 
-		private Collection<String> parseList(String s) {
-			return s != null ? Arrays.asList(s.split("\\|")) : Collections.emptyList();
+		public DefaultRegisteredClientRowMapper(ObjectMapper objectMapper) {
+			this.objectMapper = objectMapper;
+		}
+
+		private Set<String> parseList(String s) {
+			return s != null ? StringUtils.commaDelimitedListToSet(s) : Collections.emptySet();
 		}
 
 		@Override
 		@SuppressWarnings("unchecked")
 		public RegisteredClient mapRow(ResultSet rs, int rowNum) throws SQLException {
-			Collection<String> scopes = parseList(rs.getString("scopes"));
-			List<AuthorizationGrantType> authGrantTypes = parseList(rs.getString("authorization_grant_types"))
-					.stream().map(AUTHORIZATION_GRANT_TYPE_MAP::get).collect(Collectors.toList());
-			List<ClientAuthenticationMethod> clientAuthMethods = parseList(rs.getString("client_authentication_methods"))
-					.stream().map(CLIENT_AUTHENTICATION_METHOD_MAP::get).collect(Collectors.toList());
-			Collection<String> redirectUris = parseList(rs.getString("redirect_uris"));
+			Set<String> scopes = parseList(rs.getString("scopes"));
+			Set<String> authGrantTypes = parseList(rs.getString("authorization_grant_types"));
+			Set<String> clientAuthMethods = parseList(rs.getString("client_authentication_methods"));
+			Set<String> redirectUris = parseList(rs.getString("redirect_uris"));
 			Timestamp clientIssuedAt = rs.getTimestamp("client_id_issued_at");
 			Timestamp clientSecretExpiresAt = rs.getTimestamp("client_secret_expires_at");
 			byte[] clientSecretBytes = this.lobHandler.getBlobAsBytes(rs, "client_secret");
@@ -176,8 +222,10 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 					.clientSecret(clientSecret)
 					.clientSecretExpiresAt(clientSecretExpiresAt != null ? clientSecretExpiresAt.toInstant() : null)
 					.clientName(rs.getString("client_name"))
-					.clientAuthenticationMethods(coll -> coll.addAll(clientAuthMethods))
-					.authorizationGrantTypes(coll -> coll.addAll(authGrantTypes))
+					.authorizationGrantTypes(coll -> authGrantTypes.forEach(authGrantType ->
+							coll.add(AUTHORIZATION_GRANT_TYPE_MAP.get(authGrantType))))
+					.clientAuthenticationMethods(coll -> clientAuthMethods.forEach(clientAuthMethod ->
+							coll.add(CLIENT_AUTHENTICATION_METHOD_MAP.get(clientAuthMethod))))
 					.redirectUris(coll -> coll.addAll(redirectUris))
 					.scopes(coll -> coll.addAll(scopes));
 
@@ -189,8 +237,7 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 			try {
 				String tokenSettingsJson = rs.getString("token_settings");
 				if (tokenSettingsJson != null) {
-
-					Map<String, Object> m = JdbcRegisteredClientRepository.this.objectMapper.readValue(tokenSettingsJson, Map.class);
+					Map<String, Object> m = this.objectMapper.readValue(tokenSettingsJson, Map.class);
 
 					Number accessTokenTTL = (Number) m.get("access_token_ttl");
 					if (accessTokenTTL != null) {
@@ -210,8 +257,7 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 
 				String clientSettingsJson = rs.getString("client_settings");
 				if (clientSettingsJson != null) {
-
-					Map<String, Object> m = JdbcRegisteredClientRepository.this.objectMapper.readValue(clientSettingsJson, Map.class);
+					Map<String, Object> m = this.objectMapper.readValue(clientSettingsJson, Map.class);
 
 					Boolean requireProofKey = (Boolean) m.get("require_proof_key");
 					if (requireProofKey != null) {
@@ -223,17 +269,28 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 						cs.requireUserConsent(requireUserConsent);
 					}
 				}
-
-
 			} catch (JsonProcessingException e) {
 				throw new IllegalArgumentException(e.getMessage(), e);
 			}
 
 			return rc;
 		}
+
+		public final void setLobHandler(LobHandler lobHandler) {
+			Assert.notNull(lobHandler, "lobHandler cannot be null");
+			this.lobHandler = lobHandler;
+		}
+
 	}
 
-	private class DefaultRegisteredClientParametersMapper implements Function<RegisteredClient, List<SqlParameterValue>> {
+	public static class DefaultRegisteredClientParametersMapper implements Function<RegisteredClient, List<SqlParameterValue>> {
+
+		private final ObjectMapper objectMapper;
+
+		private DefaultRegisteredClientParametersMapper(ObjectMapper objectMapper) {
+			this.objectMapper = objectMapper;
+		}
+
 		@Override
 		public List<SqlParameterValue> apply(RegisteredClient registeredClient) {
 			try {
@@ -256,13 +313,13 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 				Map<String, Object> clientSettings = new HashMap<>();
 				clientSettings.put("require_proof_key", registeredClient.getClientSettings().requireProofKey());
 				clientSettings.put("require_user_consent", registeredClient.getClientSettings().requireUserConsent());
-				String clientSettingsJson = JdbcRegisteredClientRepository.this.objectMapper.writeValueAsString(clientSettings);
+				String clientSettingsJson = this.objectMapper.writeValueAsString(clientSettings);
 
 				Map<String, Object> tokenSettings = new HashMap<>();
 				tokenSettings.put("access_token_ttl", registeredClient.getTokenSettings().accessTokenTimeToLive().toMillis());
 				tokenSettings.put("reuse_refresh_tokens", registeredClient.getTokenSettings().reuseRefreshTokens());
 				tokenSettings.put("refresh_token_ttl", registeredClient.getTokenSettings().refreshTokenTimeToLive().toMillis());
-				String tokenSettingsJson = JdbcRegisteredClientRepository.this.objectMapper.writeValueAsString(tokenSettings);
+				String tokenSettingsJson = this.objectMapper.writeValueAsString(tokenSettings);
 
 				return Arrays.asList(
 						new SqlParameterValue(Types.VARCHAR, registeredClient.getId()),
@@ -271,16 +328,17 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 						new SqlParameterValue(Types.BLOB, registeredClient.getClientSecret().getBytes(StandardCharsets.UTF_8)),
 						new SqlParameterValue(Types.TIMESTAMP, clientSecretExpiresAt),
 						new SqlParameterValue(Types.VARCHAR, registeredClient.getClientName()),
-						new SqlParameterValue(Types.VARCHAR, String.join("|", clientAuthenticationMethodNames)),
-						new SqlParameterValue(Types.VARCHAR, String.join("|", authorizationGrantTypeNames)),
-						new SqlParameterValue(Types.VARCHAR, String.join("|", registeredClient.getRedirectUris())),
-						new SqlParameterValue(Types.VARCHAR, String.join("|", registeredClient.getScopes())),
+						new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(clientAuthenticationMethodNames)),
+						new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(authorizationGrantTypeNames)),
+						new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getRedirectUris())),
+						new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getScopes())),
 						new SqlParameterValue(Types.VARCHAR, clientSettingsJson),
 						new SqlParameterValue(Types.VARCHAR, tokenSettingsJson));
 			} catch (JsonProcessingException e) {
 				throw new IllegalArgumentException(e.getMessage(), e);
 			}
 		}
+
 	}
 
 	private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {

+ 0 - 0
oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2_registered_client.sql → oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client.sql


+ 19 - 8
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java

@@ -15,22 +15,24 @@
  */
 package org.springframework.security.oauth2.server.authorization.client;
 
+import java.io.InputStream;
+import java.nio.charset.Charset;
+import java.time.Duration;
+import java.time.Instant;
+
 import com.fasterxml.jackson.databind.ObjectMapper;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
+
 import org.springframework.jdbc.core.JdbcTemplate;
 import org.springframework.jdbc.datasource.DriverManagerDataSource;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.util.StreamUtils;
 
-import java.io.InputStream;
-import java.nio.charset.Charset;
-import java.time.Duration;
-import java.time.Instant;
-
-import static org.assertj.core.api.Assertions.*;
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 
 /**
  * JDBC-backed registered client repository tests
@@ -40,7 +42,7 @@ import static org.assertj.core.api.Assertions.*;
  */
 public class JdbcRegisteredClientRepositoryTests {
 
-	private final String SCRIPT = "/org/springframework/security/oauth2/server/authorization/client/oauth2_registered_client.sql";
+	private final String SCRIPT = "/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client.sql";
 
 	private DriverManagerDataSource dataSource;
 
@@ -71,7 +73,7 @@ public class JdbcRegisteredClientRepositoryTests {
 			}
 		}
 
-		this.clients = new JdbcRegisteredClientRepository(this.jdbc, new ObjectMapper());
+		this.clients = new JdbcRegisteredClientRepository(this.jdbc);
 		this.registration = TestRegisteredClients.registeredClient().build();
 
 		this.clients.save(this.registration);
@@ -101,6 +103,15 @@ public class JdbcRegisteredClientRepositoryTests {
 		// @formatter:on
 	}
 
+	@Test
+	public void whenLobHandlerNullThenThrow() {
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new JdbcRegisteredClientRepository(this.jdbc, null, new ObjectMapper()))
+				.withMessage("lobHandler cannot be null");
+		// @formatter:on
+	}
+
 	@Test
 	public void whenSetNullRegisteredClientRowMapperThenThrow() {
 		// @formatter:off