Browse Source

Extract Placeholder Resolution

Closes gh-12842
Josh Cummings 2 years ago
parent
commit
37b893a0f5

+ 10 - 61
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java

@@ -16,10 +16,6 @@
 
 package org.springframework.security.saml2.provider.service.web;
 
-import java.util.HashMap;
-import java.util.Map;
-import java.util.function.Function;
-
 import jakarta.servlet.http.HttpServletRequest;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -27,13 +23,10 @@ import org.apache.commons.logging.LogFactory;
 import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
-import org.springframework.security.web.util.UrlUtils;
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
-import org.springframework.util.StringUtils;
-import org.springframework.web.util.UriComponents;
-import org.springframework.web.util.UriComponentsBuilder;
 
 /**
  * A {@link Converter} that resolves a {@link RelyingPartyRegistration} by extracting the
@@ -48,8 +41,6 @@ public final class DefaultRelyingPartyRegistrationResolver
 
 	private Log logger = LogFactory.getLog(getClass());
 
-	private static final char PATH_DELIMITER = '/';
-
 	private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
 
 	private final RequestMatcher registrationRequestMatcher = new AntPathRequestMatcher("/**/{registrationId}");
@@ -87,61 +78,19 @@ public final class DefaultRelyingPartyRegistrationResolver
 			}
 			return null;
 		}
-		RelyingPartyRegistration relyingPartyRegistration = this.relyingPartyRegistrationRepository
+		RelyingPartyRegistration registration = this.relyingPartyRegistrationRepository
 				.findByRegistrationId(relyingPartyRegistrationId);
-		if (relyingPartyRegistration == null) {
+		if (registration == null) {
 			return null;
 		}
-		String applicationUri = getApplicationUri(request);
-		Function<String, String> templateResolver = templateResolver(applicationUri, relyingPartyRegistration);
-		String relyingPartyEntityId = templateResolver.apply(relyingPartyRegistration.getEntityId());
-		String assertionConsumerServiceLocation = templateResolver
-				.apply(relyingPartyRegistration.getAssertionConsumerServiceLocation());
-		String singleLogoutServiceLocation = templateResolver
-				.apply(relyingPartyRegistration.getSingleLogoutServiceLocation());
-		String singleLogoutServiceResponseLocation = templateResolver
-				.apply(relyingPartyRegistration.getSingleLogoutServiceResponseLocation());
-		return relyingPartyRegistration.mutate().entityId(relyingPartyEntityId)
-				.assertionConsumerServiceLocation(assertionConsumerServiceLocation)
-				.singleLogoutServiceLocation(singleLogoutServiceLocation)
-				.singleLogoutServiceResponseLocation(singleLogoutServiceResponseLocation).build();
-	}
-
-	private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
-		return (template) -> resolveUrlTemplate(template, applicationUri, relyingParty);
-	}
-
-	private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
-		if (template == null) {
-			return null;
-		}
-		String entityId = relyingParty.getAssertingPartyDetails().getEntityId();
-		String registrationId = relyingParty.getRegistrationId();
-		Map<String, String> uriVariables = new HashMap<>();
-		UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl).replaceQuery(null).fragment(null)
+		UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
+		return registration.mutate().entityId(uriResolver.resolve(registration.getEntityId()))
+				.assertionConsumerServiceLocation(
+						uriResolver.resolve(registration.getAssertionConsumerServiceLocation()))
+				.singleLogoutServiceLocation(uriResolver.resolve(registration.getSingleLogoutServiceLocation()))
+				.singleLogoutServiceResponseLocation(
+						uriResolver.resolve(registration.getSingleLogoutServiceResponseLocation()))
 				.build();
-		String scheme = uriComponents.getScheme();
-		uriVariables.put("baseScheme", (scheme != null) ? scheme : "");
-		String host = uriComponents.getHost();
-		uriVariables.put("baseHost", (host != null) ? host : "");
-		// following logic is based on HierarchicalUriComponents#toUriString()
-		int port = uriComponents.getPort();
-		uriVariables.put("basePort", (port == -1) ? "" : ":" + port);
-		String path = uriComponents.getPath();
-		if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) {
-			path = PATH_DELIMITER + path;
-		}
-		uriVariables.put("basePath", (path != null) ? path : "");
-		uriVariables.put("baseUrl", uriComponents.toUriString());
-		uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
-		uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
-		return UriComponentsBuilder.fromUriString(template).buildAndExpand(uriVariables).toUriString();
-	}
-
-	private static String getApplicationUri(HttpServletRequest request) {
-		UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
-				.replacePath(request.getContextPath()).replaceQuery(null).fragment(null).build();
-		return uriComponents.toUriString();
 	}
 
 }

+ 129 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/RelyingPartyRegistrationPlaceholderResolvers.java

