|
@@ -20,95 +20,125 @@ import org.junit.Before;
|
|
import org.junit.Test;
|
|
import org.junit.Test;
|
|
import org.springframework.mock.web.MockHttpServletRequest;
|
|
import org.springframework.mock.web.MockHttpServletRequest;
|
|
import org.springframework.mock.web.MockHttpServletResponse;
|
|
import org.springframework.mock.web.MockHttpServletResponse;
|
|
|
|
+import org.springframework.security.saml2.provider.service.metadata.Saml2MetadataResolver;
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
|
-import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
|
|
|
|
|
|
+import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
|
|
|
|
|
import javax.servlet.FilterChain;
|
|
import javax.servlet.FilterChain;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
|
|
+import static org.assertj.core.api.Assertions.assertThatCode;
|
|
import static org.mockito.Mockito.mock;
|
|
import static org.mockito.Mockito.mock;
|
|
import static org.mockito.Mockito.verify;
|
|
import static org.mockito.Mockito.verify;
|
|
import static org.mockito.Mockito.verifyNoInteractions;
|
|
import static org.mockito.Mockito.verifyNoInteractions;
|
|
import static org.mockito.Mockito.when;
|
|
import static org.mockito.Mockito.when;
|
|
|
|
+import static org.springframework.security.saml2.core.TestSaml2X509Credentials.relyingPartyVerifyingCredential;
|
|
|
|
+import static org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations.noCredentials;
|
|
|
|
|
|
|
|
+/**
|
|
|
|
+ * Tests for {@link Saml2MetadataFilter}
|
|
|
|
+ */
|
|
public class Saml2MetadataFilterTest {
|
|
public class Saml2MetadataFilterTest {
|
|
|
|
|
|
RelyingPartyRegistrationRepository repository;
|
|
RelyingPartyRegistrationRepository repository;
|
|
- Saml2MetadataResolver saml2MetadataResolver;
|
|
|
|
|
|
+ Saml2MetadataResolver resolver;
|
|
Saml2MetadataFilter filter;
|
|
Saml2MetadataFilter filter;
|
|
MockHttpServletRequest request;
|
|
MockHttpServletRequest request;
|
|
MockHttpServletResponse response;
|
|
MockHttpServletResponse response;
|
|
- FilterChain filterChain;
|
|
|
|
|
|
+ FilterChain chain;
|
|
|
|
|
|
@Before
|
|
@Before
|
|
public void setup() {
|
|
public void setup() {
|
|
- repository = mock(RelyingPartyRegistrationRepository.class);
|
|
|
|
- saml2MetadataResolver = mock(Saml2MetadataResolver.class);
|
|
|
|
- filter = new Saml2MetadataFilter(repository, saml2MetadataResolver);
|
|
|
|
- request = new MockHttpServletRequest();
|
|
|
|
- response = new MockHttpServletResponse();
|
|
|
|
- filterChain = mock(FilterChain.class);
|
|
|
|
|
|
+ this.repository = mock(RelyingPartyRegistrationRepository.class);
|
|
|
|
+ this.resolver = mock(Saml2MetadataResolver.class);
|
|
|
|
+ this.filter = new Saml2MetadataFilter(
|
|
|
|
+ new DefaultRelyingPartyRegistrationResolver(this.repository), this.resolver);
|
|
|
|
+ this.request = new MockHttpServletRequest();
|
|
|
|
+ this.response = new MockHttpServletResponse();
|
|
|
|
+ this.chain = mock(FilterChain.class);
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
@Test
|
|
- public void shouldReturnValueWhenMatcherSucceed() throws Exception {
|
|
|
|
|
|
+ public void doFilterWhenMatcherSucceedsThenResolverInvoked() throws Exception {
|
|
// given
|
|
// given
|
|
- request.setPathInfo("/saml2/service-provider-metadata/registration-id");
|
|
|
|
|
|
+ this.request.setPathInfo("/saml2/service-provider-metadata/registration-id");
|
|
|
|
|
|
// when
|
|
// when
|
|
- filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
+ this.filter.doFilter(this.request, this.response, this.chain);
|
|
|
|
|
|
// then
|
|
// then
|
|
- verifyNoInteractions(filterChain);
|
|
|
|
|
|
+ verifyNoInteractions(this.chain);
|
|
|
|
+ verify(this.repository).findByRegistrationId("registration-id");
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
@Test
|
|
- public void shouldProcessFilterChainIfMatcherFails() throws Exception {
|
|
|
|
|
|
+ public void doFilterWhenMatcherFailsThenProcessesFilterChain() throws Exception {
|
|
// given
|
|
// given
|
|
- request.setPathInfo("/saml2/authenticate/registration-id");
|
|
|
|
|
|
+ this.request.setPathInfo("/saml2/authenticate/registration-id");
|
|
|
|
|
|
// when
|
|
// when
|
|
- filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
+ this.filter.doFilter(this.request, this.response, this.chain);
|
|
|
|
|
|
// then
|
|
// then
|
|
- verify(filterChain).doFilter(request, response);
|
|
|
|
|
|
+ verify(this.chain).doFilter(this.request, this.response);
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
@Test
|
|
- public void shouldReturn401IfNoRegistrationIsFound() throws Exception {
|
|
|
|
|
|
+ public void doFilterWhenNoRelyingPartyRegistrationThenUnauthorized() throws Exception {
|
|
// given
|
|
// given
|
|
- request.setPathInfo("/saml2/service-provider-metadata/invalidRegistration");
|
|
|
|
- when(repository.findByRegistrationId("invalidRegistration")).thenReturn(null);
|
|
|
|
|
|
+ this.request.setPathInfo("/saml2/service-provider-metadata/invalidRegistration");
|
|
|
|
+ when(this.repository.findByRegistrationId("invalidRegistration")).thenReturn(null);
|
|
|
|
|
|
// when
|
|
// when
|
|
- filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
+ this.filter.doFilter(this.request, this.response, this.chain);
|
|
|
|
|
|
// then
|
|
// then
|
|
- verifyNoInteractions(filterChain);
|
|
|
|
- assertThat(response.getStatus()).isEqualTo(401);
|
|
|
|
|
|
+ verifyNoInteractions(this.chain);
|
|
|
|
+ assertThat(this.response.getStatus()).isEqualTo(401);
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
@Test
|
|
- public void shouldInvokeMetadataGenerationIfRegistrationIsFound() throws Exception {
|
|
|
|
|
|
+ public void doFilterWhenRelyingPartyRegistrationFoundThenInvokesMetadataResolver() throws Exception {
|
|
// given
|
|
// given
|
|
- request.setPathInfo("/saml2/service-provider-metadata/validRegistration");
|
|
|
|
- RelyingPartyRegistration validRegistration = TestRelyingPartyRegistrations.relyingPartyRegistration().build();
|
|
|
|
- when(repository.findByRegistrationId("validRegistration")).thenReturn(validRegistration);
|
|
|
|
|
|
+ this.request.setPathInfo("/saml2/service-provider-metadata/validRegistration");
|
|
|
|
+ RelyingPartyRegistration validRegistration = noCredentials()
|
|
|
|
+ .assertingPartyDetails(party -> party
|
|
|
|
+ .verificationX509Credentials(c -> c.add(relyingPartyVerifyingCredential())))
|
|
|
|
+ .build();
|
|
|
|
|
|
String generatedMetadata = "<xml>test</xml>";
|
|
String generatedMetadata = "<xml>test</xml>";
|
|
- when(saml2MetadataResolver.resolveMetadata(request, validRegistration)).thenReturn(generatedMetadata);
|
|
|
|
|
|
+ when(this.resolver.resolve(validRegistration)).thenReturn(generatedMetadata);
|
|
|
|
|
|
- filter = new Saml2MetadataFilter(repository, saml2MetadataResolver);
|
|
|
|
|
|
+ this.filter = new Saml2MetadataFilter(request -> validRegistration, this.resolver);
|
|
|
|
|
|
// when
|
|
// when
|
|
- filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
+ this.filter.doFilter(this.request, this.response, this.chain);
|
|
|
|
|
|
// then
|
|
// then
|
|
- verifyNoInteractions(filterChain);
|
|
|
|
- assertThat(response.getStatus()).isEqualTo(200);
|
|
|
|
- assertThat(response.getContentAsString()).isEqualTo(generatedMetadata);
|
|
|
|
- verify(saml2MetadataResolver).resolveMetadata(request, validRegistration);
|
|
|
|
|
|
+ verifyNoInteractions(this.chain);
|
|
|
|
+ assertThat(this.response.getStatus()).isEqualTo(200);
|
|
|
|
+ assertThat(this.response.getContentAsString()).isEqualTo(generatedMetadata);
|
|
|
|
+ verify(this.resolver).resolve(validRegistration);
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ @Test
|
|
|
|
+ public void doFilterWhenCustomRequestMatcherThenUses() throws Exception {
|
|
|
|
+ // given
|
|
|
|
+ this.request.setPathInfo("/path");
|
|
|
|
+ this.filter.setRequestMatcher(new AntPathRequestMatcher("/path"));
|
|
|
|
+
|
|
|
|
+ // when
|
|
|
|
+ this.filter.doFilter(this.request, this.response, this.chain);
|
|
|
|
+
|
|
|
|
+ // then
|
|
|
|
+ verifyNoInteractions(this.chain);
|
|
|
|
+ verify(this.repository).findByRegistrationId("path");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ public void setRequestMatcherWhenNullThenIllegalArgument() {
|
|
|
|
+ assertThatCode(() -> this.filter.setRequestMatcher(null))
|
|
|
|
+ .isInstanceOf(IllegalArgumentException.class);
|
|
|
|
+ }
|
|
}
|
|
}
|