|
@@ -16,22 +16,29 @@
|
|
|
|
|
|
package org.springframework.security.saml2.provider.service.servlet.filter;
|
|
|
|
|
|
+import java.io.IOException;
|
|
|
+import java.nio.charset.StandardCharsets;
|
|
|
+import javax.servlet.ServletException;
|
|
|
+
|
|
|
import org.junit.Before;
|
|
|
import org.junit.Test;
|
|
|
+
|
|
|
import org.springframework.mock.web.MockFilterChain;
|
|
|
import org.springframework.mock.web.MockHttpServletRequest;
|
|
|
import org.springframework.mock.web.MockHttpServletResponse;
|
|
|
+import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
|
|
|
+import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
|
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
|
|
|
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
|
|
|
import org.springframework.web.util.HtmlUtils;
|
|
|
import org.springframework.web.util.UriUtils;
|
|
|
|
|
|
-import javax.servlet.ServletException;
|
|
|
-import java.io.IOException;
|
|
|
-import java.nio.charset.StandardCharsets;
|
|
|
-
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
|
+import static org.assertj.core.api.Assertions.assertThatCode;
|
|
|
+import static org.mockito.ArgumentMatchers.any;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
+import static org.mockito.Mockito.verify;
|
|
|
+import static org.mockito.Mockito.verifyNoInteractions;
|
|
|
import static org.mockito.Mockito.when;
|
|
|
import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
|
|
|
import static org.springframework.security.saml2.provider.service.servlet.filter.TestSaml2SigningCredentials.signingCredential;
|
|
@@ -41,6 +48,7 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
|
|
|
private static final String IDP_SSO_URL = "https://sso-url.example.com/IDP/SSO";
|
|
|
private Saml2WebSsoAuthenticationRequestFilter filter;
|
|
|
private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class);
|
|
|
+ private Saml2AuthenticationRequestFactory factory = mock(Saml2AuthenticationRequestFactory.class);
|
|
|
private MockHttpServletRequest request;
|
|
|
private MockHttpServletResponse response;
|
|
|
private MockFilterChain filterChain;
|
|
@@ -147,4 +155,59 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
|
|
|
.contains("value=\""+relayStateEncoded+"\"");
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenSetAuthenticationRequestFactoryThenUses() throws Exception {
|
|
|
+ RelyingPartyRegistration relyingParty = this.rpBuilder
|
|
|
+ .providerDetails(c -> c.binding(POST))
|
|
|
+ .build();
|
|
|
+ Saml2PostAuthenticationRequest authenticationRequest = mock(Saml2PostAuthenticationRequest.class);
|
|
|
+ when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri");
|
|
|
+ when(authenticationRequest.getRelayState()).thenReturn("relay");
|
|
|
+ when(authenticationRequest.getSamlRequest()).thenReturn("saml");
|
|
|
+ when(this.repository.findByRegistrationId("registration-id")).thenReturn(relyingParty);
|
|
|
+ when(this.factory.createPostAuthenticationRequest(any()))
|
|
|
+ .thenReturn(authenticationRequest);
|
|
|
+
|
|
|
+ Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
|
|
|
+ (this.repository);
|
|
|
+ filter.setAuthenticationRequestFactory(this.factory);
|
|
|
+ filter.doFilterInternal(this.request, this.response, this.filterChain);
|
|
|
+ assertThat(this.response.getContentAsString())
|
|
|
+ .contains("<form action=\"uri\" method=\"post\">")
|
|
|
+ .contains("<input type=\"hidden\" name=\"SAMLRequest\" value=\"saml\"")
|
|
|
+ .contains("<input type=\"hidden\" name=\"RelayState\" value=\"relay\"");
|
|
|
+ verify(this.factory).createPostAuthenticationRequest(any());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void setRequestMatcherWhenNullThenException() {
|
|
|
+ Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
|
|
|
+ (this.repository);
|
|
|
+ assertThatCode(() -> filter.setRedirectMatcher(null))
|
|
|
+ .isInstanceOf(IllegalArgumentException.class);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void setAuthenticationRequestFactoryWhenNullThenException() {
|
|
|
+ Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.repository);
|
|
|
+ assertThatCode(() -> filter.setAuthenticationRequestFactory(null))
|
|
|
+ .isInstanceOf(IllegalArgumentException.class);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenRequestMatcherFailsThenSkipsFilter() throws Exception {
|
|
|
+ Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
|
|
|
+ (this.repository);
|
|
|
+ filter.setRedirectMatcher(request -> false);
|
|
|
+ filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
+ verifyNoInteractions(this.repository);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void doFilterWhenRelyingPartyRegistrationNotFoundThenUnauthorized() throws Exception {
|
|
|
+ Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
|
|
|
+ (this.repository);
|
|
|
+ filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
+ assertThat(this.response.getStatus()).isEqualTo(401);
|
|
|
+ }
|
|
|
}
|