@@ -0,0 +1,129 @@
+/*
+ * Copyright 2002-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web;
+
+import java.util.HashMap;
+import java.util.Map;
+
+import jakarta.servlet.http.HttpServletRequest;
+
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.web.util.UrlUtils;
+import org.springframework.util.StringUtils;
+import org.springframework.web.util.UriComponents;
+import org.springframework.web.util.UriComponentsBuilder;
+
+/**
+ * A factory for creating placeholder resolvers for {@link RelyingPartyRegistration}
+ * templates. Supports {@code baseUrl}, {@code baseScheme}, {@code baseHost},
+ * {@code basePort}, {@code basePath}, {@code registrationId},
+ * {@code relyingPartyEntityId}, and {@code assertingPartyEntityId}
+ *
+ * @author Josh Cummings
+ * @since 6.1
+ */
+public final class RelyingPartyRegistrationPlaceholderResolvers {
+
+	private static final char PATH_DELIMITER = '/';
+
+	private RelyingPartyRegistrationPlaceholderResolvers() {
+
+	}
+
+	/**
+	 * Create a resolver based on the given {@link HttpServletRequest}. Given the request,
+	 * placeholders {@code baseUrl}, {@code baseScheme}, {@code baseHost},
+	 * {@code basePort}, and {@code basePath} are resolved.
+	 * @param request the HTTP request
+	 * @return a resolver that can resolve {@code baseUrl}, {@code baseScheme},
+	 * {@code baseHost}, {@code basePort}, and {@code basePath} placeholders
+	 */
+	public static UriResolver uriResolver(HttpServletRequest request) {
+		return new UriResolver(uriVariables(request));
+	}
+
+	/**
+	 * Create a resolver based on the given {@link HttpServletRequest}. Given the request,
+	 * placeholders {@code baseUrl}, {@code baseScheme}, {@code baseHost},
+	 * {@code basePort}, {@code basePath}, {@code registrationId},
+	 * {@code assertingPartyEntityId}, and {@code relyingPartyEntityId} are resolved.
+	 * @param request the HTTP request
+	 * @return a resolver that can resolve {@code baseUrl}, {@code baseScheme},
+	 * {@code baseHost}, {@code basePort}, {@code basePath}, {@code registrationId},
+	 * {@code relyingPartyEntityId}, and {@code assertingPartyEntityId} placeholders
+	 */
+	public static UriResolver uriResolver(HttpServletRequest request, RelyingPartyRegistration registration) {
+		String relyingPartyEntityId = registration.getEntityId();
+		String assertingPartyEntityId = registration.getAssertingPartyDetails().getEntityId();
+		String registrationId = registration.getRegistrationId();
+		Map<String, String> uriVariables = uriVariables(request);
+		uriVariables.put("relyingPartyEntityId", StringUtils.hasText(relyingPartyEntityId) ? relyingPartyEntityId : "");
+		uriVariables.put("assertingPartyEntityId",
+				StringUtils.hasText(assertingPartyEntityId) ? assertingPartyEntityId : "");
+		uriVariables.put("entityId", StringUtils.hasText(assertingPartyEntityId) ? assertingPartyEntityId : "");
+		uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
+		return new UriResolver(uriVariables);
+	}
+
+	private static Map<String, String> uriVariables(HttpServletRequest request) {
+		String baseUrl = getApplicationUri(request);
+		Map<String, String> uriVariables = new HashMap<>();
+		UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl).replaceQuery(null).fragment(null)
+				.build();
+		String scheme = uriComponents.getScheme();
+		uriVariables.put("baseScheme", (scheme != null) ? scheme : "");
+		String host = uriComponents.getHost();
+		uriVariables.put("baseHost", (host != null) ? host : "");
+		// following logic is based on HierarchicalUriComponents#toUriString()
+		int port = uriComponents.getPort();
+		uriVariables.put("basePort", (port == -1) ? "" : ":" + port);
+		String path = uriComponents.getPath();
+		if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) {
+			path = PATH_DELIMITER + path;
+		}
+		uriVariables.put("basePath", (path != null) ? path : "");
+		uriVariables.put("baseUrl", uriComponents.toUriString());
+		return uriVariables;
+	}
+
+	private static String getApplicationUri(HttpServletRequest request) {
+		UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
+				.replacePath(request.getContextPath()).replaceQuery(null).fragment(null).build();
+		return uriComponents.toUriString();
+	}
+
+	/**
+	 * A class for resolving {@link RelyingPartyRegistration} URIs
+	 */
+	public static final class UriResolver {
+
+		private final Map<String, String> uriVariables;
+
+		private UriResolver(Map<String, String> uriVariables) {
+			this.uriVariables = uriVariables;
+		}
+
+		public String resolve(String uri) {
+			if (uri == null) {
+				return null;
+			}
+			return UriComponentsBuilder.fromUriString(uri).buildAndExpand(this.uriVariables).toUriString();
+		}
+
+	}
+
+}

+ 56 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/RelyingPartyRegistrationPlaceholderResolversTests.java

@@ -0,0 +1,56 @@
+/*
+ * Copyright 2002-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web;
+
+import org.junit.jupiter.api.Test;
+
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
+import org.springframework.security.saml2.provider.service.web.RelyingPartyRegistrationPlaceholderResolvers.UriResolver;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+
+/**
+ * Tests for {@link RelyingPartyRegistrationPlaceholderResolvers}
+ */
+public class RelyingPartyRegistrationPlaceholderResolversTests {
+
+	@Test
+	void uriResolverGivenRequestCreatesResolver() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request);
+		String resolved = uriResolver.resolve("{baseUrl}/extension");
+		assertThat(resolved).isEqualTo("http://localhost/extension");
+		assertThatExceptionOfType(IllegalArgumentException.class)
+				.isThrownBy(() -> uriResolver.resolve("{baseUrl}/extension/{registrationId}"));
+	}
+
+	@Test
+	void uriResolverGivenRequestAndRegistrationCreatesResolver() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.relyingPartyRegistration()
+				.entityId("http://sp.example.org").build();
+		UriResolver uriResolver = RelyingPartyRegistrationPlaceholderResolvers.uriResolver(request, registration);
+		String resolved = uriResolver.resolve("{baseUrl}/extension/{registrationId}");
+		assertThat(resolved).isEqualTo("http://localhost/extension/simplesamlphp");
+		resolved = uriResolver.resolve("{relyingPartyEntityId}/extension");
+		assertThat(resolved).isEqualTo("http://sp.example.org/extension");
+	}
+
+}