Selaa lähdekoodia

Add DefaultRelyingPartyRegistrationResolver

Closes gh-8887
Josh Cummings 5 vuotta sitten
vanhempi
commit
015281ff53

+ 135 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolver.java

@@ -0,0 +1,135 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.function.Function;
+import javax.servlet.http.HttpServletRequest;
+
+import org.springframework.core.convert.converter.Converter;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
+import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
+import org.springframework.security.web.util.matcher.RequestMatcher;
+import org.springframework.util.Assert;
+import org.springframework.util.StringUtils;
+import org.springframework.web.util.UriComponents;
+import org.springframework.web.util.UriComponentsBuilder;
+
+import static org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration.withRelyingPartyRegistration;
+import static org.springframework.security.web.util.UrlUtils.buildFullRequestUrl;
+import static org.springframework.web.util.UriComponentsBuilder.fromHttpUrl;
+
+/**
+ * A {@link Converter} that resolves a {@link RelyingPartyRegistration} by extracting the
+ * registration id from the request, querying a {@link RelyingPartyRegistrationRepository},
+ * and resolving any template values.
+ *
+ * @since 5.4
+ * @author Josh Cummings
+ */
+public final class DefaultRelyingPartyRegistrationResolver
+		implements Converter<HttpServletRequest, RelyingPartyRegistration> {
+
+	private static final char PATH_DELIMITER = '/';
+
+	private final RelyingPartyRegistrationRepository relyingPartyRegistrationRepository;
+	private final Converter<HttpServletRequest, String> registrationIdResolver = new RegistrationIdResolver();
+
+	public DefaultRelyingPartyRegistrationResolver
+			(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository) {
+
+		Assert.notNull(relyingPartyRegistrationRepository, "relyingPartyRegistrationRepository cannot be null");
+		this.relyingPartyRegistrationRepository = relyingPartyRegistrationRepository;
+	}
+
+	@Override
+	public RelyingPartyRegistration convert(HttpServletRequest request) {
+		String registrationId = this.registrationIdResolver.convert(request);
+		if (registrationId == null) {
+			return null;
+		}
+		RelyingPartyRegistration relyingPartyRegistration =
+				this.relyingPartyRegistrationRepository.findByRegistrationId(registrationId);
+		if (relyingPartyRegistration == null) {
+			return null;
+		}
+
+		String applicationUri = getApplicationUri(request);
+		Function<String, String> templateResolver = templateResolver(applicationUri, relyingPartyRegistration);
+		String relyingPartyEntityId = templateResolver.apply(relyingPartyRegistration.getEntityId());
+		String assertionConsumerServiceLocation = templateResolver.apply(
+				relyingPartyRegistration.getAssertionConsumerServiceLocation());
+		return withRelyingPartyRegistration(relyingPartyRegistration)
+				.entityId(relyingPartyEntityId)
+				.assertionConsumerServiceLocation(assertionConsumerServiceLocation)
+				.build();
+	}
+
+	private Function<String, String> templateResolver(String applicationUri, RelyingPartyRegistration relyingParty) {
+		return template -> resolveUrlTemplate(template, applicationUri, relyingParty);
+	}
+
+	private static String resolveUrlTemplate(String template, String baseUrl, RelyingPartyRegistration relyingParty) {
+		String entityId = relyingParty.getAssertingPartyDetails().getEntityId();
+		String registrationId = relyingParty.getRegistrationId();
+		Map<String, String> uriVariables = new HashMap<>();
+		UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(baseUrl)
+				.replaceQuery(null)
+				.fragment(null)
+				.build();
+		String scheme = uriComponents.getScheme();
+		uriVariables.put("baseScheme", scheme == null ? "" : scheme);
+		String host = uriComponents.getHost();
+		uriVariables.put("baseHost", host == null ? "" : host);
+		// following logic is based on HierarchicalUriComponents#toUriString()
+		int port = uriComponents.getPort();
+		uriVariables.put("basePort", port == -1 ? "" : ":" + port);
+		String path = uriComponents.getPath();
+		if (StringUtils.hasLength(path) && path.charAt(0) != PATH_DELIMITER) {
+			path = PATH_DELIMITER + path;
+		}
+		uriVariables.put("basePath", path == null ? "" : path);
+		uriVariables.put("baseUrl", uriComponents.toUriString());
+		uriVariables.put("entityId", StringUtils.hasText(entityId) ? entityId : "");
+		uriVariables.put("registrationId", StringUtils.hasText(registrationId) ? registrationId : "");
+
+		return UriComponentsBuilder.fromUriString(template)
+				.buildAndExpand(uriVariables)
+				.toUriString();
+	}
+
+	private static String getApplicationUri(HttpServletRequest request) {
+		UriComponents uriComponents = fromHttpUrl(buildFullRequestUrl(request))
+				.replacePath(request.getContextPath())
+				.replaceQuery(null)
+				.fragment(null)
+				.build();
+		return uriComponents.toUriString();
+	}
+
+	private static class RegistrationIdResolver implements Converter<HttpServletRequest, String> {
+		private final RequestMatcher requestMatcher = new AntPathRequestMatcher("/**/{registrationId}");
+
+		@Override
+		public String convert(HttpServletRequest request) {
+			RequestMatcher.MatchResult result = this.requestMatcher.matcher(request);
+			return result.getVariables().get("registrationId");
+		}
+	}
+}

+ 74 - 0
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/DefaultRelyingPartyRegistrationResolverTests.java

@@ -0,0 +1,74 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.saml2.provider.service.web;
+
+import org.junit.Test;
+
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.security.saml2.provider.service.registration.InMemoryRelyingPartyRegistrationRepository;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.relyingPartyRegistration;
+
+/**
+ * Tests for {@link DefaultRelyingPartyRegistrationResolver}
+ */
+public class DefaultRelyingPartyRegistrationResolverTests {
+	private final RelyingPartyRegistration registration = relyingPartyRegistration().build();
+	private final RelyingPartyRegistrationRepository repository =
+			new InMemoryRelyingPartyRegistrationRepository(this.registration);
+	private final DefaultRelyingPartyRegistrationResolver resolver =
+			new DefaultRelyingPartyRegistrationResolver(this.repository);
+
+	@Test
+	public void resolveWhenRequestContainsRegistrationIdThenResolves() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setPathInfo("/some/path/" + this.registration.getRegistrationId());
+		RelyingPartyRegistration registration = this.resolver.convert(request);
+		assertThat(registration).isNotNull();
+		assertThat(registration.getRegistrationId())
+				.isEqualTo(this.registration.getRegistrationId());
+		assertThat(registration.getEntityId())
+				.isEqualTo("http://localhost/saml2/service-provider-metadata/" + this.registration.getRegistrationId());
+		assertThat(registration.getAssertionConsumerServiceLocation())
+				.isEqualTo("http://localhost/login/saml2/sso/" + this.registration.getRegistrationId());
+	}
+
+	@Test
+	public void resolveWhenRequestContainsInvalidRegistrationIdThenNull() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		request.setPathInfo("/some/path/not-" + this.registration.getRegistrationId());
+		RelyingPartyRegistration registration = this.resolver.convert(request);
+		assertThat(registration).isNull();
+	}
+
+	@Test
+	public void resolveWhenRequestIsMissingRegistrationIdThenNull() {
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		RelyingPartyRegistration registration = this.resolver.convert(request);
+		assertThat(registration).isNull();
+	}
+
+	@Test
+	public void constructorWhenNullRelyingPartyRegistrationThenIllegalArgument() {
+		assertThatCode(() -> new DefaultRelyingPartyRegistrationResolver(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+}