소스 검색

Fix NPE saving public client

Closes gh-326
Steve Riesenberg 4 년 전
부모
커밋
67e62a2f21

+ 8 - 72
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java

@@ -15,8 +15,6 @@
  */
 package org.springframework.security.oauth2.server.authorization.client;
 
-import java.nio.charset.StandardCharsets;
-import java.sql.PreparedStatement;
 import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.sql.Timestamp;
@@ -40,9 +38,6 @@ import org.springframework.jdbc.core.JdbcOperations;
 import org.springframework.jdbc.core.PreparedStatementSetter;
 import org.springframework.jdbc.core.RowMapper;
 import org.springframework.jdbc.core.SqlParameterValue;
-import org.springframework.jdbc.support.lob.DefaultLobHandler;
-import org.springframework.jdbc.support.lob.LobCreator;
-import org.springframework.jdbc.support.lob.LobHandler;
 import org.springframework.security.oauth2.core.AuthorizationGrantType;
 import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
 import org.springframework.security.oauth2.server.authorization.config.ClientSettings;
@@ -87,8 +82,6 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 
 	private final JdbcOperations jdbcOperations;
 
-	private final LobHandler lobHandler;
-
 	/**
 	 * Constructs a {@code JdbcRegisteredClientRepository} using the provided parameters.
 	 *
@@ -105,25 +98,10 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 	 * @param objectMapper the object mapper
 	 */
 	public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations, ObjectMapper objectMapper) {
-		this(jdbcOperations, new DefaultLobHandler(), objectMapper);
-	}
-
-	/**
-	 * Constructs a {@code JdbcRegisteredClientRepository} using the provided parameters.
-	 *
-	 * @param jdbcOperations the JDBC operations
-	 * @param lobHandler the handler for large binary fields and large text fields
-	 * @param objectMapper the object mapper
-	 */
-	public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations, LobHandler lobHandler, ObjectMapper objectMapper) {
 		Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
-		Assert.notNull(lobHandler, "lobHandler cannot be null");
 		Assert.notNull(objectMapper, "objectMapper cannot be null");
 		this.jdbcOperations = jdbcOperations;
-		this.lobHandler = lobHandler;
-		DefaultRegisteredClientRowMapper registeredClientRowMapper = new DefaultRegisteredClientRowMapper(objectMapper);
-		registeredClientRowMapper.setLobHandler(lobHandler);
-		this.registeredClientRowMapper = registeredClientRowMapper;
+		this.registeredClientRowMapper = new DefaultRegisteredClientRowMapper(objectMapper);
 		this.registeredClientParametersMapper = new DefaultRegisteredClientParametersMapper(objectMapper);
 	}
 
@@ -150,25 +128,19 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 	@Override
 	public void save(RegisteredClient registeredClient) {
 		Assert.notNull(registeredClient, "registeredClient cannot be null");
-		RegisteredClient foundClient = this.findBy("id = ? OR client_id = ? OR client_secret = ?",
-				registeredClient.getId(), registeredClient.getClientId(),
-				registeredClient.getClientSecret().getBytes(StandardCharsets.UTF_8));
+		RegisteredClient foundClient = this.findBy("id = ? OR client_id = ?",
+				registeredClient.getId(), registeredClient.getClientId());
 
 		if (null != foundClient) {
 			Assert.isTrue(!foundClient.getId().equals(registeredClient.getId()),
 					"Registered client must be unique. Found duplicate identifier: " + registeredClient.getId());
 			Assert.isTrue(!foundClient.getClientId().equals(registeredClient.getClientId()),
 					"Registered client must be unique. Found duplicate client identifier: " + registeredClient.getClientId());
-			Assert.isTrue(!foundClient.getClientSecret().equals(registeredClient.getClientSecret()),
-					"Registered client must be unique. Found duplicate client secret for identifier: " + registeredClient.getId());
 		}
 
 		List<SqlParameterValue> parameters = this.registeredClientParametersMapper.apply(registeredClient);
-
-		try (LobCreator lobCreator = this.lobHandler.getLobCreator()) {
-			PreparedStatementSetter pss = new LobCreatorArgumentPreparedStatementSetter(lobCreator, parameters.toArray());
-			jdbcOperations.update(INSERT_REGISTERED_CLIENT_SQL, pss);
-		}
+		PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters.toArray());
+		this.jdbcOperations.update(INSERT_REGISTERED_CLIENT_SQL, pss);
 	}
 
 	@Override
