소스 검색

Polish gh-291

Steve Riesenberg 4 년 전
부모
커밋
3318874da1

+ 139 - 122
oauth2-authorization-server/src/main/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepository.java

@@ -30,7 +30,7 @@ import java.util.Map;
 import java.util.Set;
 import java.util.function.Function;
 
-import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.core.type.TypeReference;
 import com.fasterxml.jackson.databind.ObjectMapper;
 
 import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
@@ -88,21 +88,10 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 	 * @param jdbcOperations the JDBC operations
 	 */
 	public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations) {
-		this(jdbcOperations, new ObjectMapper());
-	}
-
-	/**
-	 * Constructs a {@code JdbcRegisteredClientRepository} using the provided parameters.
-	 *
-	 * @param jdbcOperations the JDBC operations
-	 * @param objectMapper the object mapper
-	 */
-	public JdbcRegisteredClientRepository(JdbcOperations jdbcOperations, ObjectMapper objectMapper) {
 		Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
-		Assert.notNull(objectMapper, "objectMapper cannot be null");
 		this.jdbcOperations = jdbcOperations;
-		this.registeredClientRowMapper = new DefaultRegisteredClientRowMapper(objectMapper);
-		this.registeredClientParametersMapper = new DefaultRegisteredClientParametersMapper(objectMapper);
+		this.registeredClientRowMapper = new DefaultRegisteredClientRowMapper();
+		this.registeredClientParametersMapper = new DefaultRegisteredClientParametersMapper();
 	}
 
 	/**
@@ -110,7 +99,7 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 	 *
 	 * @param registeredClientRowMapper mapper implementation
 	 */
