|
@@ -86,7 +86,11 @@ public class CsrfFilterTests {
|
|
|
}
|
|
|
|
|
|
private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) {
|
|
|
- CsrfFilter filter = new CsrfFilter(repository);
|
|
|
+ return createCsrfFilter(new CsrfTokenRepositoryRequestHandler(repository));
|
|
|
+ }
|
|
|
+
|
|
|
+ private CsrfFilter createCsrfFilter(CsrfTokenRequestHandler requestHandler) {
|
|
|
+ CsrfFilter filter = new CsrfFilter(requestHandler);
|
|
|
filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
|
|
|
filter.setAccessDeniedHandler(this.deniedHandler);
|
|
|
return filter;
|
|
@@ -99,7 +103,7 @@ public class CsrfFilterTests {
|
|
|
|
|
|
@Test
|
|
|
public void constructorNullRepository() {
|
|
|
- assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter(null));
|
|
|
+ assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter((CsrfTokenRequestHandler) null));
|
|
|
}
|
|
|
|
|
|
// SEC-2276
|
|
@@ -249,7 +253,7 @@ public class CsrfFilterTests {
|
|
|
|
|
|
@Test
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException {
|
|
|
- this.filter = new CsrfFilter(this.tokenRepository);
|
|
|
+ this.filter = createCsrfFilter(this.tokenRepository);
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
|
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
|
|
|
resetRequestResponse();
|
|
@@ -269,7 +273,7 @@ public class CsrfFilterTests {
|
|
|
*/
|
|
|
@Test
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception {
|
|
|
- this.filter = new CsrfFilter(this.tokenRepository);
|
|
|
+ this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
|
for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) {
|
|
|
resetRequestResponse();
|
|
@@ -284,7 +288,7 @@ public class CsrfFilterTests {
|
|
|
|
|
|
@Test
|
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException {
|
|
|
- this.filter = new CsrfFilter(this.tokenRepository);
|
|
|
+ this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
|
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) {
|
|
|
resetRequestResponse();
|
|
@@ -299,7 +303,7 @@ public class CsrfFilterTests {
|
|
|
|
|
|
@Test
|
|
|
public void doFilterDefaultAccessDenied() throws ServletException, IOException {
|
|
|
- this.filter = new CsrfFilter(this.tokenRepository);
|
|
|
+ this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
|
|
|
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
@@ -313,7 +317,7 @@ public class CsrfFilterTests {
|
|
|
@Test
|
|
|
public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception {
|
|
|
CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
|
|
|
- CsrfFilter filter = new CsrfFilter(repository);
|
|
|
+ CsrfFilter filter = createCsrfFilter(repository);
|
|
|
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
|
|
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
|
|
CsrfFilter.skipRequest(request);
|
|
@@ -340,25 +344,13 @@ public class CsrfFilterTests {
|
|
|
CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
|
|
|
given(requestHandler.handle(this.request, this.response))
|
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
|
- this.filter.setRequestHandler(requestHandler);
|
|
|
+ this.filter = createCsrfFilter(requestHandler);
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
verify(requestHandler).handle(eq(this.request), eq(this.response));
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
|
}
|
|
|
|
|
|
- @Test
|
|
|
- public void doFilterWhenRequestResolverThenUsed() throws Exception {
|
|
|
- given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
- CsrfTokenRequestResolver requestResolver = mock(CsrfTokenRequestResolver.class);
|
|
|
- given(requestResolver.resolveCsrfTokenValue(this.request, this.token)).willReturn(this.token.getToken());
|
|
|
- this.filter.setRequestResolver(requestResolver);
|
|
|
- this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- verify(requestResolver).resolveCsrfTokenValue(this.request, this.token);
|
|
|
- verify(this.filterChain).doFilter(this.request, this.response);
|
|
|
- }
|
|
|
-
|
|
|
@Test
|
|
|
public void setRequireCsrfProtectionMatcherNull() {
|
|
|
assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequireCsrfProtectionMatcher(null));
|
|
@@ -373,16 +365,14 @@ public class CsrfFilterTests {
|
|
|
@Test
|
|
|
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet()
|
|
|
throws ServletException, IOException {
|
|
|
- CsrfFilter filter = createCsrfFilter(this.tokenRepository);
|
|
|
String csrfAttrName = "_csrf";
|
|
|
- CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor();
|
|
|
- csrfTokenRequestProcessor.setTokenRepository(this.tokenRepository);
|
|
|
- csrfTokenRequestProcessor.setCsrfRequestAttributeName(csrfAttrName);
|
|
|
- filter.setRequestHandler(csrfTokenRequestProcessor);
|
|
|
+ CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository);
|
|
|
+ requestHandler.setCsrfRequestAttributeName(csrfAttrName);
|
|
|
+ this.filter = createCsrfFilter(requestHandler);
|
|
|
CsrfToken expectedCsrfToken = spy(this.token);
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken);
|
|
|
|
|
|
- filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
+ this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
|
|
|
verifyNoInteractions(expectedCsrfToken);
|
|
|
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
|
|
@@ -410,6 +400,6 @@ public class CsrfFilterTests {
|
|
|
return this.isGenerated;
|
|
|
}
|
|
|
|
|
|
- };
|
|
|
+ }
|
|
|
|
|
|
}
|