Przeglądaj źródła

Fix entity-id ignored in RelyingPartyRegistration XML config

Closes gh-11898
Marcus Da Coregio 2 lat temu
rodzic
commit
1c3ce1e401

+ 37 - 18
config/src/main/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParser.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -208,30 +208,49 @@ public final class RelyingPartyRegistrationsBeanDefinitionParser implements Bean
 			ParserContext parserContext) {
 		String registrationId = relyingPartyRegistrationElt.getAttribute(ATT_REGISTRATION_ID);
 		String metadataLocation = relyingPartyRegistrationElt.getAttribute(ATT_METADATA_LOCATION);
+		RelyingPartyRegistration.Builder builder;
+		if (StringUtils.hasText(metadataLocation)) {
+			builder = RelyingPartyRegistrations.fromMetadataLocation(metadataLocation).registrationId(registrationId);
+		}
+		else {
+			builder = RelyingPartyRegistration.withRegistrationId(registrationId)
+					.assertingPartyDetails((apBuilder) -> buildAssertingParty(relyingPartyRegistrationElt,
+							assertingParties, apBuilder, parserContext));
+		}
+		addRemainingProperties(relyingPartyRegistrationElt, builder);
+		return builder;
+	}
+
+	private static void addRemainingProperties(Element relyingPartyRegistrationElt,
+			RelyingPartyRegistration.Builder builder) {
+		String entityId = relyingPartyRegistrationElt.getAttribute(ATT_ENTITY_ID);
 		String singleLogoutServiceLocation = relyingPartyRegistrationElt
 				.getAttribute(ATT_SINGLE_LOGOUT_SERVICE_LOCATION);
 		String singleLogoutServiceResponseLocation = relyingPartyRegistrationElt
 				.getAttribute(ATT_SINGLE_LOGOUT_SERVICE_RESPONSE_LOCATION);
 		Saml2MessageBinding singleLogoutServiceBinding = getSingleLogoutServiceBinding(relyingPartyRegistrationElt);
-		if (StringUtils.hasText(metadataLocation)) {
-			return RelyingPartyRegistrations.fromMetadataLocation(metadataLocation).registrationId(registrationId)
-					.singleLogoutServiceLocation(singleLogoutServiceLocation)
-					.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocation)
-					.singleLogoutServiceBinding(singleLogoutServiceBinding);
-		}
-		String entityId = relyingPartyRegistrationElt.getAttribute(ATT_ENTITY_ID);
 		String assertionConsumerServiceLocation = relyingPartyRegistrationElt
 				.getAttribute(ATT_ASSERTION_CONSUMER_SERVICE_LOCATION);
 		Saml2MessageBinding assertionConsumerServiceBinding = getAssertionConsumerServiceBinding(
 				relyingPartyRegistrationElt);
