|
@@ -24,8 +24,6 @@ import javax.servlet.ServletException;
|
|
|
import javax.servlet.http.HttpServletRequest;
|
|
|
import javax.servlet.http.HttpServletResponse;
|
|
|
|
|
|
-import org.assertj.core.api.AbstractObjectAssert;
|
|
|
-import org.assertj.core.api.ObjectAssert;
|
|
|
import org.junit.jupiter.api.BeforeEach;
|
|
|
import org.junit.jupiter.api.Test;
|
|
|
import org.junit.jupiter.api.extension.ExtendWith;
|
|
@@ -46,10 +44,12 @@ import static org.mockito.BDDMockito.given;
|
|
|
import static org.mockito.Mockito.lenient;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
import static org.mockito.Mockito.never;
|
|
|
+import static org.mockito.Mockito.spy;
|
|
|
import static org.mockito.Mockito.times;
|
|
|
import static org.mockito.Mockito.verify;
|
|
|
import static org.mockito.Mockito.verifyNoInteractions;
|
|
|
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
|
|
+import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
|
|
|
|
|
|
/**
|
|
|
* @author Rob Winch
|
|
@@ -126,8 +126,8 @@ public class CsrfFilterTests {
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
|
|
verifyNoMoreInteractions(this.filterChain);
|
|
|
}
|
|
@@ -138,8 +138,8 @@ public class CsrfFilterTests {
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
|
|
verifyNoMoreInteractions(this.filterChain);
|
|
|
}
|
|
@@ -150,8 +150,8 @@ public class CsrfFilterTests {
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
|
|
verifyNoMoreInteractions(this.filterChain);
|
|
|
}
|
|
@@ -164,8 +164,8 @@ public class CsrfFilterTests {
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
|
|
|
verifyNoMoreInteractions(this.filterChain);
|
|
|
}
|
|
@@ -175,8 +175,8 @@ public class CsrfFilterTests {
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
|
verifyNoMoreInteractions(this.deniedHandler);
|
|
|
}
|
|
@@ -186,8 +186,8 @@ public class CsrfFilterTests {
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(false);
|
|
|
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
- assertToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
|
verifyNoMoreInteractions(this.deniedHandler);
|
|
|
}
|
|
@@ -198,8 +198,8 @@ public class CsrfFilterTests {
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
|
verifyNoMoreInteractions(this.deniedHandler);
|
|
|
}
|
|
@@ -212,8 +212,8 @@ public class CsrfFilterTests {
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
|
|
|
this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
|
verifyNoMoreInteractions(this.deniedHandler);
|
|
|
}
|
|
@@ -224,8 +224,8 @@ public class CsrfFilterTests {
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
|
verifyNoMoreInteractions(this.deniedHandler);
|
|
|
verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class),
|
|
@@ -238,8 +238,8 @@ public class CsrfFilterTests {
|
|
|
given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
- assertToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
// LazyCsrfTokenRepository requires the response as an attribute
|
|
|
assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
@@ -304,8 +304,8 @@ public class CsrfFilterTests {
|
|
|
given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- assertThat(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
- assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
|
|
|
+ assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
|
|
|
assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
|
|
|
verifyNoMoreInteractions(this.filterChain);
|
|
|
}
|
|
@@ -336,14 +336,14 @@ public class CsrfFilterTests {
|
|
|
}
|
|
|
|
|
|
@Test
|
|
|
- public void doFilterWhenRequestAttributeHandlerThenUsed() throws Exception {
|
|
|
- given(this.requestMatcher.matches(this.request)).willReturn(true);
|
|
|
- given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
|
|
|
- CsrfTokenRequestAttributeHandler requestAttributeHandler = mock(CsrfTokenRequestAttributeHandler.class);
|
|
|
- this.filter.setRequestAttributeHandler(requestAttributeHandler);
|
|
|
+ public void doFilterWhenRequestHandlerThenUsed() throws Exception {
|
|
|
+ CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
|
|
|
+ given(requestHandler.handle(this.request, this.response))
|
|
|
+ .willReturn(new TestDeferredCsrfToken(this.token, false));
|
|
|
+ this.filter.setRequestHandler(requestHandler);
|
|
|
this.request.setParameter(this.token.getParameterName(), this.token.getToken());
|
|
|
this.filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
- verify(requestAttributeHandler).handle(eq(this.request), eq(this.response), any());
|
|
|
+ verify(requestHandler).handle(eq(this.request), eq(this.response));
|
|
|
verify(this.filterChain).doFilter(this.request, this.response);
|
|
|
}
|
|
|
|
|
@@ -376,39 +376,40 @@ public class CsrfFilterTests {
|
|
|
CsrfFilter filter = createCsrfFilter(this.tokenRepository);
|
|
|
String csrfAttrName = "_csrf";
|
|
|
CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor();
|
|
|
+ csrfTokenRequestProcessor.setTokenRepository(this.tokenRepository);
|
|
|
csrfTokenRequestProcessor.setCsrfRequestAttributeName(csrfAttrName);
|
|
|
- filter.setRequestAttributeHandler(csrfTokenRequestProcessor);
|
|
|
- CsrfToken expectedCsrfToken = mock(CsrfToken.class);
|
|
|
+ filter.setRequestHandler(csrfTokenRequestProcessor);
|
|
|
+ CsrfToken expectedCsrfToken = spy(this.token);
|
|
|
given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken);
|
|
|
|
|
|
filter.doFilter(this.request, this.response, this.filterChain);
|
|
|
|
|
|
verifyNoInteractions(expectedCsrfToken);
|
|
|
CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
|
|
|
- assertThat(tokenFromRequest).isEqualTo(expectedCsrfToken);
|
|
|
+ assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken);
|
|
|
}
|
|
|
|
|
|
- private static CsrfTokenAssert assertToken(Object token) {
|
|
|
- return new CsrfTokenAssert((CsrfToken) token);
|
|
|
- }
|
|
|
+ private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
|
|
|
+
|
|
|
+ private final CsrfToken csrfToken;
|
|
|
|
|
|
- private static class CsrfTokenAssert extends AbstractObjectAssert<CsrfTokenAssert, CsrfToken> {
|
|
|
+ private final boolean isGenerated;
|
|
|
|
|
|
- /**
|
|
|
- * Creates a new {@link ObjectAssert}.
|
|
|
- * @param actual the target to verify.
|
|
|
- */
|
|
|
- protected CsrfTokenAssert(CsrfToken actual) {
|
|
|
- super(actual, CsrfTokenAssert.class);
|
|
|
+ private TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) {
|
|
|
+ this.csrfToken = csrfToken;
|
|
|
+ this.isGenerated = isGenerated;
|
|
|
}
|
|
|
|
|
|
- CsrfTokenAssert isEqualTo(CsrfToken expected) {
|
|
|
- assertThat(this.actual.getHeaderName()).isEqualTo(expected.getHeaderName());
|
|
|
- assertThat(this.actual.getParameterName()).isEqualTo(expected.getParameterName());
|
|
|
- assertThat(this.actual.getToken()).isEqualTo(expected.getToken());
|
|
|
- return this;
|
|
|
+ @Override
|
|
|
+ public CsrfToken get() {
|
|
|
+ return this.csrfToken;
|
|
|
}
|
|
|
|
|
|
- }
|
|
|
+ @Override
|
|
|
+ public boolean isGenerated() {
|
|
|
+ return this.isGenerated;
|
|
|
+ }
|
|
|
+
|
|
|
+ };
|
|
|
|
|
|
}
|