Bladeren bron

Add JdbcAssertingPartyMetadataRepository#save

Issue gh-16012

Co-Authored-By: chao.wang <chao.wang@zatech.com>
Josh Cummings 2 maanden geleden
bovenliggende
commit
2bd05128ec

+ 78 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepository.java

@@ -19,16 +19,20 @@ package org.springframework.security.saml2.provider.service.registration;
 import java.sql.ResultSet;
 import java.sql.SQLException;
 import java.sql.Types;
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Iterator;
 import java.util.List;
+import java.util.function.Function;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 
 import org.springframework.core.log.LogMessage;
 import org.springframework.core.serializer.DefaultDeserializer;
+import org.springframework.core.serializer.DefaultSerializer;
 import org.springframework.core.serializer.Deserializer;
+import org.springframework.core.serializer.Serializer;
 import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
 import org.springframework.jdbc.core.JdbcOperations;
 import org.springframework.jdbc.core.PreparedStatementSetter;
@@ -37,6 +41,7 @@ import org.springframework.jdbc.core.SqlParameterValue;
 import org.springframework.security.saml2.core.Saml2X509Credential;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.AssertingPartyDetails;
 import org.springframework.util.Assert;
+import org.springframework.util.function.ThrowingFunction;
 
 /**
  * A JDBC implementation of {@link AssertingPartyMetadataRepository}.
@@ -51,6 +56,8 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
 	private RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper = new AssertingPartyMetadataRowMapper(
 			ResultSet::getBytes);
 
+	private final AssertingPartyMetadataParametersMapper assertingPartyMetadataParametersMapper = new AssertingPartyMetadataParametersMapper();
+
 	// @formatter:off
 	static final String COLUMN_NAMES = "entity_id, "
 			+ "singlesignon_url, "
@@ -77,6 +84,25 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
 			+ " FROM " + TABLE_NAME;
 	// @formatter:on
 
+	// @formatter:off
+	private static final String SAVE_CREDENTIAL_RECORD_SQL = "INSERT INTO " + TABLE_NAME
+			+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
+	// @formatter:on
+
+	// @formatter:off
+	private static final String UPDATE_CREDENTIAL_RECORD_SQL = "UPDATE " + TABLE_NAME
+			+ " SET singlesignon_url = ?, "
+			+ "singlesignon_binding = ?, "
+			+ "singlesignon_sign_request = ?, "
+			+ "signing_algorithms = ?, "
+			+ "verification_credentials = ?, "
+			+ "encryption_credentials = ?, "
+			+ "singlelogout_url = ?, "
+			+ "singlelogout_response_url = ?, "
+			+ "singlelogout_binding = ?"
+			+ " WHERE " + ENTITY_ID_FILTER;
+	// @formatter:on
+
 	/**
 	 * Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided
 	 * parameters.
@@ -116,6 +142,30 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
 		return result.iterator();
 	}
 
+	/**
+	 * Persist this {@link AssertingPartyMetadata}
+	 * @param metadata the metadata to persist
+	 */
+	public void save(AssertingPartyMetadata metadata) {
+		Assert.notNull(metadata, "metadata cannot be null");
+		int rows = updateCredentialRecord(metadata);
+		if (rows == 0) {
+			insertCredentialRecord(metadata);
+		}
+	}
+
+	private void insertCredentialRecord(AssertingPartyMetadata metadata) {
+		List<SqlParameterValue> parameters = this.assertingPartyMetadataParametersMapper.apply(metadata);
+		this.jdbcOperations.update(SAVE_CREDENTIAL_RECORD_SQL, parameters.toArray());
+	}
+
+	private int updateCredentialRecord(AssertingPartyMetadata metadata) {
+		List<SqlParameterValue> parameters = this.assertingPartyMetadataParametersMapper.apply(metadata);
+		SqlParameterValue credentialId = parameters.remove(0);
+		parameters.add(credentialId);
+		return this.jdbcOperations.update(UPDATE_CREDENTIAL_RECORD_SQL, parameters.toArray());
+	}
+
 	/**
 	 * The default {@link RowMapper} that maps the current row in
 	 * {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}.
@@ -181,6 +231,34 @@ public final class JdbcAssertingPartyMetadataRepository implements AssertingPart
 
 	}
 
+	private static class AssertingPartyMetadataParametersMapper
+			implements Function<AssertingPartyMetadata, List<SqlParameterValue>> {
+
+		private final Serializer<Object> serializer = new DefaultSerializer();
+
+		@Override
+		public List<SqlParameterValue> apply(AssertingPartyMetadata record) {
+			List<SqlParameterValue> parameters = new ArrayList<>();
+
+			parameters.add(new SqlParameterValue(Types.VARCHAR, record.getEntityId()));
+			parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceLocation()));
+			parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleSignOnServiceBinding().getUrn()));
+			parameters.add(new SqlParameterValue(Types.BOOLEAN, record.getWantAuthnRequestsSigned()));
+			ThrowingFunction<List<String>, byte[]> algorithms = this.serializer::serializeToByteArray;
+			parameters.add(new SqlParameterValue(Types.BLOB, algorithms.apply(record.getSigningAlgorithms())));
+			ThrowingFunction<Collection<Saml2X509Credential>, byte[]> credentials = this.serializer::serializeToByteArray;
+			parameters
+				.add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getVerificationX509Credentials())));
+			parameters.add(new SqlParameterValue(Types.BLOB, credentials.apply(record.getEncryptionX509Credentials())));
+			parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceLocation()));
+			parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceResponseLocation()));
+			parameters.add(new SqlParameterValue(Types.VARCHAR, record.getSingleLogoutServiceBinding().getUrn()));
+
+			return parameters;
+		}
+
+	}
+
 	private interface GetBytes {
 
 		byte[] getBytes(ResultSet rs, String columnName) throws SQLException;

+ 39 - 81
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepositoryTests.java

@@ -16,27 +16,17 @@
 
 package org.springframework.security.saml2.provider.service.registration;
 
-import java.io.IOException;
-import java.io.InputStream;
-import java.security.cert.CertificateFactory;
-import java.security.cert.X509Certificate;
-import java.util.Collection;
 import java.util.Iterator;
-import java.util.List;
 
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
-import org.springframework.core.io.ClassPathResource;
-import org.springframework.core.serializer.DefaultSerializer;
-import org.springframework.core.serializer.Serializer;
 import org.springframework.jdbc.core.JdbcOperations;
 import org.springframework.jdbc.core.JdbcTemplate;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
 import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
-import org.springframework.security.saml2.core.Saml2X509Credential;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
@@ -48,41 +38,21 @@ class JdbcAssertingPartyMetadataRepositoryTests {
 
 	private static final String SCHEMA_SQL_RESOURCE = "org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql";
 
-	private static final String SAVE_SQL = "INSERT INTO saml2_asserting_party_metadata ("
-			+ JdbcAssertingPartyMetadataRepository.COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
-
-	private static final String ENTITY_ID = "https://localhost/simplesaml/saml2/idp/metadata.php";
-
-	private static final String SINGLE_SIGNON_URL = "https://localhost/SSO";
-
-	private static final String SINGLE_SIGNON_BINDING = Saml2MessageBinding.REDIRECT.getUrn();
-
-	private static final boolean SINGLE_SIGNON_SIGN_REQUEST = false;
-
-	private static final String SINGLE_LOGOUT_URL = "https://localhost/SLO";
-
-	private static final String SINGLE_LOGOUT_RESPONSE_URL = "https://localhost/SLO/response";
-
-	private static final String SINGLE_LOGOUT_BINDING = Saml2MessageBinding.REDIRECT.getUrn();
-
-	private static final List<String> SIGNING_ALGORITHMS = List.of("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512");
-
-	private X509Certificate certificate;
-
 	private EmbeddedDatabase db;
 
 	private JdbcAssertingPartyMetadataRepository repository;
 
 	private JdbcOperations jdbcOperations;
 
-	private final Serializer<Object> serializer = new DefaultSerializer();
+	private final AssertingPartyMetadata metadata = TestRelyingPartyRegistrations.full()
+		.build()
+		.getAssertingPartyMetadata();
 
 	@BeforeEach
 	void setUp() {
 		this.db = createDb();
 		this.jdbcOperations = new JdbcTemplate(this.db);
 		this.repository = new JdbcAssertingPartyMetadataRepository(this.jdbcOperations);
-		this.certificate = loadCertificate("rsa.crt");
 	}
 
 	@AfterEach
@@ -109,26 +79,12 @@ class JdbcAssertingPartyMetadataRepositoryTests {
 	}
 
 	@Test
-	void findByEntityId() throws IOException {
-		this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING,
-				SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
-				this.serializer.serializeToByteArray(asCredentials(this.certificate)),
-				this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
-				SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
+	void findByEntityId() {
+		this.repository.save(this.metadata);
 
-		AssertingPartyMetadata found = this.repository.findByEntityId(ENTITY_ID);
+		AssertingPartyMetadata found = this.repository.findByEntityId(this.metadata.getEntityId());
 
-		assertThat(found).isNotNull();
-		assertThat(found.getEntityId()).isEqualTo(ENTITY_ID);
-		assertThat(found.getSingleSignOnServiceLocation()).isEqualTo(SINGLE_SIGNON_URL);
-		assertThat(found.getSingleSignOnServiceBinding().getUrn()).isEqualTo(SINGLE_SIGNON_BINDING);
-		assertThat(found.getWantAuthnRequestsSigned()).isEqualTo(SINGLE_SIGNON_SIGN_REQUEST);
-		assertThat(found.getSingleLogoutServiceLocation()).isEqualTo(SINGLE_LOGOUT_URL);
-		assertThat(found.getSingleLogoutServiceResponseLocation()).isEqualTo(SINGLE_LOGOUT_RESPONSE_URL);
-		assertThat(found.getSingleLogoutServiceBinding().getUrn()).isEqualTo(SINGLE_LOGOUT_BINDING);
-		assertThat(found.getSigningAlgorithms()).contains(SIGNING_ALGORITHMS.get(0));
-		assertThat(found.getVerificationX509Credentials()).hasSize(1);
-		assertThat(found.getEncryptionX509Credentials()).hasSize(1);
+		assertAssertingPartyEquals(found, this.metadata);
 	}
 
 	@Test
@@ -138,28 +94,30 @@ class JdbcAssertingPartyMetadataRepositoryTests {
 	}
 
 	@Test
-	void iterator() throws IOException {
-		this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING,
-				SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
-				this.serializer.serializeToByteArray(asCredentials(this.certificate)),
-				this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
-				SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
-
-		this.jdbcOperations.update(SAVE_SQL, "https://localhost/simplesaml2/saml2/idp/metadata.php", SINGLE_SIGNON_URL,
-				SINGLE_SIGNON_BINDING, SINGLE_SIGNON_SIGN_REQUEST,
-				this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
-				this.serializer.serializeToByteArray(asCredentials(this.certificate)),
-				this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
-				SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
+	void iterator() {
+		AssertingPartyMetadata second = RelyingPartyRegistration.withAssertingPartyMetadata(this.metadata)
+			.assertingPartyMetadata((a) -> a.entityId("https://example.org/idp"))
+			.build()
+			.getAssertingPartyMetadata();
+		this.repository.save(this.metadata);
+		this.repository.save(second);
 
 		Iterator<AssertingPartyMetadata> iterator = this.repository.iterator();
-		AssertingPartyMetadata first = iterator.next();
-		assertThat(first).isNotNull();
-		AssertingPartyMetadata second = iterator.next();
-		assertThat(second).isNotNull();
+
+		assertAssertingPartyEquals(iterator.next(), this.metadata);
+		assertAssertingPartyEquals(iterator.next(), second);
 		assertThat(iterator.hasNext()).isFalse();
 	}
 
+	@Test
+	void saveWhenExistingThenUpdates() {
+		this.repository.save(this.metadata);
+		boolean existing = this.metadata.getWantAuthnRequestsSigned();
+		this.repository.save(this.metadata.mutate().wantAuthnRequestsSigned(!existing).build());
+		boolean updated = this.repository.findByEntityId(this.metadata.getEntityId()).getWantAuthnRequestsSigned();
+		assertThat(existing).isNotEqualTo(updated);
+	}
+
 	private static EmbeddedDatabase createDb() {
 		return createDb(SCHEMA_SQL_RESOURCE);
 	}
@@ -175,19 +133,19 @@ class JdbcAssertingPartyMetadataRepositoryTests {
 		// @formatter:on
 	}
 
-	private X509Certificate loadCertificate(String path) {
-		try (InputStream is = new ClassPathResource(path).getInputStream()) {
-			CertificateFactory factory = CertificateFactory.getInstance("X.509");
-			return (X509Certificate) factory.generateCertificate(is);
-		}
-		catch (Exception ex) {
-			throw new RuntimeException("Error loading certificate from " + path, ex);
-		}
-	}
-
-	private Collection<Saml2X509Credential> asCredentials(X509Certificate certificate) {
-		return List.of(new Saml2X509Credential(certificate, Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION,
-				Saml2X509Credential.Saml2X509CredentialType.VERIFICATION));
+	private void assertAssertingPartyEquals(AssertingPartyMetadata found, AssertingPartyMetadata expected) {
+		assertThat(found).isNotNull();
+		assertThat(found.getEntityId()).isEqualTo(expected.getEntityId());
+		assertThat(found.getSingleSignOnServiceLocation()).isEqualTo(expected.getSingleSignOnServiceLocation());
+		assertThat(found.getSingleSignOnServiceBinding()).isEqualTo(expected.getSingleSignOnServiceBinding());
+		assertThat(found.getWantAuthnRequestsSigned()).isEqualTo(expected.getWantAuthnRequestsSigned());
+		assertThat(found.getSingleLogoutServiceLocation()).isEqualTo(expected.getSingleLogoutServiceLocation());
+		assertThat(found.getSingleLogoutServiceResponseLocation())
+			.isEqualTo(expected.getSingleLogoutServiceResponseLocation());
+		assertThat(found.getSingleLogoutServiceBinding()).isEqualTo(expected.getSingleLogoutServiceBinding());
+		assertThat(found.getSigningAlgorithms()).containsAll(expected.getSigningAlgorithms());
+		assertThat(found.getVerificationX509Credentials()).containsAll(expected.getVerificationX509Credentials());
+		assertThat(found.getEncryptionX509Credentials()).containsAll(expected.getEncryptionX509Credentials());
 	}
 
 }