Browse Source

Add a new Saml2MetadataFilter constructor for RelyingPartyRegistrationRepository

Closes gh-11815
Mitja Kotnik 2 years ago
parent
commit
70249e536a

+ 6 - 0
saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilter.java

@@ -29,6 +29,7 @@ import org.springframework.http.HttpHeaders;
 import org.springframework.http.MediaType;
 import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResolver;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
+import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
@@ -62,6 +63,11 @@ public final class Saml2MetadataFilter extends OncePerRequestFilter {
 		this.saml2MetadataResolver = saml2MetadataResolver;
 	}
 
+	public Saml2MetadataFilter(RelyingPartyRegistrationRepository relyingPartyRegistrationRepository,
+			Saml2MetadataResolver saml2MetadataResolver) {
+		this(new DefaultRelyingPartyRegistrationResolver(relyingPartyRegistrationRepository), saml2MetadataResolver);
+	}
+
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
 			throws ServletException, IOException {

+ 14 - 3
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/web/Saml2MetadataFilterTests.java

@@ -64,9 +64,7 @@ public class Saml2MetadataFilterTests {
 	public void setup() {
 		this.repository = mock(RelyingPartyRegistrationRepository.class);
 		this.resolver = mock(Saml2MetadataResolver.class);
-		RelyingPartyRegistrationResolver relyingPartyRegistrationResolver = new DefaultRelyingPartyRegistrationResolver(
-				this.repository);
-		this.filter = new Saml2MetadataFilter(relyingPartyRegistrationResolver, this.resolver);
+		this.filter = new Saml2MetadataFilter(this.repository, this.resolver);
 		this.request = new MockHttpServletRequest();
 		this.response = new MockHttpServletResponse();
 		this.chain = mock(FilterChain.class);
@@ -152,6 +150,19 @@ public class Saml2MetadataFilterTests {
 		verify(this.repository).findByRegistrationId("registration-id");
 	}
 
+	@Test
+	public void doFilterWhenPathStartsWithOneThenServesMetadata() throws Exception {
+		RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
+		given(this.repository.findByRegistrationId("one")).willReturn(registration);
+		given(this.resolver.resolve(any())).willReturn("metadata");
+		this.filter = new Saml2MetadataFilter((id) -> this.repository.findByRegistrationId("one"),
+				this.resolver);
+		this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata"));
+		this.request.setPathInfo("/metadata");
+		this.filter.doFilter(this.request, this.response, new MockFilterChain());
+		verify(this.repository).findByRegistrationId("one");
+	}
+
 	// gh-12026
 	@Test
 	public void doFilterWhenCharacterEncodingThenEncodeSpecialCharactersCorrectly() throws Exception {