|
@@ -23,8 +23,6 @@ import jakarta.servlet.FilterChain;
|
|
|
import jakarta.servlet.ServletException;
|
|
|
import jakarta.servlet.http.HttpServletRequest;
|
|
|
import jakarta.servlet.http.HttpServletResponse;
|
|
|
-import org.assertj.core.api.AbstractObjectAssert;
|
|
|
-import org.assertj.core.api.ObjectAssert;
|
|
|
import org.junit.jupiter.api.BeforeEach;
|
|
|
import org.junit.jupiter.api.Test;
|
|
|
import org.junit.jupiter.api.extension.ExtendWith;
|
|
@@ -45,10 +43,12 @@ import static org.mockito.BDDMockito.given;
|
|
|
import static org.mockito.Mockito.lenient;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
import static org.mockito.Mockito.never;
|
|
|
+import static org.mockito.Mockito.spy;
|
|
|
import static org.mockito.Mockito.times;
|
|
|
import static org.mockito.Mockito.verify;
|
|
|
import static org.mockito.Mockito.verifyNoInteractions;
|
|
|
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
|
|
+import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
|
|
|
|
|
|
/**
|
|
|
* @author Rob Winch
|
|
@@ -127,8 +127,8 @@ public class CsrfFilterTests {
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
|
|
verifyNoMoreInteractions(this.filterChain);
|
|
|
}
|
|
@@ -139,8 +139,8 @@ public class CsrfFilterTests {
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
|
|
verifyNoMoreInteractions(this.filterChain);
|
|
|
}
|
|
@@ -151,8 +151,8 @@ public class CsrfFilterTests {
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
|
|
verifyNoMoreInteractions(this.filterChain);
|
|
|
}
|
|
@@ -165,8 +165,8 @@ public class CsrfFilterTests {
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
|
|
verifyNoMoreInteractions(this.filterChain);
|
|
|
}
|
|
@@ -176,8 +176,8 @@ public class CsrfFilterTests {
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
|
verifyNoMoreInteractions(this.deniedHandler);
|
|
|
}
|
|
@@ -187,8 +187,8 @@ public class CsrfFilterTests {
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
|
|
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
- assertToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
|
verifyNoMoreInteractions(this.deniedHandler);
|
|
|
}
|
|
@@ -199,8 +199,8 @@ public class CsrfFilterTests {
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
|
verifyNoMoreInteractions(this.deniedHandler);
|
|
|
}
|
|
@@ -213,8 +213,8 @@ public class CsrfFilterTests {
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
|
verifyNoMoreInteractions(this.deniedHandler);
|
|
|
}
|
|
@@ -225,8 +225,8 @@ public class CsrfFilterTests {
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
|
verifyNoMoreInteractions(this.deniedHandler);
|
|
|
verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class),
|
|
@@ -239,8 +239,8 @@ public class CsrfFilterTests {
|
|
|
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
- assertToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
// LazyCsrfTokenRepository requires the response as an attribute
|
|
|
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
@@ -254,7 +254,6 @@ public class CsrfFilterTests {
|
|
|
this.filter.setAccessDeniedHandler(this.deniedHandler);
|
|
|
for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
|
|
|
resetRequestResponse();
|
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.request.setMethod(method);
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
@@ -305,8 +304,8 @@ public class CsrfFilterTests {
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
|
|
|
verifyNoMoreInteractions(this.filterChain);
|
|
|
}
|
|
@@ -337,14 +336,14 @@ public class CsrfFilterTests {
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterWhenRequestAttributeHandlerThenUsed() throws Exception {
|
|
|
- given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
- CsrfTokenRequestAttributeHandler requestAttributeHandler = mock(CsrfTokenRequestAttributeHandler.class);
|
|
|
- this.filter.setRequestAttributeHandler(requestAttributeHandler);
|
|
|
+ public void doFilterWhenRequestHandlerThenUsed() throws Exception {
|
|
|
+ CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
|
|
|
+ given(requestHandler.handle(this.request, this.response))
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
|
+ this.filter.setRequestHandler(requestHandler);
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- verify(requestAttributeHandler).handle(eq(this.request), eq(this.response), any());
|
|
|
+ verify(requestHandler).handle(eq(this.request), eq(this.response));
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
|
}
|
|
|
|
|
@@ -377,39 +376,40 @@ public class CsrfFilterTests {
|
|
|
CsrfFilter filter = createCsrfFilter(this.tokenRepository);
|
|
|
String csrfAttrName = "_csrf";
|
|
|
CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor();
|
|
|
+ csrfTokenRequestProcessor.setTokenRepository(this.tokenRepository);
|
|
|
csrfTokenRequestProcessor.setCsrfRequestAttributeName(csrfAttrName);
|
|
|
- filter.setRequestAttributeHandler(csrfTokenRequestProcessor);
|
|
|
- CsrfToken expectedCsrfToken = mock(CsrfToken.class);
|
|
|
+ filter.setRequestHandler(csrfTokenRequestProcessor);
|
|
|
+ CsrfToken expectedCsrfToken = spy(this.token);
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken);
|
|
|
|
|
|
filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
|
|
|
verifyNoInteractions(expectedCsrfToken);
|
|
|
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
|
|
|
- assertThat(tokenFromRequest).isEqualTo(expectedCsrfToken);
|
|
|
+ assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken);
|
|
|
}
|
|
|
|
|
|
- private static CsrfTokenAssert assertToken(Object token) {
|
|
|
- return new CsrfTokenAssert((CsrfToken) token);
|
|
|
- }
|
|
|
+ private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
|
|
|
|
|
|
- private static class CsrfTokenAssert extends AbstractObjectAssert<CsrfTokenAssert, CsrfToken> {
|
|
|
+ private final CsrfToken csrfToken;
|
|
|
|
|
|
- /**
|
|
|
- * Creates a new {@link ObjectAssert}.
|
|
|
- * @param actual the target to verify.
|
|
|
- */
|
|
|
- protected CsrfTokenAssert(CsrfToken actual) {
|
|
|
- super(actual, CsrfTokenAssert.class);
|
|
|
+ private final boolean isGenerated;
|
|
|
+
|
|
|
+ private TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) {
|
|
|
+ this.csrfToken = csrfToken;
|
|
|
+ this.isGenerated = isGenerated;
|
|
|
}
|
|
|
|
|
|
- CsrfTokenAssert isEqualTo(CsrfToken expected) {
|
|
|
- assertThat(this.actual.getHeaderName()).isEqualTo(expected.getHeaderName());
|
|
|
- assertThat(this.actual.getParameterName()).isEqualTo(expected.getParameterName());
|
|
|
- assertThat(this.actual.getToken()).isEqualTo(expected.getToken());
|
|
|
- return this;
|
|
|
+ @Override
|
|
|
+ public CsrfToken get() {
|
|
|
+ return this.csrfToken;
|
|
|
}
|
|
|
|
|
|
- }
|
|
|
+ @Override
|
|
|
+ public boolean isGenerated() {
|
|
|
+ return this.isGenerated;
|
|
|
+ }
|
|
|
+
|
|
|
+ };
|
|
|
|
|
|
}
|