Переглянути джерело

Polish JdbcAssertingPartyMetadataRepository

- Remove GetBytes since it's not used yet
- Remove customizable RowMapper since this can be added
later
- Change signing_algorithms to be a String since the conversion
strategy is simple
- Standardize test names
- Simplify conversion of credentials using ThrowingFunction
- Change column names to match RelyingPartyRegistration
field names

Issue gh-16012
Josh Cummings 2 місяців тому
батько
коміт
e8f920e0ee

+ 39 - 90
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepository.java

@@ -20,15 +20,13 @@ import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.sql.Types;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
+import java.util.Collections;
 import java.util.Iterator;
 import java.util.List;
 import java.util.function.Function;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-
-import org.springframework.core.log.LogMessage;
 import org.springframework.core.serializer.DefaultDeserializer;
 import org.springframework.core.serializer.DefaultSerializer;
 import org.springframework.core.serializer.Deserializer;
@@ -53,22 +51,22 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
 
 	private final JdbcOperations jdbcOperations;
 
-	private RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper = new AssertingPartyMetadataRowMapper(
-			ResultSet::getBytes);
+	private final RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper = new AssertingPartyMetadataRowMapper();
 
 	private final AssertingPartyMetadataParametersMapper assertingPartyMetadataParametersMapper = new AssertingPartyMetadataParametersMapper();
 
 	// @formatter:off
-	static final String COLUMN_NAMES = "entity_id, "
-			+ "singlesignon_url, "
-			+ "singlesignon_binding, "
-			+ "singlesignon_sign_request, "
-			+ "signing_algorithms, "
-			+ "verification_credentials, "
-			+ "encryption_credentials, "
-			+ "singlelogout_url, "
-			+ "singlelogout_response_url, "
-			+ "singlelogout_binding";
+	static final String[] COLUMN_NAMES = { "entity_id",
+			"single_sign_on_service_location",
+			"single_sign_on_service_binding",
+			"want_authn_requests_signed",
+			"signing_algorithms",
+			"verification_credentials",
+			"encryption_credentials",
+			"single_logout_service_location",
+			"single_logout_service_response_location",
+			"single_logout_service_binding" };
+
 	// @formatter:on
 
 	private static final String TABLE_NAME = "saml2_asserting_party_metadata";
@@ -76,30 +74,23 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
 	private static final String ENTITY_ID_FILTER = "entity_id = ?";
 
 	// @formatter:off
-	private static final String LOAD_BY_ID_SQL = "SELECT " + COLUMN_NAMES
+	private static final String LOAD_BY_ID_SQL = "SELECT " + String.join(",", COLUMN_NAMES)
 			+ " FROM " + TABLE_NAME
 			+ " WHERE " + ENTITY_ID_FILTER;
 
-	private static final String LOAD_ALL_SQL = "SELECT " + COLUMN_NAMES
+	private static final String LOAD_ALL_SQL = "SELECT " + String.join(",", COLUMN_NAMES)
 			+ " FROM " + TABLE_NAME;
 	// @formatter:on
 
 	// @formatter:off
 	private static final String SAVE_CREDENTIAL_RECORD_SQL = "INSERT INTO " + TABLE_NAME
-			+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
+			+ " (" + String.join(",", COLUMN_NAMES) + ") VALUES (" + String.join(",", Collections.nCopies(COLUMN_NAMES.length, "?")) + ")";
 	// @formatter:on
 
 	// @formatter:off
 	private static final String UPDATE_CREDENTIAL_RECORD_SQL = "UPDATE " + TABLE_NAME
-			+ " SET singlesignon_url = ?, "
-			+ "singlesignon_binding = ?, "
-			+ "singlesignon_sign_request = ?, "
-			+ "signing_algorithms = ?, "
-			+ "verification_credentials = ?, "
-			+ "encryption_credentials = ?, "
-			+ "singlelogout_url = ?, "
-			+ "singlelogout_response_url = ?, "
-			+ "singlelogout_binding = ?"
+			+ " SET " + String.join(" = ?,", Arrays.copyOfRange(COLUMN_NAMES, 1, COLUMN_NAMES.length))
+			+ " = ?"
 			+ " WHERE " + ENTITY_ID_FILTER;
 	// @formatter:on
 
@@ -113,18 +104,6 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
 		this.jdbcOperations = jdbcOperations;
 	}
 
