|
@@ -318,13 +318,13 @@ public final class SecurityMockMvcRequestPostProcessors {
|
|
|
*/
|
|
|
@Override
|
|
|
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
|
|
|
-
|
|
|
CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
|
|
|
if (!(repository instanceof TestCsrfTokenRepository)) {
|
|
|
repository = new TestCsrfTokenRepository(
|
|
|
new HttpSessionCsrfTokenRepository());
|
|
|
WebTestUtils.setCsrfTokenRepository(request, repository);
|
|
|
}
|
|
|
+ TestCsrfTokenRepository.enable(request);
|
|
|
CsrfToken token = repository.generateToken(request);
|
|
|
repository.saveToken(token, request, new MockHttpServletResponse());
|
|
|
String tokenValue = this.useInvalidToken ? "invalid" + token.getToken()
|
|
@@ -367,9 +367,12 @@ public final class SecurityMockMvcRequestPostProcessors {
|
|
|
* request is wrapped (i.e. Spring Session is in use).
|
|
|
*/
|
|
|
static class TestCsrfTokenRepository implements CsrfTokenRepository {
|
|
|
- final static String ATTR_NAME = TestCsrfTokenRepository.class.getName()
|
|
|
+ final static String TOKEN_ATTR_NAME = TestCsrfTokenRepository.class.getName()
|
|
|
.concat(".TOKEN");
|
|
|
|
|
|
+ final static String ENABLED_ATTR_NAME = TestCsrfTokenRepository.class
|
|
|
+ .getName().concat(".ENABLED");
|
|
|
+
|
|
|
private final CsrfTokenRepository delegate;
|
|
|
|
|
|
private TestCsrfTokenRepository(CsrfTokenRepository delegate) {
|
|
@@ -384,12 +387,30 @@ public final class SecurityMockMvcRequestPostProcessors {
|
|
|
@Override
|
|
|
public void saveToken(CsrfToken token, HttpServletRequest request,
|
|
|
HttpServletResponse response) {
|
|
|
- request.setAttribute(ATTR_NAME, token);
|
|
|
+ if (isEnabled(request)) {
|
|
|
+ request.setAttribute(TOKEN_ATTR_NAME, token);
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ this.delegate.saveToken(token, request, response);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
@Override
|
|
|
public CsrfToken loadToken(HttpServletRequest request) {
|
|
|
- return (CsrfToken) request.getAttribute(ATTR_NAME);
|
|
|
+ if (isEnabled(request)) {
|
|
|
+ return (CsrfToken) request.getAttribute(TOKEN_ATTR_NAME);
|
|
|
+ }
|
|
|
+ else {
|
|
|
+ return this.delegate.loadToken(request);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ public static void enable(HttpServletRequest request) {
|
|
|
+ request.setAttribute(ENABLED_ATTR_NAME, Boolean.TRUE);
|
|
|
+ }
|
|
|
+
|
|
|
+ public boolean isEnabled(HttpServletRequest request) {
|
|
|
+ return Boolean.TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME));
|
|
|
}
|
|
|
}
|
|
|
}
|