|
@@ -43,7 +43,6 @@ import static org.mockito.BDDMockito.given;
|
|
import static org.mockito.Mockito.lenient;
|
|
import static org.mockito.Mockito.lenient;
|
|
import static org.mockito.Mockito.mock;
|
|
import static org.mockito.Mockito.mock;
|
|
import static org.mockito.Mockito.never;
|
|
import static org.mockito.Mockito.never;
|
|
-import static org.mockito.Mockito.spy;
|
|
|
|
import static org.mockito.Mockito.times;
|
|
import static org.mockito.Mockito.times;
|
|
import static org.mockito.Mockito.verify;
|
|
import static org.mockito.Mockito.verify;
|
|
import static org.mockito.Mockito.verifyNoInteractions;
|
|
import static org.mockito.Mockito.verifyNoInteractions;
|
|
@@ -87,11 +86,7 @@ public class CsrfFilterTests {
|
|
}
|
|
}
|
|
|
|
|
|
private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) {
|
|
private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) {
|
|
- return createCsrfFilter(new CsrfTokenRepositoryRequestHandler(repository));
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- private CsrfFilter createCsrfFilter(CsrfTokenRequestHandler requestHandler) {
|
|
|
|
- CsrfFilter filter = new CsrfFilter(requestHandler);
|
|
|
|
|
|
+ CsrfFilter filter = new CsrfFilter(repository);
|
|
filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
|
|
filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
|
|
filter.setAccessDeniedHandler(this.deniedHandler);
|
|
filter.setAccessDeniedHandler(this.deniedHandler);
|
|
return filter;
|
|
return filter;
|
|
@@ -104,7 +99,7 @@ public class CsrfFilterTests {
|
|
|
|
|
|
@Test
|
|
@Test
|
|
public void constructorNullRepository() {
|
|
public void constructorNullRepository() {
|
|
- assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter((CsrfTokenRequestHandler) null));
|
|
|
|
|
|
+ assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter(null));
|
|
}
|
|
}
|
|
|
|
|
|
// SEC-2276
|
|
// SEC-2276
|
|
@@ -129,7 +124,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
|
|
public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
@@ -140,7 +136,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
|
|
public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
@@ -152,7 +149,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
@@ -165,7 +163,8 @@ public class CsrfFilterTests {
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
|
|
public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
|
|
throws ServletException, IOException {
|
|
throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
@@ -178,7 +177,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
|
|
public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
@@ -189,7 +189,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
|
|
public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
|
- given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, true));
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
@@ -200,7 +201,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
|
|
public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
@@ -213,7 +215,8 @@ public class CsrfFilterTests {
|
|
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
|
|
public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
|
|
throws ServletException, IOException {
|
|
throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
@@ -226,7 +229,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
|
|
public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
@@ -240,7 +244,8 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
|
|
public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, true));
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
@@ -248,16 +253,17 @@ public class CsrfFilterTests {
|
|
// LazyCsrfTokenRepository requires the response as an attribute
|
|
// LazyCsrfTokenRepository requires the response as an attribute
|
|
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
|
|
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
- verify(this.tokenRepository).saveToken(this.token, this.request, this.response);
|
|
|
|
verifyNoMoreInteractions(this.deniedHandler);
|
|
verifyNoMoreInteractions(this.deniedHandler);
|
|
}
|
|
}
|
|
|
|
|
|
@Test
|
|
@Test
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException {
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException {
|
|
- this.filter = createCsrfFilter(this.tokenRepository);
|
|
|
|
|
|
+ this.filter = new CsrfFilter(this.tokenRepository);
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
|
|
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
|
|
resetRequestResponse();
|
|
resetRequestResponse();
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.setMethod(method);
|
|
this.request.setMethod(method);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
@@ -273,11 +279,12 @@ public class CsrfFilterTests {
|
|
*/
|
|
*/
|
|
@Test
|
|
@Test
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception {
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception {
|
|
- this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
|
|
|
|
|
|
+ this.filter = new CsrfFilter(this.tokenRepository);
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) {
|
|
for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) {
|
|
resetRequestResponse();
|
|
resetRequestResponse();
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.setMethod(method);
|
|
this.request.setMethod(method);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
|
|
@@ -288,11 +295,12 @@ public class CsrfFilterTests {
|
|
|
|
|
|
@Test
|
|
@Test
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException {
|
|
public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException {
|
|
- this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
|
|
|
|
|
|
+ this.filter = new CsrfFilter(this.tokenRepository);
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) {
|
|
for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) {
|
|
resetRequestResponse();
|
|
resetRequestResponse();
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.request.setMethod(method);
|
|
this.request.setMethod(method);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
|
|
@@ -303,10 +311,11 @@ public class CsrfFilterTests {
|
|
|
|
|
|
@Test
|
|
@Test
|
|
public void doFilterDefaultAccessDenied() throws ServletException, IOException {
|
|
public void doFilterDefaultAccessDenied() throws ServletException, IOException {
|
|
- this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
|
|
|
|
|
|
+ this.filter = new CsrfFilter(this.tokenRepository);
|
|
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
|
|
this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
@@ -317,7 +326,7 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception {
|
|
public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception {
|
|
CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
|
|
CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
|
|
- CsrfFilter filter = createCsrfFilter(repository);
|
|
|
|
|
|
+ CsrfFilter filter = new CsrfFilter(repository);
|
|
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
|
|
lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
|
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
|
MockHttpServletRequest request = new MockHttpServletRequest();
|
|
CsrfFilter.skipRequest(request);
|
|
CsrfFilter.skipRequest(request);
|
|
@@ -333,7 +342,8 @@ public class CsrfFilterTests {
|
|
given(token.getToken()).willReturn(null);
|
|
given(token.getToken()).willReturn(null);
|
|
given(token.getHeaderName()).willReturn(this.token.getHeaderName());
|
|
given(token.getHeaderName()).willReturn(this.token.getHeaderName());
|
|
given(token.getParameterName()).willReturn(this.token.getParameterName());
|
|
given(token.getParameterName()).willReturn(this.token.getParameterName());
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(token);
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(token, false));
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
filter.doFilterInternal(this.request, this.response, this.filterChain);
|
|
filter.doFilterInternal(this.request, this.response, this.filterChain);
|
|
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
|
|
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
|
|
@@ -341,13 +351,15 @@ public class CsrfFilterTests {
|
|
|
|
|
|
@Test
|
|
@Test
|
|
public void doFilterWhenRequestHandlerThenUsed() throws Exception {
|
|
public void doFilterWhenRequestHandlerThenUsed() throws Exception {
|
|
- CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
|
|
|
|
- given(requestHandler.handle(this.request, this.response))
|
|
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
.willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
- this.filter = createCsrfFilter(requestHandler);
|
|
|
|
|
|
+ CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
|
|
|
|
+ this.filter = createCsrfFilter(this.tokenRepository);
|
|
|
|
+ this.filter.setRequestHandler(requestHandler);
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
- verify(requestHandler).handle(eq(this.request), eq(this.response));
|
|
|
|
|
|
+ verify(this.tokenRepository).loadDeferredToken(this.request, this.response);
|
|
|
|
+ verify(requestHandler).handle(eq(this.request), eq(this.response), any());
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
}
|
|
}
|
|
|
|
|
|
@@ -365,41 +377,20 @@ public class CsrfFilterTests {
|
|
@Test
|
|
@Test
|
|
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet()
|
|
public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet()
|
|
throws ServletException, IOException {
|
|
throws ServletException, IOException {
|
|
|
|
+ CsrfFilter filter = createCsrfFilter(this.tokenRepository);
|
|
String csrfAttrName = "_csrf";
|
|
String csrfAttrName = "_csrf";
|
|
- CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository);
|
|
|
|
|
|
+ CsrfTokenRequestAttributeHandler requestHandler = new CsrfTokenRequestAttributeHandler();
|
|
requestHandler.setCsrfRequestAttributeName(csrfAttrName);
|
|
requestHandler.setCsrfRequestAttributeName(csrfAttrName);
|
|
- this.filter = createCsrfFilter(requestHandler);
|
|
|
|
- CsrfToken expectedCsrfToken = spy(this.token);
|
|
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken);
|
|
|
|
|
|
+ filter.setRequestHandler(requestHandler);
|
|
|
|
+ CsrfToken expectedCsrfToken = mock(CsrfToken.class);
|
|
|
|
+ given(this.tokenRepository.loadDeferredToken(this.request, this.response))
|
|
|
|
+ .willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true));
|
|
|
|
|
|
- this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
|
|
|
+ filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
|
|
|
verifyNoInteractions(expectedCsrfToken);
|
|
verifyNoInteractions(expectedCsrfToken);
|
|
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
|
|
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
|
|
assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken);
|
|
assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken);
|
|
}
|
|
}
|
|
|
|
|
|
- private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
|
|
|
|
-
|
|
|
|
- private final CsrfToken csrfToken;
|
|
|
|
-
|
|
|
|
- private final boolean isGenerated;
|
|
|
|
-
|
|
|
|
- private TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) {
|
|
|
|
- this.csrfToken = csrfToken;
|
|
|
|
- this.isGenerated = isGenerated;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public CsrfToken get() {
|
|
|
|
- return this.csrfToken;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- @Override
|
|
|
|
- public boolean isGenerated() {
|
|
|
|
- return this.isGenerated;
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
}
|
|
}
|