@@ -184,7 +156,7 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 	}
 
 	private RegisteredClient findBy(String condStr, Object...args) {
-		List<RegisteredClient> lst = jdbcOperations.query(
+		List<RegisteredClient> lst = this.jdbcOperations.query(
 				LOAD_REGISTERED_CLIENT_SQL + condStr,
 				registeredClientRowMapper, args);
 		return !lst.isEmpty() ? lst.get(0) : null;
@@ -194,8 +166,6 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 
 		private final ObjectMapper objectMapper;
 
-		private LobHandler lobHandler = new DefaultLobHandler();
-
 		public DefaultRegisteredClientRowMapper(ObjectMapper objectMapper) {
 			this.objectMapper = objectMapper;
 		}
@@ -213,8 +183,7 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 			Set<String> redirectUris = parseList(rs.getString("redirect_uris"));
 			Timestamp clientIssuedAt = rs.getTimestamp("client_id_issued_at");
 			Timestamp clientSecretExpiresAt = rs.getTimestamp("client_secret_expires_at");
-			byte[] clientSecretBytes = this.lobHandler.getBlobAsBytes(rs, "client_secret");
-			String clientSecret = clientSecretBytes != null ? new String(clientSecretBytes, StandardCharsets.UTF_8) : null;
+			String clientSecret = rs.getString("client_secret");
 			RegisteredClient.Builder builder = RegisteredClient
 					.withId(rs.getString("id"))
 					.clientId(rs.getString("client_id"))
@@ -276,11 +245,6 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 			return rc;
 		}
 
-		public final void setLobHandler(LobHandler lobHandler) {
-			Assert.notNull(lobHandler, "lobHandler cannot be null");
-			this.lobHandler = lobHandler;
-		}
-
 	}
 
 	public static class DefaultRegisteredClientParametersMapper implements Function<RegisteredClient, List<SqlParameterValue>> {
@@ -325,7 +289,7 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 						new SqlParameterValue(Types.VARCHAR, registeredClient.getId()),
 						new SqlParameterValue(Types.VARCHAR, registeredClient.getClientId()),
 						new SqlParameterValue(Types.TIMESTAMP, Timestamp.from(issuedAt)),
-						new SqlParameterValue(Types.BLOB, registeredClient.getClientSecret().getBytes(StandardCharsets.UTF_8)),
+						new SqlParameterValue(Types.VARCHAR, registeredClient.getClientSecret()),
 						new SqlParameterValue(Types.TIMESTAMP, clientSecretExpiresAt),
 						new SqlParameterValue(Types.VARCHAR, registeredClient.getClientName()),
 						new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(clientAuthenticationMethodNames)),
@@ -341,34 +305,6 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 
 	}
 
