소스 검색

Merge branch '6.1.x'

Closes gh-13701
Josh Cummings 2 년 전
부모
커밋
3540dee259

+ 10 - 14
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolver.java

@@ -19,8 +19,6 @@ package org.springframework.security.saml2.provider.service.metadata;
 import java.io.UnsupportedEncodingException;
 import java.net.URLEncoder;
 import java.nio.charset.StandardCharsets;
-import java.util.ArrayList;
-import java.util.Collection;
 import java.util.Collections;
 import java.util.LinkedHashMap;
 import java.util.Map;
@@ -126,21 +124,19 @@ public final class RequestMatcherMetadataResponseResolver implements Saml2Metada
 			Iterable<RelyingPartyRegistration> registrations) {
 		Map<String, RelyingPartyRegistration> results = new LinkedHashMap<>();
 		for (RelyingPartyRegistration registration : registrations) {
-			results.put(registration.getEntityId(), registration);
-		}
-		Collection<RelyingPartyRegistration> resolved = new ArrayList<>();
-		for (RelyingPartyRegistration registration : results.values()) {
 			UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
 			String entityId = uriResolver.resolve(registration.getEntityId());
-			String ssoLocation = uriResolver.resolve(registration.getAssertionConsumerServiceLocation());
-			String sloLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation());
-			String sloResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation());
-			resolved.add(registration.mutate().entityId(entityId).assertionConsumerServiceLocation(ssoLocation)
-					.singleLogoutServiceLocation(sloLocation).singleLogoutServiceResponseLocation(sloResponseLocation)
-					.build());
+			results.computeIfAbsent(entityId, (e) -> {
+				String ssoLocation = uriResolver.resolve(registration.getAssertionConsumerServiceLocation());
+				String sloLocation = uriResolver.resolve(registration.getSingleLogoutServiceLocation());
+				String sloResponseLocation = uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation());
+				return registration.mutate().entityId(entityId).assertionConsumerServiceLocation(ssoLocation)
+						.singleLogoutServiceLocation(sloLocation)
+						.singleLogoutServiceResponseLocation(sloResponseLocation).build();
+			});
 		}
-		String metadata = this.metadata.resolve(resolved);
-		String value = (resolved.size() == 1) ? resolved.iterator().next().getRegistrationId()
+		String metadata = this.metadata.resolve(results.values());
+		String value = (results.size() == 1) ? results.values().iterator().next().getRegistrationId()
 				: UUID.randomUUID().toString();
 		String fileName = this.filename.replace("{registrationId}", value);
 		try {

+ 20 - 2
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/metadata/RequestMatcherMetadataResponseResolverTests.java

@@ -20,6 +20,7 @@ import java.util.Collection;
 
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.junit.jupiter.MockitoExtension;
 
@@ -101,6 +102,23 @@ public final class RequestMatcherMetadataResponseResolverTests {
 		assertThat(resolver.resolve(new MockHttpServletRequest())).isNull();
 	}
 
+	// gh-13700
+	@Test
+	void resolveWhenNoRegistrationIdThenResolvesEntityIds() {
+		RelyingPartyRegistration one = withEntityId("one");
+		RelyingPartyRegistration two = withEntityId("two");
+		RelyingPartyRegistrationRepository registrations = new InMemoryRelyingPartyRegistrationRepository(one, two);
+		RequestMatcherMetadataResponseResolver resolver = new RequestMatcherMetadataResponseResolver(registrations,
+				this.metadataFactory);
+		given(this.metadataFactory.resolve(any(Collection.class))).willReturn("metadata");
+		resolver.resolve(get("/saml2/metadata"));
+		ArgumentCaptor<Collection<RelyingPartyRegistration>> captor = ArgumentCaptor.forClass(Collection.class);
+		verify(this.metadataFactory).resolve(captor.capture());
+		Collection<RelyingPartyRegistration> resolved = captor.getValue();
+		assertThat(resolved).hasSize(2);
+		assertThat(resolved.iterator().next().getEntityId()).isEqualTo("one");
+	}
+
 	private MockHttpServletRequest get(String uri) {
 		MockHttpServletRequest request = new MockHttpServletRequest("GET", uri);
 		request.setServletPath(uri);
@@ -108,8 +126,8 @@ public final class RequestMatcherMetadataResponseResolverTests {
 	}
 
 	private RelyingPartyRegistration withEntityId(String entityId) {
-		return TestRelyingPartyRegistrations.relyingPartyRegistration().registrationId(entityId).entityId(entityId)
-				.build();
+		return TestRelyingPartyRegistrations.relyingPartyRegistration().registrationId(entityId)
+				.entityId("{registrationId}").build();
 	}
 
 }