-	/**
-	 * Sets the {@link RowMapper} used for mapping the current row in
-	 * {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}. The default is
-	 * {@link AssertingPartyMetadataRowMapper}.
-	 * @param assertingPartyMetadataRowMapper the {@link RowMapper} used for mapping the
-	 * current row in {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}
-	 */
-	public void setAssertingPartyMetadataRowMapper(RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper) {
-		Assert.notNull(assertingPartyMetadataRowMapper, "assertingPartyMetadataRowMapper cannot be null");
-		this.assertingPartyMetadataRowMapper = assertingPartyMetadataRowMapper;
-	}
-
 	@Override
 	public AssertingPartyMetadata findByEntityId(String entityId) {
 		Assert.hasText(entityId, "entityId cannot be empty");
@@ -172,52 +151,26 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
 	 */
 	private static final class AssertingPartyMetadataRowMapper implements RowMapper<AssertingPartyMetadata> {
 
-		private final Log logger = LogFactory.getLog(AssertingPartyMetadataRowMapper.class);
-
 		private final Deserializer<Object> deserializer = new DefaultDeserializer();
 
-		private final GetBytes getBytes;
-
-		AssertingPartyMetadataRowMapper(GetBytes getBytes) {
-			this.getBytes = getBytes;
-		}
-
 		@Override
 		public AssertingPartyMetadata mapRow(ResultSet rs, int rowNum) throws SQLException {
-			String entityId = rs.getString("entity_id");
-			String singleSignOnUrl = rs.getString("singlesignon_url");
-			Saml2MessageBinding singleSignOnBinding = Saml2MessageBinding.from(rs.getString("singlesignon_binding"));
-			boolean singleSignOnSignRequest = rs.getBoolean("singlesignon_sign_request");
-			String singleLogoutUrl = rs.getString("singlelogout_url");
-			String singleLogoutResponseUrl = rs.getString("singlelogout_response_url");
-			Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding.from(rs.getString("singlelogout_binding"));
-			byte[] signingAlgorithmsBytes = this.getBytes.getBytes(rs, "signing_algorithms");
-			byte[] verificationCredentialsBytes = this.getBytes.getBytes(rs, "verification_credentials");
-			byte[] encryptionCredentialsBytes = this.getBytes.getBytes(rs, "encryption_credentials");
-
+			String entityId = rs.getString(COLUMN_NAMES[0]);
+			String singleSignOnUrl = rs.getString(COLUMN_NAMES[1]);
+			Saml2MessageBinding singleSignOnBinding = Saml2MessageBinding.from(rs.getString(COLUMN_NAMES[2]));
+			boolean singleSignOnSignRequest = rs.getBoolean(COLUMN_NAMES[3]);
+			List<String> algorithms = List.of(rs.getString(COLUMN_NAMES[4]).split(","));
+			byte[] verificationCredentialsBytes = rs.getBytes(COLUMN_NAMES[5]);
+			byte[] encryptionCredentialsBytes = rs.getBytes(COLUMN_NAMES[6]);
+			ThrowingFunction<byte[], Collection<Saml2X509Credential>> credentials = (
+					bytes) -> (Collection<Saml2X509Credential>) this.deserializer.deserializeFromByteArray(bytes);
 			AssertingPartyMetadata.Builder<?> builder = new AssertingPartyDetails.Builder();
-			try {
-				if (signingAlgorithmsBytes != null) {
-					List<String> signingAlgorithms = (List<String>) this.deserializer
-						.deserializeFromByteArray(signingAlgorithmsBytes);
-					builder.signingAlgorithms((algorithms) -> algorithms.addAll(signingAlgorithms));
-				}
-				if (verificationCredentialsBytes != null) {
-					Collection<Saml2X509Credential> verificationCredentials = (Collection<Saml2X509Credential>) this.deserializer
-						.deserializeFromByteArray(verificationCredentialsBytes);
-					builder.verificationX509Credentials((credentials) -> credentials.addAll(verificationCredentials));
-				}
-				if (encryptionCredentialsBytes != null) {
-					Collection<Saml2X509Credential> encryptionCredentials = (Collection<Saml2X509Credential>) this.deserializer
-						.deserializeFromByteArray(encryptionCredentialsBytes);
-					builder.encryptionX509Credentials((credentials) -> credentials.addAll(encryptionCredentials));
-				}
-			}
-			catch (Exception ex) {
-				this.logger.debug(LogMessage.format("Parsing serialized credentials for entity %s failed", entityId),
-						ex);
-				return null;
-			}
+			Collection<Saml2X509Credential> verificationCredentials = credentials.apply(verificationCredentialsBytes);
+			Collection<Saml2X509Credential> encryptionCredentials = (encryptionCredentialsBytes != null)
+					? credentials.apply(encryptionCredentialsBytes) : List.of();
+			String singleLogoutUrl = rs.getString(COLUMN_NAMES[7]);
+			String singleLogoutResponseUrl = rs.getString(COLUMN_NAMES[8]);
+			Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding.from(rs.getString(COLUMN_NAMES[9]));
 
 			builder.entityId(entityId)
 				.wantAuthnRequestsSigned(singleSignOnSignRequest)
@@ -225,7 +178,10 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
 				.singleSignOnServiceBinding(singleSignOnBinding)
 				.singleLogoutServiceLocation(singleLogoutUrl)
 				.singleLogoutServiceBinding(singleLogoutBinding)
-				.singleLogoutServiceResponseLocation(singleLogoutResponseUrl);
+				.singleLogoutServiceResponseLocation(singleLogoutResponseUrl)
+				.signingAlgorithms((a) -> a.addAll(algorithms))
+				.verificationX509Credentials((c) -> c.addAll(verificationCredentials))
+				.encryptionX509Credentials((c) -> c.addAll(encryptionCredentials));
 			return builder.build();
 		}
 
@@ -244,8 +200,7 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
 			parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceLocation()));
 			parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceBinding().getUrn()));
 			parameters.add(new SqlParameterValue(Types.BOOLEAN, record.getWantAuthnRequestsSigned()));