-		return RelyingPartyRegistration.withRegistrationId(registrationId).entityId(entityId)
-				.assertionConsumerServiceLocation(assertionConsumerServiceLocation)
-				.assertionConsumerServiceBinding(assertionConsumerServiceBinding)
-				.singleLogoutServiceLocation(singleLogoutServiceLocation)
-				.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocation)
-				.singleLogoutServiceBinding(singleLogoutServiceBinding)
-				.assertingPartyDetails((builder) -> buildAssertingParty(relyingPartyRegistrationElt, assertingParties,
-						builder, parserContext));
+		if (StringUtils.hasText(entityId)) {
+			builder.entityId(entityId);
+		}
+		if (StringUtils.hasText(singleLogoutServiceLocation)) {
+			builder.singleLogoutServiceLocation(singleLogoutServiceLocation);
+		}
+		if (StringUtils.hasText(singleLogoutServiceResponseLocation)) {
+			builder.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocation);
+		}
+		if (singleLogoutServiceBinding != null) {
+			builder.singleLogoutServiceBinding(singleLogoutServiceBinding);
+		}
+		if (StringUtils.hasText(assertionConsumerServiceLocation)) {
+			builder.assertionConsumerServiceLocation(assertionConsumerServiceLocation);
+		}
+		if (assertionConsumerServiceBinding != null) {
+			builder.assertionConsumerServiceBinding(assertionConsumerServiceBinding);
+		}
 	}
 
 	private static void buildAssertingParty(Element relyingPartyElt, Map<String, Map<String, Object>> assertingParties,
@@ -309,7 +328,7 @@ public final class RelyingPartyRegistrationsBeanDefinitionParser implements Bean
 		if (StringUtils.hasText(assertionConsumerServiceBinding)) {
 			return Saml2MessageBinding.valueOf(assertionConsumerServiceBinding);
 		}
-		return Saml2MessageBinding.REDIRECT;
+		return null;
 	}
 
 	private static Saml2MessageBinding getSingleLogoutServiceBinding(Element relyingPartyRegistrationElt) {
@@ -317,7 +336,7 @@ public final class RelyingPartyRegistrationsBeanDefinitionParser implements Bean
 		if (StringUtils.hasText(singleLogoutServiceBinding)) {
 			return Saml2MessageBinding.valueOf(singleLogoutServiceBinding);
 		}
-		return Saml2MessageBinding.POST;
+		return null;
 	}
 
 	private static Saml2X509Credential getSaml2VerificationCredential(String certificateLocation) {

+ 57 - 1
config/src/test/java/org/springframework/security/config/saml2/RelyingPartyRegistrationsBeanDefinitionParserTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -62,6 +62,27 @@ public class RelyingPartyRegistrationsBeanDefinitionParserTests {
 			"</b:beans>\n";
 	// @formatter:on
 
+	// @formatter:off
+	private static final String METADATA_LOCATION_OVERRIDE_PROPERTIES_XML_CONFIG = "<b:beans xmlns:b=\"http://www.springframework.org/schema/beans\"\n" +
+			"         xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n" +
+			"         xmlns=\"http://www.springframework.org/schema/security\"\n" +
+			"         xsi:schemaLocation=\"\n" +
+			"\t\t\thttp://www.springframework.org/schema/security\n" +
+			"\t\t\thttps://www.springframework.org/schema/security/spring-security.xsd\n" +
+			"\t\t\thttp://www.springframework.org/schema/beans\n" +
+			"\t\t\thttps://www.springframework.org/schema/beans/spring-beans.xsd\">\n" +
+			"  \n" +
+			"  <relying-party-registrations>\n" +
+			"    <relying-party-registration registration-id=\"one\"\n" +
+			"                                entity-id=\"https://rp.example.org\"\n" +
+			"                                metadata-location=\"${metadata-location}\"\n" +
+			"                                assertion-consumer-service-location=\"https://rp.example.org/location\"\n" +
+			"                                assertion-consumer-service-binding=\"REDIRECT\"/>"  +
+			"  </relying-party-registrations>\n" +
+			"\n" +
+			"</b:beans>\n";
+	// @formatter:on
+
 	// @formatter:off
 	private static final String METADATA_RESPONSE = "<?xml version=\"1.0\"?>\n" +
 			"<md:EntityDescriptor xmlns:md=\"urn:oasis:names:tc:SAML:2.0:metadata\" xmlns:ds=\"http://www.w3.org/2000/09/xmldsig#\" entityID=\"https://simplesaml-for-spring-saml.apps.pcfone.io/saml2/idp/metadata.php\" ID=\"_e793a707d3e1a9ee6cbec7454fdad2c7cd793dd3703179a527b9620a6e9682af\"><ds:Signature>\n" +
@@ -143,6 +164,41 @@ public class RelyingPartyRegistrationsBeanDefinitionParserTests {
 				.containsExactly("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256");
 	}
 
+	@Test
+	public void parseWhenMetadataLocationConfiguredAndRegistrationHasPropertiesThenDoNotOverrideSpecifiedProperties()
+			throws Exception {
+		this.server = new MockWebServer();
+		this.server.start();
+		String serverUrl = this.server.url("/").toString();
+		this.server.enqueue(xmlResponse(METADATA_RESPONSE));
+		String metadataConfig = METADATA_LOCATION_OVERRIDE_PROPERTIES_XML_CONFIG.replace("${metadata-location}",
+				serverUrl);
+		this.spring.context(metadataConfig).autowire();
+		assertThat(this.relyingPartyRegistrationRepository)
+				.isInstanceOf(InMemoryRelyingPartyRegistrationRepository.class);
+		RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationRepository
+				.findByRegistrationId("one");
+		RelyingPartyRegistration.AssertingPartyDetails assertingPartyDetails = relyingPartyRegistration
+				.getAssertingPartyDetails();
+		assertThat(relyingPartyRegistration).isNotNull();
+		assertThat(relyingPartyRegistration.getRegistrationId()).isEqualTo("one");
+		assertThat(relyingPartyRegistration.getEntityId()).isEqualTo("https://rp.example.org");
+		assertThat(relyingPartyRegistration.getAssertionConsumerServiceLocation())
+				.isEqualTo("https://rp.example.org/location");
+		assertThat(relyingPartyRegistration.getAssertionConsumerServiceBinding())
+				.isEqualTo(Saml2MessageBinding.REDIRECT);
+		assertThat(assertingPartyDetails.getEntityId())
+				.isEqualTo("https://simplesaml-for-spring-saml.apps.pcfone.io/saml2/idp/metadata.php");
+		assertThat(assertingPartyDetails.getWantAuthnRequestsSigned()).isFalse();
+		assertThat(assertingPartyDetails.getVerificationX509Credentials()).hasSize(1);
+		assertThat(assertingPartyDetails.getEncryptionX509Credentials()).hasSize(1);
+		assertThat(assertingPartyDetails.getSingleSignOnServiceLocation())
+				.isEqualTo("https://simplesaml-for-spring-saml.apps.pcfone.io/saml2/idp/SSOService.php");
+		assertThat(assertingPartyDetails.getSingleSignOnServiceBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
+		assertThat(assertingPartyDetails.getSigningAlgorithms())
+				.containsExactly("http://www.w3.org/2001/04/xmldsig-more#rsa-sha256");
+	}
+
 	@Test
 	public void parseWhenSingleRelyingPartyRegistrationThenAvailableInRepository() {
 		this.spring.configLocations(xml("SingleRegistration")).autowire();