|
@@ -18,6 +18,7 @@ package org.springframework.security.web.csrf;
|
|
|
import static org.fest.assertions.Assertions.assertThat;
|
|
|
import static org.mockito.Matchers.any;
|
|
|
import static org.mockito.Matchers.eq;
|
|
|
+import static org.mockito.Mockito.times;
|
|
|
import static org.mockito.Mockito.verify;
|
|
|
import static org.mockito.Mockito.verifyZeroInteractions;
|
|
|
import static org.mockito.Mockito.when;
|
|
@@ -27,8 +28,11 @@ import java.util.Arrays;
|
|
|
|
|
|
import javax.servlet.FilterChain;
|
|
|
import javax.servlet.ServletException;
|
|
|
+import javax.servlet.http.HttpServletRequest;
|
|
|
import javax.servlet.http.HttpServletResponse;
|
|
|
|
|
|
+import org.fest.assertions.GenericAssert;
|
|
|
+import org.fest.assertions.ObjectAssert;
|
|
|
import org.junit.Before;
|
|
|
import org.junit.Test;
|
|
|
import org.junit.runner.RunWith;
|
|
@@ -59,12 +63,12 @@ public class CsrfFilterTests {
|
|
|
private MockHttpServletResponse response;
|
|
|
private CsrfToken token;
|
|
|
|
|
|
-
|
|
|
private CsrfFilter filter;
|
|
|
|
|
|
@Before
|
|
|
public void setup() {
|
|
|
- token = new CsrfToken("headerName","paramName", "csrfTokenValue");
|
|
|
+ token = new DefaultCsrfToken("headerName", "paramName",
|
|
|
+ "csrfTokenValue");
|
|
|
resetRequestResponse();
|
|
|
filter = new CsrfFilter(tokenRepository);
|
|
|
filter.setRequireCsrfProtectionMatcher(requestMatcher);
|
|
@@ -81,171 +85,221 @@ public class CsrfFilterTests {
|
|
|
new CsrfFilter(null);
|
|
|
}
|
|
|
|
|
|
+ // SEC-2276
|
|
|
+ @Test
|
|
|
+ public void doFilterDoesNotSaveCsrfTokenUntilAccessed() throws ServletException,
|
|
|
+ IOException {
|
|
|
+ when(requestMatcher.matches(request)).thenReturn(false);
|
|
|
+ when(tokenRepository.generateToken(request)).thenReturn(token);
|
|
|
+
|
|
|
+ filter.doFilter(request, response, filterChain);
|
|
|
+ CsrfToken attrToken = (CsrfToken) request.getAttribute(token.getParameterName());
|
|
|
+
|
|
|
+ // no CsrfToken should have been saved yet
|
|
|
+ verify(tokenRepository,times(0)).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
+ verify(filterChain).doFilter(request, response);
|
|
|
+
|
|
|
+ // access the token
|
|
|
+ attrToken.getToken();
|
|
|
+
|
|
|
+ // now the CsrfToken should have been saved
|
|
|
+ verify(tokenRepository).saveToken(eq(token), any(HttpServletRequest.class), any(HttpServletResponse.class));
|
|
|
+ }
|
|
|
+
|
|
|
@Test
|
|
|
- public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
|
|
|
+ public void doFilterAccessDeniedNoTokenPresent() throws ServletException,
|
|
|
+ IOException {
|
|
|
when(requestMatcher.matches(request)).thenReturn(true);
|
|
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
|
|
- assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
|
|
- assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
|
|
+ assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
|
|
+ token);
|
|
|
+ assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
|
|
+ token);
|
|
|
|
|
|
- verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
|
|
|
+ verify(deniedHandler).handle(eq(request), eq(response),
|
|
|
+ any(InvalidCsrfTokenException.class));
|
|
|
verifyZeroInteractions(filterChain);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
|
|
|
+ public void doFilterAccessDeniedIncorrectTokenPresent()
|
|
|
+ throws ServletException, IOException {
|
|
|
when(requestMatcher.matches(request)).thenReturn(true);
|
|
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
|
|
- request.setParameter(token.getParameterName(), token.getToken()+ " INVALID");
|
|
|
+ request.setParameter(token.getParameterName(), token.getToken()
|
|
|
+ + " INVALID");
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
|
|
- assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
|
|
- assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
|
|
+ assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
|
|
+ token);
|
|
|
+ assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
|
|
+ token);
|
|
|
|
|
|
- verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
|
|
|
+ verify(deniedHandler).handle(eq(request), eq(response),
|
|
|
+ any(InvalidCsrfTokenException.class));
|
|
|
verifyZeroInteractions(filterChain);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
|
|
|
+ public void doFilterAccessDeniedIncorrectTokenPresentHeader()
|
|
|
+ throws ServletException, IOException {
|
|
|
when(requestMatcher.matches(request)).thenReturn(true);
|
|
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
|
|
- request.addHeader(token.getHeaderName(), token.getToken()+ " INVALID");
|
|
|
+ request.addHeader(token.getHeaderName(), token.getToken() + " INVALID");
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
|
|
- assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
|
|
- assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
|
|
+ assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
|
|
+ token);
|
|
|
+ assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
|
|
+ token);
|
|
|
|
|
|
- verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
|
|
|
+ verify(deniedHandler).handle(eq(request), eq(response),
|
|
|
+ any(InvalidCsrfTokenException.class));
|
|
|
verifyZeroInteractions(filterChain);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() throws ServletException, IOException {
|
|
|
+ public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
|
|
|
+ throws ServletException, IOException {
|
|
|
when(requestMatcher.matches(request)).thenReturn(true);
|
|
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
|
|
request.setParameter(token.getParameterName(), token.getToken());
|
|
|
- request.addHeader(token.getHeaderName(), token.getToken()+ " INVALID");
|
|
|
+ request.addHeader(token.getHeaderName(), token.getToken() + " INVALID");
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
|
|
- assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
|
|
- assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
|
|
+ assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
|
|
+ token);
|
|
|
+ assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
|
|
+ token);
|
|
|
|
|
|
- verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
|
|
|
+ verify(deniedHandler).handle(eq(request), eq(response),
|
|
|
+ any(InvalidCsrfTokenException.class));
|
|
|
verifyZeroInteractions(filterChain);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
|
|
|
+ public void doFilterNotCsrfRequestExistingToken() throws ServletException,
|
|
|
+ IOException {
|
|
|
when(requestMatcher.matches(request)).thenReturn(false);
|
|
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
|
|
- assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
|
|
- assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
|
|
+ assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
|
|
+ token);
|
|
|
+ assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
|
|
+ token);
|
|
|
|
|
|
verify(filterChain).doFilter(request, response);
|
|
|
verifyZeroInteractions(deniedHandler);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
|
|
|
+ public void doFilterNotCsrfRequestGenerateToken() throws ServletException,
|
|
|
+ IOException {
|
|
|
when(requestMatcher.matches(request)).thenReturn(false);
|
|
|
- when(tokenRepository.generateAndSaveToken(request, response)).thenReturn(token);
|
|
|
+ when(tokenRepository.generateToken(request))
|
|
|
+ .thenReturn(token);
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
|
|
- assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
|
|
- assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
|
|
+ assertToken(request.getAttribute(token.getParameterName())).isEqualTo(
|
|
|
+ token);
|
|
|
+ assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
|
|
+ token);
|
|
|
|
|
|
verify(filterChain).doFilter(request, response);
|
|
|
verifyZeroInteractions(deniedHandler);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
|
|
|
+ public void doFilterIsCsrfRequestExistingTokenHeader()
|
|
|
+ throws ServletException, IOException {
|
|
|
when(requestMatcher.matches(request)).thenReturn(true);
|
|
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
|
|
request.addHeader(token.getHeaderName(), token.getToken());
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
|
|
- assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
|
|
- assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
|
|
+ assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
|
|
+ token);
|
|
|
+ assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
|
|
+ token);
|
|
|
|
|
|
verify(filterChain).doFilter(request, response);
|
|
|
verifyZeroInteractions(deniedHandler);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() throws ServletException, IOException {
|
|
|
+ public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
|
|
|
+ throws ServletException, IOException {
|
|
|
when(requestMatcher.matches(request)).thenReturn(true);
|
|
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
|
|
- request.setParameter(token.getParameterName(), token.getToken()+ " INVALID");
|
|
|
+ request.setParameter(token.getParameterName(), token.getToken()
|
|
|
+ + " INVALID");
|
|
|
request.addHeader(token.getHeaderName(), token.getToken());
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
|
|
- assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
|
|
- assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
|
|
+ assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
|
|
+ token);
|
|
|
+ assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
|
|
+ token);
|
|
|
|
|
|
verify(filterChain).doFilter(request, response);
|
|
|
verifyZeroInteractions(deniedHandler);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
|
|
|
+ public void doFilterIsCsrfRequestExistingToken() throws ServletException,
|
|
|
+ IOException {
|
|
|
when(requestMatcher.matches(request)).thenReturn(true);
|
|
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
|
|
request.setParameter(token.getParameterName(), token.getToken());
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
|
|
- assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
|
|
- assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
|
|
+ assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
|
|
+ token);
|
|
|
+ assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
|
|
+ token);
|
|
|
|
|
|
verify(filterChain).doFilter(request, response);
|
|
|
verifyZeroInteractions(deniedHandler);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
|
|
|
+ public void doFilterIsCsrfRequestGenerateToken() throws ServletException,
|
|
|
+ IOException {
|
|
|
when(requestMatcher.matches(request)).thenReturn(true);
|
|
|
- when(tokenRepository.generateAndSaveToken(request, response)).thenReturn(token);
|
|
|
+ when(tokenRepository.generateToken(request))
|
|
|
+ .thenReturn(token);
|
|
|
request.setParameter(token.getParameterName(), token.getToken());
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
|
|
- assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
|
|
- assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
|
|
+ assertToken(request.getAttribute(token.getParameterName())).isEqualTo(
|
|
|
+ token);
|
|
|
+ assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
|
|
+ token);
|
|
|
|
|
|
verify(filterChain).doFilter(request, response);
|
|
|
verifyZeroInteractions(deniedHandler);
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException {
|
|
|
+ public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods()
|
|
|
+ throws ServletException, IOException {
|
|
|
filter = new CsrfFilter(tokenRepository);
|
|
|
filter.setAccessDeniedHandler(deniedHandler);
|
|
|
|
|
|
- for(String method : Arrays.asList("GET","TRACE", "OPTIONS", "HEAD")) {
|
|
|
+ for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
|
|
|
resetRequestResponse();
|
|
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
|
|
request.setMethod(method);
|
|
@@ -258,24 +312,28 @@ public class CsrfFilterTests {
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException {
|
|
|
+ public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods()
|
|
|
+ throws ServletException, IOException {
|
|
|
filter = new CsrfFilter(tokenRepository);
|
|
|
filter.setAccessDeniedHandler(deniedHandler);
|
|
|
|
|
|
- for(String method : Arrays.asList("POST","PUT", "PATCH", "DELETE", "INVALID")) {
|
|
|
+ for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE",
|
|
|
+ "INVALID")) {
|
|
|
resetRequestResponse();
|
|
|
when(tokenRepository.loadToken(request)).thenReturn(token);
|
|
|
request.setMethod(method);
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
|
|
|
+ verify(deniedHandler).handle(eq(request), eq(response),
|
|
|
+ any(InvalidCsrfTokenException.class));
|
|
|
verifyZeroInteractions(filterChain);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterDefaultAccessDenied() throws ServletException, IOException {
|
|
|
+ public void doFilterDefaultAccessDenied() throws ServletException,
|
|
|
+ IOException {
|
|
|
filter = new CsrfFilter(tokenRepository);
|
|
|
filter.setRequireCsrfProtectionMatcher(requestMatcher);
|
|
|
when(requestMatcher.matches(request)).thenReturn(true);
|
|
@@ -283,11 +341,13 @@ public class CsrfFilterTests {
|
|
|
|
|
|
filter.doFilter(request, response, filterChain);
|
|
|
|
|
|
- assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
|
|
|
- assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
|
|
|
- assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
|
|
|
+ assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
|
|
|
+ token);
|
|
|
+ assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
|
|
|
+ token);
|
|
|
|
|
|
- assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
|
|
|
+ assertThat(response.getStatus()).isEqualTo(
|
|
|
+ HttpServletResponse.SC_FORBIDDEN);
|
|
|
verifyZeroInteractions(filterChain);
|
|
|
}
|
|
|
|
|
@@ -300,4 +360,29 @@ public class CsrfFilterTests {
|
|
|
public void setAccessDeniedHandlerNull() {
|
|
|
filter.setAccessDeniedHandler(null);
|
|
|
}
|
|
|
+
|
|
|
+ private static final CsrfTokenAssert assertToken(Object token) {
|
|
|
+ return new CsrfTokenAssert((CsrfToken)token);
|
|
|
+ }
|
|
|
+
|
|
|
+ private static class CsrfTokenAssert extends
|
|
|
+ GenericAssert<CsrfTokenAssert, CsrfToken> {
|
|
|
+
|
|
|
+ /**
|
|
|
+ * Creates a new </code>{@link ObjectAssert}</code>.
|
|
|
+ *
|
|
|
+ * @param actual
|
|
|
+ * the target to verify.
|
|
|
+ */
|
|
|
+ protected CsrfTokenAssert(CsrfToken actual) {
|
|
|
+ super(CsrfTokenAssert.class, actual);
|
|
|
+ }
|
|
|
+
|
|
|
+ public CsrfTokenAssert isEqualTo(CsrfToken expected) {
|
|
|
+ assertThat(actual.getHeaderName()).isEqualTo(expected.getHeaderName());
|
|
|
+ assertThat(actual.getParameterName()).isEqualTo(expected.getParameterName());
|
|
|
+ assertThat(actual.getToken()).isEqualTo(expected.getToken());
|
|
|
+ return this;
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|