-	public void setRegisteredClientRowMapper(RowMapper<RegisteredClient> registeredClientRowMapper) {
+	public final void setRegisteredClientRowMapper(RowMapper<RegisteredClient> registeredClientRowMapper) {
 		Assert.notNull(registeredClientRowMapper, "registeredClientRowMapper cannot be null");
 		this.registeredClientRowMapper = registeredClientRowMapper;
 	}
@@ -120,18 +109,30 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 	 *
 	 * @param registeredClientParametersMapper mapper implementation
 	 */
-	public void setRegisteredClientParametersMapper(Function<RegisteredClient, List<SqlParameterValue>> registeredClientParametersMapper) {
+	public final void setRegisteredClientParametersMapper(Function<RegisteredClient, List<SqlParameterValue>> registeredClientParametersMapper) {
 		Assert.notNull(registeredClientParametersMapper, "registeredClientParameterMapper cannot be null");
 		this.registeredClientParametersMapper = registeredClientParametersMapper;
 	}
 
+	protected final JdbcOperations getJdbcOperations() {
+		return this.jdbcOperations;
+	}
+
+	protected final RowMapper<RegisteredClient> getRegisteredClientRowMapper() {
+		return this.registeredClientRowMapper;
+	}
+
+	protected final Function<RegisteredClient, List<SqlParameterValue>> getRegisteredClientParametersMapper() {
+		return this.registeredClientParametersMapper;
+	}
+
 	@Override
 	public void save(RegisteredClient registeredClient) {
 		Assert.notNull(registeredClient, "registeredClient cannot be null");
-		RegisteredClient foundClient = this.findBy("id = ? OR client_id = ?",
+		RegisteredClient foundClient = findBy("id = ? OR client_id = ?",
 				registeredClient.getId(), registeredClient.getClientId());
 
-		if (null != foundClient) {
+		if (foundClient != null) {
 			Assert.isTrue(!foundClient.getId().equals(registeredClient.getId()),
 					"Registered client must be unique. Found duplicate identifier: " + registeredClient.getId());
 			Assert.isTrue(!foundClient.getClientId().equals(registeredClient.getClientId()),
@@ -155,29 +156,20 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 		return findBy("client_id = ?", clientId);
 	}
 
-	private RegisteredClient findBy(String condStr, Object...args) {
-		List<RegisteredClient> lst = this.jdbcOperations.query(
+	private RegisteredClient findBy(String condStr, Object... args) {
+		List<RegisteredClient> result = this.jdbcOperations.query(
 				LOAD_REGISTERED_CLIENT_SQL + condStr,
-				registeredClientRowMapper, args);
-		return !lst.isEmpty() ? lst.get(0) : null;
+				this.registeredClientRowMapper, args);
+		return !result.isEmpty() ? result.get(0) : null;
 	}
 
 	public static class DefaultRegisteredClientRowMapper implements RowMapper<RegisteredClient> {
 
-		private final ObjectMapper objectMapper;
-
-		public DefaultRegisteredClientRowMapper(ObjectMapper objectMapper) {
-			this.objectMapper = objectMapper;
-		}
-
-		private Set<String> parseList(String s) {
-			return s != null ? StringUtils.commaDelimitedListToSet(s) : Collections.emptySet();
-		}
+		private ObjectMapper objectMapper = new ObjectMapper();
 
 		@Override
-		@SuppressWarnings("unchecked")
 		public RegisteredClient mapRow(ResultSet rs, int rowNum) throws SQLException {
-			Set<String> scopes = parseList(rs.getString("scopes"));
+			Set<String> clientScopes = parseList(rs.getString("scopes"));
 			Set<String> authGrantTypes = parseList(rs.getString("authorization_grant_types"));
 			Set<String> clientAuthMethods = parseList(rs.getString("client_authentication_methods"));
 			Set<String> redirectUris = parseList(rs.getString("redirect_uris"));
@@ -191,115 +183,140 @@ public class JdbcRegisteredClientRepository implements RegisteredClientRepositor
 					.clientSecret(clientSecret)
 					.clientSecretExpiresAt(clientSecretExpiresAt != null ? clientSecretExpiresAt.toInstant() : null)
 					.clientName(rs.getString("client_name"))
-					.authorizationGrantTypes(coll -> authGrantTypes.forEach(authGrantType ->
-							coll.add(AUTHORIZATION_GRANT_TYPE_MAP.get(authGrantType))))
-					.clientAuthenticationMethods(coll -> clientAuthMethods.forEach(clientAuthMethod ->
-							coll.add(CLIENT_AUTHENTICATION_METHOD_MAP.get(clientAuthMethod))))
-					.redirectUris(coll -> coll.addAll(redirectUris))
-					.scopes(coll -> coll.addAll(scopes));
-
-			RegisteredClient rc = builder.build();
+					.authorizationGrantTypes((grantTypes) -> authGrantTypes.forEach(authGrantType ->
+							grantTypes.add(AUTHORIZATION_GRANT_TYPE_MAP.get(authGrantType))))
+					.clientAuthenticationMethods((authenticationMethods) -> clientAuthMethods.forEach(clientAuthMethod ->
+							authenticationMethods.add(CLIENT_AUTHENTICATION_METHOD_MAP.get(clientAuthMethod))))
+					.redirectUris((uris) -> uris.addAll(redirectUris))
+					.scopes((scopes) -> scopes.addAll(clientScopes));
+
+			RegisteredClient registeredClient = builder.build();
+
+			String tokenSettingsJson = rs.getString("token_settings");
+			if (tokenSettingsJson != null) {
+				Map<String, Object> settings = parseMap(tokenSettingsJson);
+				TokenSettings tokenSettings = registeredClient.getTokenSettings();
+
+				Number accessTokenTTL = (Number) settings.get("access_token_ttl");
+				if (accessTokenTTL != null) {
+					tokenSettings.accessTokenTimeToLive(Duration.ofMillis(accessTokenTTL.longValue()));
+				}
 
-			TokenSettings ts = rc.getTokenSettings();
-			ClientSettings cs = rc.getClientSettings();
+				Number refreshTokenTTL = (Number) settings.get("refresh_token_ttl");
+				if (refreshTokenTTL != null) {
+					tokenSettings.refreshTokenTimeToLive(Duration.ofMillis(refreshTokenTTL.longValue()));
+				}
 
-			try {
-				String tokenSettingsJson = rs.getString("token_settings");
-				if (tokenSettingsJson != null) {
-					Map<String, Object> m = this.objectMapper.readValue(tokenSettingsJson, Map.class);
-
-					Number accessTokenTTL = (Number) m.get("access_token_ttl");
-					if (accessTokenTTL != null) {
-						ts.accessTokenTimeToLive(Duration.ofMillis(accessTokenTTL.longValue()));
-					}
-
-					Number refreshTokenTTL = (Number) m.get("refresh_token_ttl");
-					if (refreshTokenTTL != null) {
-						ts.refreshTokenTimeToLive(Duration.ofMillis(refreshTokenTTL.longValue()));
-					}
-
-					Boolean reuseRefreshTokens = (Boolean) m.get("reuse_refresh_tokens");
-					if (reuseRefreshTokens != null) {
-						ts.reuseRefreshTokens(reuseRefreshTokens);
-					}
+				Boolean reuseRefreshTokens = (Boolean) settings.get("reuse_refresh_tokens");
+				if (reuseRefreshTokens != null) {
+					tokenSettings.reuseRefreshTokens(reuseRefreshTokens);
 				}
+			}
 
-				String clientSettingsJson = rs.getString("client_settings");
-				if (clientSettingsJson != null) {
-					Map<String, Object> m = this.objectMapper.readValue(clientSettingsJson, Map.class);
+			String clientSettingsJson = rs.getString("client_settings");
+			if (clientSettingsJson != null) {
+				Map<String, Object> settings = parseMap(clientSettingsJson);
+				ClientSettings clientSettings = registeredClient.getClientSettings();
 
-					Boolean requireProofKey = (Boolean) m.get("require_proof_key");
-					if (requireProofKey != null) {
-						cs.requireProofKey(requireProofKey);
-					}
+				Boolean requireProofKey = (Boolean) settings.get("require_proof_key");
+				if (requireProofKey != null) {
+					clientSettings.requireProofKey(requireProofKey);
+				}
 
-					Boolean requireUserConsent = (Boolean) m.get("require_user_consent");
-					if (requireUserConsent != null) {
-						cs.requireUserConsent(requireUserConsent);
-					}
+				Boolean requireUserConsent = (Boolean) settings.get("require_user_consent");
+				if (requireUserConsent != null) {
+					clientSettings.requireUserConsent(requireUserConsent);
 				}
-			} catch (JsonProcessingException e) {
-				throw new IllegalArgumentException(e.getMessage(), e);
 			}
 
-			return rc;
+			return registeredClient;
 		}
 
-	}
+		public final void setObjectMapper(ObjectMapper objectMapper) {
+			Assert.notNull(objectMapper, "objectMapper cannot be null");
+			this.objectMapper = objectMapper;
+		}
 
-	public static class DefaultRegisteredClientParametersMapper implements Function<RegisteredClient, List<SqlParameterValue>> {
+		protected final ObjectMapper getObjectMapper() {
+			return this.objectMapper;
+		}
 
-		private final ObjectMapper objectMapper;
+		private Set<String> parseList(String s) {
+			return s != null ? StringUtils.commaDelimitedListToSet(s) : Collections.emptySet();
+		}
 
-		private DefaultRegisteredClientParametersMapper(ObjectMapper objectMapper) {
-			this.objectMapper = objectMapper;
+		private Map<String, Object> parseMap(String data) {
+			try {
+				return this.objectMapper.readValue(data, new TypeReference<Map<String, Object>>() {});
+			} catch (Exception ex) {
+				throw new IllegalArgumentException(ex.getMessage(), ex);
+			}
 		}
 
+	}
+
+	public static class DefaultRegisteredClientParametersMapper implements Function<RegisteredClient, List<SqlParameterValue>> {
+
+		private ObjectMapper objectMapper = new ObjectMapper();
+
 		@Override
 		public List<SqlParameterValue> apply(RegisteredClient registeredClient) {
-			try {
-				List<String> clientAuthenticationMethodNames = new ArrayList<>(registeredClient.getClientAuthenticationMethods().size());
-				for (ClientAuthenticationMethod clientAuthenticationMethod : registeredClient.getClientAuthenticationMethods()) {
-					clientAuthenticationMethodNames.add(clientAuthenticationMethod.getValue());
-				}
+			List<String> clientAuthenticationMethodNames = new ArrayList<>(registeredClient.getClientAuthenticationMethods().size());
+			for (ClientAuthenticationMethod clientAuthenticationMethod : registeredClient.getClientAuthenticationMethods()) {
+				clientAuthenticationMethodNames.add(clientAuthenticationMethod.getValue());
+			}
 
-				List<String> authorizationGrantTypeNames = new ArrayList<>(registeredClient.getAuthorizationGrantTypes().size());
-				for (AuthorizationGrantType authorizationGrantType : registeredClient.getAuthorizationGrantTypes()) {
-					authorizationGrantTypeNames.add(authorizationGrantType.getValue());
-				}
+			List<String> authorizationGrantTypeNames = new ArrayList<>(registeredClient.getAuthorizationGrantTypes().size());
+			for (AuthorizationGrantType authorizationGrantType : registeredClient.getAuthorizationGrantTypes()) {
+				authorizationGrantTypeNames.add(authorizationGrantType.getValue());
+			}
+
+			Instant issuedAt = registeredClient.getClientIdIssuedAt() != null ?
+					registeredClient.getClientIdIssuedAt() : Instant.now();
+
+			Timestamp clientSecretExpiresAt = registeredClient.getClientSecretExpiresAt() != null ?
+					Timestamp.from(registeredClient.getClientSecretExpiresAt()) : null;
+
+			Map<String, Object> clientSettings = new HashMap<>();
+			clientSettings.put("require_proof_key", registeredClient.getClientSettings().requireProofKey());
+			clientSettings.put("require_user_consent", registeredClient.getClientSettings().requireUserConsent());
+			String clientSettingsJson = writeMap(clientSettings);
+
+			Map<String, Object> tokenSettings = new HashMap<>();
+			tokenSettings.put("access_token_ttl", registeredClient.getTokenSettings().accessTokenTimeToLive().toMillis());
+			tokenSettings.put("reuse_refresh_tokens", registeredClient.getTokenSettings().reuseRefreshTokens());
+			tokenSettings.put("refresh_token_ttl", registeredClient.getTokenSettings().refreshTokenTimeToLive().toMillis());
+			String tokenSettingsJson = writeMap(tokenSettings);
+
+			return Arrays.asList(
+					new SqlParameterValue(Types.VARCHAR, registeredClient.getId()),
+					new SqlParameterValue(Types.VARCHAR, registeredClient.getClientId()),
+					new SqlParameterValue(Types.TIMESTAMP, Timestamp.from(issuedAt)),
+					new SqlParameterValue(Types.VARCHAR, registeredClient.getClientSecret()),
+					new SqlParameterValue(Types.TIMESTAMP, clientSecretExpiresAt),
+					new SqlParameterValue(Types.VARCHAR, registeredClient.getClientName()),
+					new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(clientAuthenticationMethodNames)),
+					new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(authorizationGrantTypeNames)),
+					new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getRedirectUris())),
+					new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getScopes())),
+					new SqlParameterValue(Types.VARCHAR, clientSettingsJson),
+					new SqlParameterValue(Types.VARCHAR, tokenSettingsJson));
+		}
 
-				Instant issuedAt = registeredClient.getClientIdIssuedAt() != null ?
-						registeredClient.getClientIdIssuedAt() : Instant.now();
-
-				Timestamp clientSecretExpiresAt = registeredClient.getClientSecretExpiresAt() != null ?
-						Timestamp.from(registeredClient.getClientSecretExpiresAt()) : null;
-
-				Map<String, Object> clientSettings = new HashMap<>();
-				clientSettings.put("require_proof_key", registeredClient.getClientSettings().requireProofKey());
-				clientSettings.put("require_user_consent", registeredClient.getClientSettings().requireUserConsent());
-				String clientSettingsJson = this.objectMapper.writeValueAsString(clientSettings);
-
-				Map<String, Object> tokenSettings = new HashMap<>();
-				tokenSettings.put("access_token_ttl", registeredClient.getTokenSettings().accessTokenTimeToLive().toMillis());
-				tokenSettings.put("reuse_refresh_tokens", registeredClient.getTokenSettings().reuseRefreshTokens());
-				tokenSettings.put("refresh_token_ttl", registeredClient.getTokenSettings().refreshTokenTimeToLive().toMillis());
-				String tokenSettingsJson = this.objectMapper.writeValueAsString(tokenSettings);
-
-				return Arrays.asList(
-						new SqlParameterValue(Types.VARCHAR, registeredClient.getId()),
-						new SqlParameterValue(Types.VARCHAR, registeredClient.getClientId()),
-						new SqlParameterValue(Types.TIMESTAMP, Timestamp.from(issuedAt)),
-						new SqlParameterValue(Types.VARCHAR, registeredClient.getClientSecret()),
-						new SqlParameterValue(Types.TIMESTAMP, clientSecretExpiresAt),
-						new SqlParameterValue(Types.VARCHAR, registeredClient.getClientName()),
-						new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(clientAuthenticationMethodNames)),
-						new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(authorizationGrantTypeNames)),
-						new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getRedirectUris())),
-						new SqlParameterValue(Types.VARCHAR, StringUtils.collectionToCommaDelimitedString(registeredClient.getScopes())),
-						new SqlParameterValue(Types.VARCHAR, clientSettingsJson),
-						new SqlParameterValue(Types.VARCHAR, tokenSettingsJson));
-			} catch (JsonProcessingException e) {
-				throw new IllegalArgumentException(e.getMessage(), e);
+		public final void setObjectMapper(ObjectMapper objectMapper) {
+			Assert.notNull(objectMapper, "objectMapper cannot be null");
+			this.objectMapper = objectMapper;
+		}
+
+		protected final ObjectMapper getObjectMapper() {
+			return this.objectMapper;
+		}
+
+		private String writeMap(Map<String, Object> data) {
+			try {
+				return this.objectMapper.writeValueAsString(data);
+			} catch (Exception ex) {
+				throw new IllegalArgumentException(ex.getMessage(), ex);
 			}
 		}
 

+ 1 - 11
oauth2-authorization-server/src/test/java/org/springframework/security/oauth2/server/authorization/client/JdbcRegisteredClientRepositoryTests.java

@@ -20,7 +20,6 @@ import java.nio.charset.Charset;
 import java.time.Duration;
 import java.time.Instant;
 
-import com.fasterxml.jackson.databind.ObjectMapper;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
@@ -89,20 +88,11 @@ public class JdbcRegisteredClientRepositoryTests {
 	public void whenJdbcOperationsNullThenThrow() {
 		// @formatter:off
 		assertThatIllegalArgumentException()
-				.isThrownBy(() -> new JdbcRegisteredClientRepository(null, new ObjectMapper()))
+				.isThrownBy(() -> new JdbcRegisteredClientRepository(null))
 				.withMessage("jdbcOperations cannot be null");
 		// @formatter:on
 	}
 
-	@Test
-	public void whenObjectMapperNullThenThrow() {
-		// @formatter:off
-		assertThatIllegalArgumentException()
-				.isThrownBy(() -> new JdbcRegisteredClientRepository(this.jdbc, null))
-				.withMessage("objectMapper cannot be null");
-		// @formatter:on
-	}
-
 	@Test
 	public void whenSetNullRegisteredClientRowMapperThenThrow() {
 		// @formatter:off