-	private static final class LobCreatorArgumentPreparedStatementSetter extends ArgumentPreparedStatementSetter {
-
-		protected final LobCreator lobCreator;
-
-		private LobCreatorArgumentPreparedStatementSetter(LobCreator lobCreator, Object[] args) {
-			super(args);
-			this.lobCreator = lobCreator;
-		}
-
-		@Override
-		protected void doSetValue(PreparedStatement ps, int parameterPosition, Object argValue) throws SQLException {
-			if (argValue instanceof SqlParameterValue) {
-				SqlParameterValue paramValue = (SqlParameterValue) argValue;
-				if (paramValue.getSqlType() == Types.BLOB) {
-					if (paramValue.getValue() != null) {
-						Assert.isInstanceOf(byte[].class, paramValue.getValue(),
-								"Value of blob parameter must be byte[]");
-					}
-					byte[] valueBytes = (byte[]) paramValue.getValue();
-					this.lobCreator.setBlobAsBytes(ps, parameterPosition, valueBytes);
-					return;
-				}
-			}
-			super.doSetValue(ps, parameterPosition, argValue);
-		}
-
-	}
-
 	static {
 		Map<String, AuthorizationGrantType> am = new HashMap<>();
 		for (AuthorizationGrantType a : Arrays.asList(

+ 1 - 1
oauth2-authorization-server/src/main/resources/org/springframework/security/oauth2/server/authorization/client/oauth2-registered-client.sql

@@ -2,7 +2,7 @@ CREATE TABLE oauth2_registered_client (
     id varchar(100) NOT NULL,
     client_id varchar(100) NOT NULL,
     client_id_issued_at timestamp DEFAULT CURRENT_TIMESTAMP NOT NULL,
-    client_secret blob NOT NULL,
+    client_secret varchar(200) DEFAULT NULL,
     client_secret_expires_at timestamp DEFAULT NULL,
     client_name varchar(200),
     client_authentication_methods varchar(1000) NOT NULL,

+ 24 - 13
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java

@@ -103,15 +103,6 @@ public class JdbcRegisteredClientRepositoryTests {
 		// @formatter:on
 	}
 
-	@Test
-	public void whenLobHandlerNullThenThrow() {
-		// @formatter:off
-		assertThatIllegalArgumentException()
-				.isThrownBy(() -> new JdbcRegisteredClientRepository(this.jdbc, null, new ObjectMapper()))
-				.withMessage("lobHandler cannot be null");
-		// @formatter:on
-	}
-
 	@Test
 	public void whenSetNullRegisteredClientRowMapperThenThrow() {
 		// @formatter:off
@@ -198,12 +189,12 @@ public class JdbcRegisteredClientRepositoryTests {
 	}
 
 	@Test
-	public void saveWhenExistingClientSecretThenThrowIllegalArgumentException() {
+	public void saveWhenExistingClientSecretThenSuccess() {
 		RegisteredClient registeredClient = createRegisteredClient(
 				"client-2", "client-id-2", this.registration.getClientSecret());
-		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.clients.save(registeredClient))
-				.withMessage("Registered client must be unique. Found duplicate client secret for identifier: " + registeredClient.getId());
+		this.clients.save(registeredClient);
+		RegisteredClient savedClient = this.clients.findById(registeredClient.getId());
+		assertRegisteredClientIsEqualTo(savedClient, registeredClient);
 	}
 
 	@Test
@@ -222,6 +213,26 @@ public class JdbcRegisteredClientRepositoryTests {
 		assertRegisteredClientIsEqualTo(savedClient, registeredClient);
 	}
 
+	@Test
+	public void saveWhenPublicClientSavedAndFindByClientIdThenFound() {
+		RegisteredClient registeredClient = TestRegisteredClients.registeredPublicClient().build();
+		this.clients.save(registeredClient);
+		RegisteredClient savedClient = this.clients.findByClientId(registeredClient.getClientId());
+		assertRegisteredClientIsEqualTo(savedClient, registeredClient);
+	}
+
+	@Test
+	public void saveWhenMultiplePublicClientsSavedAndFindByIdThenFound() {
+		RegisteredClient registeredClient1 = TestRegisteredClients.registeredPublicClient()
+				.id("1").clientId("a").build();
+		RegisteredClient registeredClient2 = TestRegisteredClients.registeredPublicClient()
+				.id("2").clientId("b").build();
+		this.clients.save(registeredClient1);
+		this.clients.save(registeredClient2);
+		RegisteredClient savedClient = this.clients.findByClientId(registeredClient2.getClientId());
+		assertRegisteredClientIsEqualTo(savedClient, registeredClient2);
+	}
+
 	@Test
 	public void whenSaveRegistrationWithAllAttrsThenSaved() {
 		Instant issuedAt = Instant.now(), expiresAt = issuedAt.plus(Duration.ofDays(30));