|
@@ -44,7 +44,6 @@ import static org.mockito.BDDMockito.given;
|
|
import static org.mockito.Mockito.lenient;
|
|
import static org.mockito.Mockito.lenient;
|
|
import static org.mockito.Mockito.mock;
|
|
import static org.mockito.Mockito.mock;
|
|
import static org.mockito.Mockito.never;
|
|
import static org.mockito.Mockito.never;
|
|
-import static org.mockito.Mockito.spy;
|
|
|
|
import static org.mockito.Mockito.times;
|
|
import static org.mockito.Mockito.times;
|
|
import static org.mockito.Mockito.verify;
|
|
import static org.mockito.Mockito.verify;
|
|
import static org.mockito.Mockito.verifyNoInteractions;
|
|
import static org.mockito.Mockito.verifyNoInteractions;
|
|
@@ -86,11 +85,7 @@ public class CsrfFilterTests {
|
|
}
|
|
}
|
|
|
|
|
|
private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) {
|
|
private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) {
|
|
- return createCsrfFilter(new CsrfTokenRepositoryRequestHandler(repository));
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- private CsrfFilter createCsrfFilter(CsrfTokenRequestHandler requestHandler) {
|
|
|
|
- CsrfFilter filter = new CsrfFilter(requestHandler);
|
|
|
|
|
|
+ CsrfFilter filter = new CsrfFilter(repository);
|
|
filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
|
|
filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
|
|
filter.setAccessDeniedHandler(this.deniedHandler);
|
|
filter.setAccessDeniedHandler(this.deniedHandler);
|
|
return filter;
|
|
return filter;
|
|
@@ -103,7 +98,7 @@ public class CsrfFilterTests {
|
|
|
|
|
|
@Test
|
|
@Test
|
|
public void constructorNullRepository() {
|
|
public void constructorNullRepository() {
|
|
- assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter((CsrfTokenRequestHandler) null));
|
|
|
|
|
|
+ assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter(null));
|
|
}
|
|
}
|
|
|
|
|
|
// SEC-2276
|
|
// SEC-2276
|
|
@@ -128,7 +123,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
|
|
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
@@ -139,7 +135,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
|
|
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
@@ -151,7 +148,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
@@ -164,7 +162,8 @@ public class CsrfFilterTests {
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
|
|
throws ServletException, IOException {
|
|
throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
@@ -177,7 +176,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
|
|
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
@@ -188,7 +188,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
|
|
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
|
- given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, true));
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
@@ -199,7 +200,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
|
|
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
@@ -212,7 +214,8 @@ public class CsrfFilterTests {
|
|
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
|
|
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
|
|
throws ServletException, IOException {
|
|
throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
@@ -225,7 +228,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
|
|
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
@@ -239,7 +243,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
|
|
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, true));
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
@@ -247,17 +252,17 @@ public class CsrfFilterTests {
|
|
// LazyCsrfTokenRepository requires the response as an attribute
|
|
// LazyCsrfTokenRepository requires the response as an attribute
|
|
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
|
|
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
- verify(this.tokenRepository).saveToken(this.token, this.request, this.response);
|
|
|
|
verifyNoMoreInteractions(this.deniedHandler);
|
|
verifyNoMoreInteractions(this.deniedHandler);
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
@Test
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException {
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException {
|
|
- this.filter = createCsrfFilter(this.tokenRepository);
|
|
|
|
|
|
+ this.filter = new CsrfFilter(this.tokenRepository);
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
|
|
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
|
|
resetRequestResponse();
|
|
resetRequestResponse();
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.setMethod(method);
|
|
this.request.setMethod(method);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
@@ -273,11 +278,12 @@ public class CsrfFilterTests {
|
|
*/
|
|
*/
|
|
@Test
|
|
@Test
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception {
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception {
|
|
- this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
|
|
|
|
|
|
+ this.filter = new CsrfFilter(this.tokenRepository);
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) {
|
|
for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) {
|
|
resetRequestResponse();
|
|
resetRequestResponse();
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.setMethod(method);
|
|
this.request.setMethod(method);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
|
|
@@ -288,11 +294,12 @@ public class CsrfFilterTests {
|
|
|
|
|
|
@Test
|
|
@Test
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException {
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException {
|
|
- this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
|
|
|
|
|
|
+ this.filter = new CsrfFilter(this.tokenRepository);
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) {
|
|
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) {
|
|
resetRequestResponse();
|
|
resetRequestResponse();
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.setMethod(method);
|
|
this.request.setMethod(method);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
|
|
@@ -303,10 +310,11 @@ public class CsrfFilterTests {
|
|
|
|
|
|
@Test
|
|
@Test
|
|
public void doFilterDefaultAccessDenied() throws ServletException, IOException {
|
|
public void doFilterDefaultAccessDenied() throws ServletException, IOException {
|
|
- this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
|
|
|
|
|
|
+ this.filter = new CsrfFilter(this.tokenRepository);
|
|
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
|
|
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
@@ -317,7 +325,7 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception {
|
|
public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception {
|
|
CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
|
|
CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
|
|
- CsrfFilter filter = createCsrfFilter(repository);
|
|
|
|
|
|
+ CsrfFilter filter = new CsrfFilter(repository);
|
|
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
|
|
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
|
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
|
CsrfFilter.skipRequest(request);
|
|
CsrfFilter.skipRequest(request);
|
|
@@ -333,7 +341,8 @@ public class CsrfFilterTests {
|
|
given(token.getToken()).willReturn(null);
|
|
given(token.getToken()).willReturn(null);
|
|
given(token.getHeaderName()).willReturn(this.token.getHeaderName());
|
|
given(token.getHeaderName()).willReturn(this.token.getHeaderName());
|
|
given(token.getParameterName()).willReturn(this.token.getParameterName());
|
|
given(token.getParameterName()).willReturn(this.token.getParameterName());
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(token, false));
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
filter.doFilterInternal(this.request, this.response, this.filterChain);
|
|
filter.doFilterInternal(this.request, this.response, this.filterChain);
|
|
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
|
|
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
|
|
@@ -341,13 +350,15 @@ public class CsrfFilterTests {
|
|
|
|
|
|
@Test
|
|
@Test
|
|
public void doFilterWhenRequestHandlerThenUsed() throws Exception {
|
|
public void doFilterWhenRequestHandlerThenUsed() throws Exception {
|
|
- CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
|
|
|
|
- given(requestHandler.handle(this.request, this.response))
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
- this.filter = createCsrfFilter(requestHandler);
|
|
|
|
|
|
+ CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
|
|
|
|
+ this.filter = createCsrfFilter(this.tokenRepository);
|
|
|
|
+ this.filter.setRequestHandler(requestHandler);
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
- verify(requestHandler).handle(eq(this.request), eq(this.response));
|
|
|
|
|
|
+ verify(this.tokenRepository).loadDeferredToken(this.request, this.response);
|
|
|
|
+ verify(requestHandler).handle(eq(this.request), eq(this.response), any());
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
}
|
|
}
|
|
|
|
|
|
@@ -365,41 +376,20 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet()
|
|
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet()
|
|
throws ServletException, IOException {
|
|
throws ServletException, IOException {
|
|
|
|
+ CsrfFilter filter = createCsrfFilter(this.tokenRepository);
|
|
String csrfAttrName = "_csrf";
|
|
String csrfAttrName = "_csrf";
|
|
- CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository);
|
|
|
|
|
|
+ CsrfTokenRequestAttributeHandler requestHandler = new CsrfTokenRequestAttributeHandler();
|
|
requestHandler.setCsrfRequestAttributeName(csrfAttrName);
|
|
requestHandler.setCsrfRequestAttributeName(csrfAttrName);
|
|
- this.filter = createCsrfFilter(requestHandler);
|
|
|
|
- CsrfToken expectedCsrfToken = spy(this.token);
|
|
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken);
|
|
|
|
|
|
+ filter.setRequestHandler(requestHandler);
|
|
|
|
+ CsrfToken expectedCsrfToken = mock(CsrfToken.class);
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true));
|
|
|
|
|
|
- this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
|
|
|
+ filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
|
|
|
verifyNoInteractions(expectedCsrfToken);
|
|
verifyNoInteractions(expectedCsrfToken);
|
|
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
|
|
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
|
|
assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken);
|
|
assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken);
|
|
}
|
|
}
|
|
|
|
|
|
- private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
|
|
|
|
-
|
|
|
|
- private final CsrfToken csrfToken;
|
|
|
|
-
|
|
|
|
- private final boolean isGenerated;
|
|
|
|
-
|
|
|
|
- private TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) {
|
|
|
|
- this.csrfToken = csrfToken;
|
|
|
|
- this.isGenerated = isGenerated;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public CsrfToken get() {
|
|
|
|
- return this.csrfToken;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public boolean isGenerated() {
|
|
|
|
- return this.isGenerated;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
}
|
|
}
|