2
0
Эх сурвалжийг харах

Add JdbcAssertingPartyMetadataRepository

Closes gh-16012

Signed-off-by: chao.wang <chao.wang@zatech.com>
chao.wang 3 сар өмнө
parent
commit
16fd24c002

+ 2 - 0
saml2/saml2-service-provider/spring-security-saml2-service-provider.gradle

@@ -106,6 +106,7 @@ dependencies {
 	provided 'jakarta.servlet:jakarta.servlet-api'
 
 	optional 'com.fasterxml.jackson.core:jackson-databind'
+	optional 'org.springframework:spring-jdbc'
 
 	testImplementation 'com.squareup.okhttp3:mockwebserver'
 	testImplementation "org.assertj:assertj-core"
@@ -118,6 +119,7 @@ dependencies {
 	testImplementation "org.springframework:spring-test"
 
 	testRuntimeOnly 'org.junit.platform:junit-platform-launcher'
+	testRuntimeOnly 'org.hsqldb:hsqldb'
 }
 
 jar {

+ 190 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepository.java

@@ -0,0 +1,190 @@
+/*
+ * Copyright 2002-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.registration;
+
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Types;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import org.springframework.core.log.LogMessage;
+import org.springframework.core.serializer.DefaultDeserializer;
+import org.springframework.core.serializer.Deserializer;
+import org.springframework.jdbc.core.ArgumentPreparedStatementSetter;
+import org.springframework.jdbc.core.JdbcOperations;
+import org.springframework.jdbc.core.PreparedStatementSetter;
+import org.springframework.jdbc.core.RowMapper;
+import org.springframework.jdbc.core.SqlParameterValue;
+import org.springframework.security.saml2.core.Saml2X509Credential;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.AssertingPartyDetails;
+import org.springframework.util.Assert;
+
+/**
+ * A JDBC implementation of {@link AssertingPartyMetadataRepository}.
+ *
+ * @author Cathy Wang
+ * @since 7.0
+ */
+public final class JdbcAssertingPartyMetadataRepository implements AssertingPartyMetadataRepository {
+
+	private final JdbcOperations jdbcOperations;
+
+	private RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper = new AssertingPartyMetadataRowMapper(
+			ResultSet::getBytes);
+
+	// @formatter:off
+	static final String COLUMN_NAMES = "entity_id, "
+			+ "singlesignon_url, "
+			+ "singlesignon_binding, "
+			+ "singlesignon_sign_request, "
+			+ "signing_algorithms, "
+			+ "verification_credentials, "
+			+ "encryption_credentials, "
+			+ "singlelogout_url, "
+			+ "singlelogout_response_url, "
+			+ "singlelogout_binding";
+	// @formatter:on
+
+	private static final String TABLE_NAME = "saml2_asserting_party_metadata";
+
+	private static final String ENTITY_ID_FILTER = "entity_id = ?";
+
+	// @formatter:off
+	private static final String LOAD_BY_ID_SQL = "SELECT " + COLUMN_NAMES
+			+ " FROM " + TABLE_NAME
+			+ " WHERE " + ENTITY_ID_FILTER;
+
+	private static final String LOAD_ALL_SQL = "SELECT " + COLUMN_NAMES
+			+ " FROM " + TABLE_NAME;
+	// @formatter:on
+
+	/**
+	 * Constructs a {@code JdbcRelyingPartyRegistrationRepository} using the provided
+	 * parameters.
+	 * @param jdbcOperations the JDBC operations
+	 */
+	public JdbcAssertingPartyMetadataRepository(JdbcOperations jdbcOperations) {
+		Assert.notNull(jdbcOperations, "jdbcOperations cannot be null");
+		this.jdbcOperations = jdbcOperations;
+	}
+
+	/**
+	 * Sets the {@link RowMapper} used for mapping the current row in
+	 * {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}. The default is
+	 * {@link AssertingPartyMetadataRowMapper}.
+	 * @param assertingPartyMetadataRowMapper the {@link RowMapper} used for mapping the
+	 * current row in {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}
+	 */
+	public void setAssertingPartyMetadataRowMapper(RowMapper<AssertingPartyMetadata> assertingPartyMetadataRowMapper) {
+		Assert.notNull(assertingPartyMetadataRowMapper, "assertingPartyMetadataRowMapper cannot be null");
+		this.assertingPartyMetadataRowMapper = assertingPartyMetadataRowMapper;
+	}
+
+	@Override
+	public AssertingPartyMetadata findByEntityId(String entityId) {
+		Assert.hasText(entityId, "entityId cannot be empty");
+		SqlParameterValue[] parameters = new SqlParameterValue[] { new SqlParameterValue(Types.VARCHAR, entityId) };
+		PreparedStatementSetter pss = new ArgumentPreparedStatementSetter(parameters);
+		List<AssertingPartyMetadata> result = this.jdbcOperations.query(LOAD_BY_ID_SQL, pss,
+				this.assertingPartyMetadataRowMapper);
+		return !result.isEmpty() ? result.get(0) : null;
+	}
+
+	@Override
+	public Iterator<AssertingPartyMetadata> iterator() {
+		List<AssertingPartyMetadata> result = this.jdbcOperations.query(LOAD_ALL_SQL,
+				this.assertingPartyMetadataRowMapper);
+		return result.iterator();
+	}
+
+	/**
+	 * The default {@link RowMapper} that maps the current row in
+	 * {@code java.sql.ResultSet} to {@link AssertingPartyMetadata}.
+	 */
+	private static final class AssertingPartyMetadataRowMapper implements RowMapper<AssertingPartyMetadata> {
+
+		private final Log logger = LogFactory.getLog(AssertingPartyMetadataRowMapper.class);
+
+		private final Deserializer<Object> deserializer = new DefaultDeserializer();
+
+		private final GetBytes getBytes;
+
+		AssertingPartyMetadataRowMapper(GetBytes getBytes) {
+			this.getBytes = getBytes;
+		}
+
+		@Override
+		public AssertingPartyMetadata mapRow(ResultSet rs, int rowNum) throws SQLException {
+			String entityId = rs.getString("entity_id");
+			String singleSignOnUrl = rs.getString("singlesignon_url");
+			Saml2MessageBinding singleSignOnBinding = Saml2MessageBinding.from(rs.getString("singlesignon_binding"));
+			boolean singleSignOnSignRequest = rs.getBoolean("singlesignon_sign_request");
+			String singleLogoutUrl = rs.getString("singlelogout_url");
+			String singleLogoutResponseUrl = rs.getString("singlelogout_response_url");
+			Saml2MessageBinding singleLogoutBinding = Saml2MessageBinding.from(rs.getString("singlelogout_binding"));
+			byte[] signingAlgorithmsBytes = this.getBytes.getBytes(rs, "signing_algorithms");
+			byte[] verificationCredentialsBytes = this.getBytes.getBytes(rs, "verification_credentials");
+			byte[] encryptionCredentialsBytes = this.getBytes.getBytes(rs, "encryption_credentials");
+
+			AssertingPartyMetadata.Builder<?> builder = new AssertingPartyDetails.Builder();
+			try {
+				if (signingAlgorithmsBytes != null) {
+					List<String> signingAlgorithms = (List<String>) this.deserializer
+						.deserializeFromByteArray(signingAlgorithmsBytes);
+					builder.signingAlgorithms((algorithms) -> algorithms.addAll(signingAlgorithms));
+				}
+				if (verificationCredentialsBytes != null) {
+					Collection<Saml2X509Credential> verificationCredentials = (Collection<Saml2X509Credential>) this.deserializer
+						.deserializeFromByteArray(verificationCredentialsBytes);
+					builder.verificationX509Credentials((credentials) -> credentials.addAll(verificationCredentials));
+				}
+				if (encryptionCredentialsBytes != null) {
+					Collection<Saml2X509Credential> encryptionCredentials = (Collection<Saml2X509Credential>) this.deserializer
+						.deserializeFromByteArray(encryptionCredentialsBytes);
+					builder.encryptionX509Credentials((credentials) -> credentials.addAll(encryptionCredentials));
+				}
+			}
+			catch (Exception ex) {
+				this.logger.debug(LogMessage.format("Parsing serialized credentials for entity %s failed", entityId),
+						ex);
+				return null;
+			}
+
+			builder.entityId(entityId)
+				.wantAuthnRequestsSigned(singleSignOnSignRequest)
+				.singleSignOnServiceLocation(singleSignOnUrl)
+				.singleSignOnServiceBinding(singleSignOnBinding)
+				.singleLogoutServiceLocation(singleLogoutUrl)
+				.singleLogoutServiceBinding(singleLogoutBinding)
+				.singleLogoutServiceResponseLocation(singleLogoutResponseUrl);
+			return builder.build();
+		}
+
+	}
+
+	private interface GetBytes {
+
+		byte[] getBytes(ResultSet rs, String columnName) throws SQLException;
+
+	}
+
+}

+ 14 - 0
saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema-postgres.sql

@@ -0,0 +1,14 @@
+CREATE TABLE saml2_asserting_party_metadata
+(
+    entity_id                 VARCHAR(1000) NOT NULL,
+    singlesignon_url          VARCHAR(1000) NOT NULL,
+    singlesignon_binding      VARCHAR(100),
+    singlesignon_sign_request boolean,
+    signing_algorithms        BYTEA,
+    verification_credentials  BYTEA         NOT NULL,
+    encryption_credentials    BYTEA,
+    singlelogout_url          VARCHAR(1000),
+    singlelogout_response_url VARCHAR(1000),
+    singlelogout_binding      VARCHAR(100),
+    PRIMARY KEY (entity_id)
+);

+ 14 - 0
saml2/saml2-service-provider/src/main/resources/org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql

@@ -0,0 +1,14 @@
+CREATE TABLE saml2_asserting_party_metadata
+(
+    entity_id                 VARCHAR(1000) NOT NULL,
+    singlesignon_url          VARCHAR(1000) NOT NULL,
+    singlesignon_binding      VARCHAR(100),
+    singlesignon_sign_request boolean,
+    signing_algorithms        blob,
+    verification_credentials  blob          NOT NULL,
+    encryption_credentials    blob,
+    singlelogout_url          VARCHAR(1000),
+    singlelogout_response_url VARCHAR(1000),
+    singlelogout_binding      VARCHAR(100),
+    PRIMARY KEY (entity_id)
+);

+ 177 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/registration/JdbcAssertingPartyMetadataRepositoryTests.java

@@ -0,0 +1,177 @@
+package org.springframework.security.saml2.provider.service.registration;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.security.cert.CertificateFactory;
+import java.security.cert.X509Certificate;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import org.springframework.core.io.ClassPathResource;
+import org.springframework.core.serializer.DefaultSerializer;
+import org.springframework.core.serializer.Serializer;
+import org.springframework.jdbc.core.JdbcOperations;
+import org.springframework.jdbc.core.JdbcTemplate;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabase;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseBuilder;
+import org.springframework.jdbc.datasource.embedded.EmbeddedDatabaseType;
+import org.springframework.security.saml2.core.Saml2X509Credential;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+
+/**
+ * Tests for {@link JdbcAssertingPartyMetadataRepository}
+ */
+class JdbcAssertingPartyMetadataRepositoryTests {
+
+	private static final String SCHEMA_SQL_RESOURCE = "org/springframework/security/saml2/saml2-asserting-party-metadata-schema.sql";
+
+	private static final String SAVE_SQL = "INSERT INTO saml2_asserting_party_metadata ("
+			+ JdbcAssertingPartyMetadataRepository.COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)";
+
+	private static final String ENTITY_ID = "https://localhost/simplesaml/saml2/idp/metadata.php";
+
+	private static final String SINGLE_SIGNON_URL = "https://localhost/SSO";
+
+	private static final String SINGLE_SIGNON_BINDING = Saml2MessageBinding.REDIRECT.getUrn();
+
+	private static final boolean SINGLE_SIGNON_SIGN_REQUEST = false;
+
+	private static final String SINGLE_LOGOUT_URL = "https://localhost/SLO";
+
+	private static final String SINGLE_LOGOUT_RESPONSE_URL = "https://localhost/SLO/response";
+
+	private static final String SINGLE_LOGOUT_BINDING = Saml2MessageBinding.REDIRECT.getUrn();
+
+	private static final List<String> SIGNING_ALGORITHMS = List.of("http://www.w3.org/2001/04/xmldsig-more#rsa-sha512");
+
+	private X509Certificate certificate;
+
+	private EmbeddedDatabase db;
+
+	private JdbcAssertingPartyMetadataRepository repository;
+
+	private JdbcOperations jdbcOperations;
+
+	private final Serializer<Object> serializer = new DefaultSerializer();
+
+	@BeforeEach
+	public void setUp() throws Exception {
+		this.db = createDb();
+		this.jdbcOperations = new JdbcTemplate(this.db);
+		this.repository = new JdbcAssertingPartyMetadataRepository(this.jdbcOperations);
+		this.certificate = loadCertificate("rsa.crt");
+	}
+
+	@AfterEach
+	public void tearDown() {
+		this.db.shutdown();
+	}
+
+	@Test
+	void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new JdbcAssertingPartyMetadataRepository(null))
+				.withMessage("jdbcOperations cannot be null");
+		// @formatter:on
+	}
+
+	@Test
+	void findByEntityIdWhenEntityIdIsNullThenThrowIllegalArgumentException() {
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.repository.findByEntityId(null))
+				.withMessage("entityId cannot be empty");
+		// @formatter:on
+	}
+
+	@Test
+	void findByEntityId() throws IOException {
+		this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING,
+				SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
+				this.serializer.serializeToByteArray(asCredentials(this.certificate)),
+				this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
+				SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
+
+		AssertingPartyMetadata found = this.repository.findByEntityId(ENTITY_ID);
+
+		assertThat(found).isNotNull();
+		assertThat(found.getEntityId()).isEqualTo(ENTITY_ID);
+		assertThat(found.getSingleSignOnServiceLocation()).isEqualTo(SINGLE_SIGNON_URL);
+		assertThat(found.getSingleSignOnServiceBinding().getUrn()).isEqualTo(SINGLE_SIGNON_BINDING);
+		assertThat(found.getWantAuthnRequestsSigned()).isEqualTo(SINGLE_SIGNON_SIGN_REQUEST);
+		assertThat(found.getSingleLogoutServiceLocation()).isEqualTo(SINGLE_LOGOUT_URL);
+		assertThat(found.getSingleLogoutServiceResponseLocation()).isEqualTo(SINGLE_LOGOUT_RESPONSE_URL);
+		assertThat(found.getSingleLogoutServiceBinding().getUrn()).isEqualTo(SINGLE_LOGOUT_BINDING);
+		assertThat(found.getSigningAlgorithms()).contains(SIGNING_ALGORITHMS.get(0));
+		assertThat(found.getVerificationX509Credentials()).hasSize(1);
+		assertThat(found.getEncryptionX509Credentials()).hasSize(1);
+	}
+
+	@Test
+	void findByEntityIdWhenNotExists() {
+		AssertingPartyMetadata found = this.repository.findByEntityId("non-existent-entity-id");
+		assertThat(found).isNull();
+	}
+
+	@Test
+	void iterator() throws IOException {
+		this.jdbcOperations.update(SAVE_SQL, ENTITY_ID, SINGLE_SIGNON_URL, SINGLE_SIGNON_BINDING,
+				SINGLE_SIGNON_SIGN_REQUEST, this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
+				this.serializer.serializeToByteArray(asCredentials(this.certificate)),
+				this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
+				SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
+
+		this.jdbcOperations.update(SAVE_SQL, "https://localhost/simplesaml2/saml2/idp/metadata.php", SINGLE_SIGNON_URL,
+				SINGLE_SIGNON_BINDING, SINGLE_SIGNON_SIGN_REQUEST,
+				this.serializer.serializeToByteArray(SIGNING_ALGORITHMS),
+				this.serializer.serializeToByteArray(asCredentials(this.certificate)),
+				this.serializer.serializeToByteArray(asCredentials(this.certificate)), SINGLE_LOGOUT_URL,
+				SINGLE_LOGOUT_RESPONSE_URL, SINGLE_LOGOUT_BINDING);
+
+		Iterator<AssertingPartyMetadata> iterator = this.repository.iterator();
+		AssertingPartyMetadata first = iterator.next();
+		assertThat(first).isNotNull();
+		AssertingPartyMetadata second = iterator.next();
+		assertThat(second).isNotNull();
+		assertThat(iterator.hasNext()).isFalse();
+	}
+
+	private static EmbeddedDatabase createDb() {
+		return createDb(SCHEMA_SQL_RESOURCE);
+	}
+
+	private static EmbeddedDatabase createDb(String schema) {
+		// @formatter:off
+		return new EmbeddedDatabaseBuilder()
+				.generateUniqueName(true)
+				.setType(EmbeddedDatabaseType.HSQL)
+				.setScriptEncoding("UTF-8")
+				.addScript(schema)
+				.build();
+		// @formatter:on
+	}
+
+	private X509Certificate loadCertificate(String path) {
+		try (InputStream is = new ClassPathResource(path).getInputStream()) {
+			CertificateFactory factory = CertificateFactory.getInstance("X.509");
+			return (X509Certificate) factory.generateCertificate(is);
+		}
+		catch (Exception ex) {
+			throw new RuntimeException("Error loading certificate from " + path, ex);
+		}
+	}
+
+	private Collection<Saml2X509Credential> asCredentials(X509Certificate certificate) {
+		return List.of(new Saml2X509Credential(certificate, Saml2X509Credential.Saml2X509CredentialType.ENCRYPTION,
+				Saml2X509Credential.Saml2X509CredentialType.VERIFICATION));
+	}
+
+}