-			ThrowingFunction<List<String>, byte[]> algorithms = this.serializer::serializeToByteArray;
-			parameters.add(new SqlParameterValue(Types.BLOB, algorithms.apply(record.getSigningAlgorithms())));
+			parameters.add(new SqlParameterValue(Types.BLOB, String.join(",", record.getSigningAlgorithms())));
 			ThrowingFunction<Collection<Saml2X509Credential>, byte[]> credentials = this.serializer::serializeToByteArray;
 			parameters
 				.add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getVerificationX509Credentials())));
@@ -259,10 +214,4 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
 
 	}
 
-	private interface GetBytes {
-
-		byte[] getBytes(ResultSet rs, String columnName) throws SQLException;
-
-	}
-
 }

+ 10 - 10
saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema-postgres.sql

@@ -1,14 +1,14 @@
 CREATE TABLE saml2_asserting_party_metadata
 (
-    entity_id                 VARCHAR(1000) NOT NULL,
-    singlesignon_url          VARCHAR(1000) NOT NULL,
-    singlesignon_binding      VARCHAR(100),
-    singlesignon_sign_request boolean,
-    signing_algorithms        BYTEA,
-    verification_credentials  BYTEA         NOT NULL,
-    encryption_credentials    BYTEA,
-    singlelogout_url          VARCHAR(1000),
-    singlelogout_response_url VARCHAR(1000),
-    singlelogout_binding      VARCHAR(100),
+    entity_id                               VARCHAR(1000) NOT NULL,
+    single_sign_on_service_location         VARCHAR(1000) NOT NULL,
+    single_sign_on_service_binding          VARCHAR(100),
+    want_authn_requests_signed              boolean,
+    signing_algorithms                      BYTEA,
+    verification_credentials                BYTEA NOT NULL,
+    encryption_credentials                  BYTEA,
+    single_logout_service_location          VARCHAR(1000),
+    single_logout_service_response_location VARCHAR(1000),
+    single_logout_service_binding           VARCHAR(100),
     PRIMARY KEY (entity_id)
 );

+ 10 - 10
saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql

@@ -1,14 +1,14 @@
 CREATE TABLE saml2_asserting_party_metadata
 (
-    entity_id                 VARCHAR(1000) NOT NULL,
-    singlesignon_url          VARCHAR(1000) NOT NULL,
-    singlesignon_binding      VARCHAR(100),
-    singlesignon_sign_request boolean,
-    signing_algorithms        blob,
-    verification_credentials  blob          NOT NULL,
-    encryption_credentials    blob,
-    singlelogout_url          VARCHAR(1000),
-    singlelogout_response_url VARCHAR(1000),
-    singlelogout_binding      VARCHAR(100),
+    entity_id                               VARCHAR(1000) NOT NULL,
+    single_sign_on_service_location         VARCHAR(1000) NOT NULL,
+    single_sign_on_service_binding          VARCHAR(100),
+    want_authn_requests_signed              boolean,
+    signing_algorithms                      VARCHAR(256) NOT NULL,
+    verification_credentials                blob NOT NULL,
+    encryption_credentials                  blob,
+    single_logout_service_location          VARCHAR(1000),
+    single_logout_service_response_location VARCHAR(1000),
+    single_logout_service_binding           VARCHAR(100),
     PRIMARY KEY (entity_id)
 );

+ 4 - 7
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepositoryTests.java

@@ -79,7 +79,7 @@ class JdbcAssertingPartyMetadataRepositoryTests {
 	}
 
 	@Test
-	void findByEntityId() {
+	void findByEntityIdWhenEntityPresentThenReturns() {
 		this.repository.save(this.metadata);
 
 		AssertingPartyMetadata found = this.repository.findByEntityId(this.metadata.getEntityId());
@@ -88,17 +88,14 @@ class JdbcAssertingPartyMetadataRepositoryTests {
 	}
 
 	@Test
-	void findByEntityIdWhenNotExists() {
+	void findByEntityIdWhenNotExistsThenNull() {
 		AssertingPartyMetadata found = this.repository.findByEntityId("non-existent-entity-id");
 		assertThat(found).isNull();
 	}
 
 	@Test
-	void iterator() {
-		AssertingPartyMetadata second = RelyingPartyRegistration.withAssertingPartyMetadata(this.metadata)
-			.assertingPartyMetadata((a) -> a.entityId("https://example.org/idp"))
-			.build()
-			.getAssertingPartyMetadata();
+	void iteratorWhenEnitiesExistThenContains() {
+		AssertingPartyMetadata second = this.metadata.mutate().entityId("https://example.org/idp").build();
 		this.repository.save(this.metadata);
 		this.repository.save(second);