Browse Source

Saml2AuthenticationRequestFilter Tests

To confirm behavior still works as expected after making related changes.

Issue gh-8359
Josh Cummings 5 years ago
parent
commit
887cb99926

+ 67 - 4
saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/servlet/filter/Saml2WebSsoAuthenticationRequestFilterTests.java

@@ -16,22 +16,29 @@
 
 package org.springframework.security.saml2.provider.service.servlet.filter;
 
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import javax.servlet.ServletException;
+
 import org.junit.Before;
 import org.junit.Test;
+
 import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.saml2.provider.service.authentication.Saml2AuthenticationRequestFactory;
+import org.springframework.security.saml2.provider.service.authentication.Saml2PostAuthenticationRequest;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
 import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistrationRepository;
 import org.springframework.web.util.HtmlUtils;
 import org.springframework.web.util.UriUtils;
 
-import javax.servlet.ServletException;
-import java.io.IOException;
-import java.nio.charset.StandardCharsets;
-
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatCode;
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.when;
 import static org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding.POST;
 import static org.springframework.security.saml2.provider.service.servlet.filter.TestSaml2SigningCredentials.signingCredential;
@@ -41,6 +48,7 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
 	private static final String IDP_SSO_URL = "https://sso-url.example.com/IDP/SSO";
 	private Saml2WebSsoAuthenticationRequestFilter filter;
 	private RelyingPartyRegistrationRepository repository = mock(RelyingPartyRegistrationRepository.class);
+	private Saml2AuthenticationRequestFactory factory = mock(Saml2AuthenticationRequestFactory.class);
 	private MockHttpServletRequest request;
 	private MockHttpServletResponse response;
 	private MockFilterChain filterChain;
@@ -147,4 +155,59 @@ public class Saml2WebSsoAuthenticationRequestFilterTests {
 				.contains("value=\""+relayStateEncoded+"\"");
 	}
 
+	@Test
+	public void doFilterWhenSetAuthenticationRequestFactoryThenUses() throws Exception {
+		RelyingPartyRegistration relyingParty = this.rpBuilder
+				.providerDetails(c -> c.binding(POST))
+				.build();
+		Saml2PostAuthenticationRequest authenticationRequest = mock(Saml2PostAuthenticationRequest.class);
+		when(authenticationRequest.getAuthenticationRequestUri()).thenReturn("uri");
+		when(authenticationRequest.getRelayState()).thenReturn("relay");
+		when(authenticationRequest.getSamlRequest()).thenReturn("saml");
+		when(this.repository.findByRegistrationId("registration-id")).thenReturn(relyingParty);
+		when(this.factory.createPostAuthenticationRequest(any()))
+				.thenReturn(authenticationRequest);
+
+		Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
+				(this.repository);
+		filter.setAuthenticationRequestFactory(this.factory);
+		filter.doFilterInternal(this.request, this.response, this.filterChain);
+		assertThat(this.response.getContentAsString())
+				.contains("<form action=\"uri\" method=\"post\">")
+				.contains("<input type=\"hidden\" name=\"SAMLRequest\" value=\"saml\"")
+				.contains("<input type=\"hidden\" name=\"RelayState\" value=\"relay\"");
+		verify(this.factory).createPostAuthenticationRequest(any());
+	}
+
+	@Test
+	public void setRequestMatcherWhenNullThenException() {
+		Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
+				(this.repository);
+		assertThatCode(() -> filter.setRedirectMatcher(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void setAuthenticationRequestFactoryWhenNullThenException() {
+		Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter(this.repository);
+		assertThatCode(() -> filter.setAuthenticationRequestFactory(null))
+				.isInstanceOf(IllegalArgumentException.class);
+	}
+
+	@Test
+	public void doFilterWhenRequestMatcherFailsThenSkipsFilter() throws Exception {
+		Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
+				(this.repository);
+		filter.setRedirectMatcher(request -> false);
+		filter.doFilter(this.request, this.response, this.filterChain);
+		verifyNoInteractions(this.repository);
+	}
+
+	@Test
+	public void doFilterWhenRelyingPartyRegistrationNotFoundThenUnauthorized() throws Exception {
+		Saml2WebSsoAuthenticationRequestFilter filter = new Saml2WebSsoAuthenticationRequestFilter
+				(this.repository);
+		filter.doFilter(this.request, this.response, this.filterChain);
+		assertThat(this.response.getStatus()).isEqualTo(401);
+	}
 }