|
@@ -33,6 +33,7 @@ import org.springframework.security.saml2.provider.service.registration.RelyingP
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
|
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
|
|
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
|
|
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
|
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
|
|
|
+import org.springframework.test.util.ReflectionTestUtils;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
|
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
|
|
@@ -137,7 +138,7 @@ public class Saml2MetadataFilterTests {
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
@Test
|
|
- public void doFilterWhenPathStartsWithRegistrationIdThenServesMetadata() throws Exception {
|
|
|
|
|
|
+ public void doFilterWhenResolverConstructorAndPathStartsWithRegistrationIdThenServesMetadata() throws Exception {
|
|
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
|
|
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
|
|
given(this.repository.findByRegistrationId("registration-id")).willReturn(registration);
|
|
given(this.repository.findByRegistrationId("registration-id")).willReturn(registration);
|
|
given(this.resolver.resolve(any())).willReturn("metadata");
|
|
given(this.resolver.resolve(any())).willReturn("metadata");
|
|
@@ -151,16 +152,17 @@ public class Saml2MetadataFilterTests {
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
@Test
|
|
- public void doFilterWhenPathStartsWithOneThenServesMetadata() throws Exception {
|
|
|
|
|
|
+ public void doFilterWhenRelyingPartyRegistrationRepositoryConstructorAndPathStartsWithRegistrationIdThenServesMetadata()
|
|
|
|
+ throws Exception {
|
|
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
|
|
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.full().build();
|
|
- given(this.repository.findByRegistrationId("one")).willReturn(registration);
|
|
|
|
|
|
+ given(this.repository.findByRegistrationId("registration-id")).willReturn(registration);
|
|
given(this.resolver.resolve(any())).willReturn("metadata");
|
|
given(this.resolver.resolve(any())).willReturn("metadata");
|
|
- this.filter = new Saml2MetadataFilter((id) -> this.repository.findByRegistrationId("one"),
|
|
|
|
|
|
+ this.filter = new Saml2MetadataFilter((id) -> this.repository.findByRegistrationId("registration-id"),
|
|
this.resolver);
|
|
this.resolver);
|
|
this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata"));
|
|
this.filter.setRequestMatcher(new AntPathRequestMatcher("/metadata"));
|
|
this.request.setPathInfo("/metadata");
|
|
this.request.setPathInfo("/metadata");
|
|
this.filter.doFilter(this.request, this.response, new MockFilterChain());
|
|
this.filter.doFilter(this.request, this.response, new MockFilterChain());
|
|
- verify(this.repository).findByRegistrationId("one");
|
|
|
|
|
|
+ verify(this.repository).findByRegistrationId("registration-id");
|
|
}
|
|
}
|
|
|
|
|
|
// gh-12026
|
|
// gh-12026
|
|
@@ -196,4 +198,14 @@ public class Saml2MetadataFilterTests {
|
|
.withMessage("metadataFilename must contain a {registrationId} match variable");
|
|
.withMessage("metadataFilename must contain a {registrationId} match variable");
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ @Test
|
|
|
|
+ public void constructorWhenRelyingPartyRegistrationRepositoryThenUses() {
|
|
|
|
+ RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class);
|
|
|
|
+ this.filter = new Saml2MetadataFilter(repository, this.resolver);
|
|
|
|
+ DefaultRelyingPartyRegistrationResolver relyingPartyRegistrationResolver = (DefaultRelyingPartyRegistrationResolver) ReflectionTestUtils
|
|
|
|
+ .getField(this.filter, "relyingPartyRegistrationResolver");
|
|
|
|
+ relyingPartyRegistrationResolver.resolve(this.request, "one");
|
|
|
|
+ verify(repository).findByRegistrationId("one");
|
|
|
|
+ }
|
|
|
|
+
|
|
}
|
|
}
|