+ 23 - 0
saml2/saml2-service-provider/src/test/resources/rsa.crt

@@ -0,0 +1,23 @@
+-----BEGIN CERTIFICATE-----
+MIID1zCCAr+gAwIBAgIUCzQeKBMTO0iHVW3iKmZC41haqCowDQYJKoZIhvcNAQEL
+BQAwezELMAkGA1UEBhMCWFgxEjAQBgNVBAgMCVN0YXRlTmFtZTERMA8GA1UEBwwI
+Q2l0eU5hbWUxFDASBgNVBAoMC0NvbXBhbnlOYW1lMRswGQYDVQQLDBJDb21wYW55
+U2VjdGlvbk5hbWUxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0yMzA5MjAwODI5MDNa
+Fw0zMzA5MTcwODI5MDNaMHsxCzAJBgNVBAYTAlhYMRIwEAYDVQQIDAlTdGF0ZU5h
+bWUxETAPBgNVBAcMCENpdHlOYW1lMRQwEgYDVQQKDAtDb21wYW55TmFtZTEbMBkG
+A1UECwwSQ29tcGFueVNlY3Rpb25OYW1lMRIwEAYDVQQDDAlsb2NhbGhvc3QwggEi
+MA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDUfi4aaCotJZX6OSDjv6fxCCfc
+ihSs91Z/mmN+yc1fsxVSs53SIbqUuo+Wzhv34kp8I/r03P9LWVTkFPbeDxAl75Oa
+PGggxK55US0Zfy9Hj1BwWIKV3330N61emID1GDEtFKL4yJbJdreQXnIXTBL2o76V
+nuV/tYozyZnb07IQ1WhUm5WDxgzM0yFudMynTczCBeZHfvharDtB8PFFhCZXW2/9
+TZVVfW4oOML8EAX3hvnvYBlFl/foxXekZSwq/odOkmWCZavT2+0sburHUlOnPGUh
+Qj4tHwpMRczp7VX4ptV1D2UrxsK/2B+s9FK2QSLKQ9JzAYJ6WxQjHcvET9jvAgMB
+AAGjUzBRMB0GA1UdDgQWBBQjDr/1E/01pfLPD8uWF7gbaYL0TTAfBgNVHSMEGDAW
+gBQjDr/1E/01pfLPD8uWF7gbaYL0TTAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3
+DQEBCwUAA4IBAQAGjUuec0+0XNMCRDKZslbImdCAVsKsEWk6NpnUViDFAxL+KQuC
+NW131UeHb9SCzMqRwrY4QI3nAwJQCmilL/hFM3ss4acn3WHu1yci/iKPUKeL1ec5
+kCFUmqX1NpTiVaytZ/9TKEr69SMVqNfQiuW5U1bIIYTqK8xo46WpM6YNNHO3eJK6
+NH0MW79Wx5ryi4i4C6afqYbVbx7tqcmy8CFeNxgZ0bFQ87SiwYXIj77b6sVYbu32
+doykBQgSHLcagWASPQ73m73CWUgo+7+EqSKIQqORbgmTLPmOUh99gFIx7jmjTyHm
+NBszx1ZVWuIv3mWmp626Kncyc+LLM9tvgymx
+-----END CERTIFICATE-----