Browse Source

SEC-2276: Delay saving CsrfToken until token is accessed

This also removed the CsrfToken from the response headers to prevent the
token from being saved. If user's wish to return the CsrfToken in the
response headers, they should use the CsrfToken found on the request.
Rob Winch 12 years ago
parent
commit
48283ec004

+ 2 - 1
config/src/test/groovy/org/springframework/security/config/annotation/BaseSpringSpec.groovy

@@ -36,6 +36,7 @@ import org.springframework.security.web.access.intercept.FilterSecurityIntercept
 import org.springframework.security.web.context.HttpRequestResponseHolder
 import org.springframework.security.web.context.HttpRequestResponseHolder
 import org.springframework.security.web.context.HttpSessionSecurityContextRepository
 import org.springframework.security.web.context.HttpSessionSecurityContextRepository
 import org.springframework.security.web.csrf.CsrfToken
 import org.springframework.security.web.csrf.CsrfToken
+import org.springframework.security.web.csrf.DefaultCsrfToken;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository
 
 
 import spock.lang.AutoCleanup
 import spock.lang.AutoCleanup
@@ -69,7 +70,7 @@ abstract class BaseSpringSpec extends Specification {
     }
     }
 
 
     def setupCsrf(csrfTokenValue="BaseSpringSpec_CSRFTOKEN") {
     def setupCsrf(csrfTokenValue="BaseSpringSpec_CSRFTOKEN") {
-        csrfToken = new CsrfToken("X-CSRF-TOKEN","_csrf",csrfTokenValue)
+        csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf",csrfTokenValue)
         new HttpSessionCsrfTokenRepository().saveToken(csrfToken, request,response)
         new HttpSessionCsrfTokenRepository().saveToken(csrfToken, request,response)
         request.setParameter(csrfToken.parameterName, csrfToken.token)
         request.setParameter(csrfToken.parameterName, csrfToken.token)
     }
     }

+ 1 - 2
config/src/test/groovy/org/springframework/security/config/annotation/web/WebSecurityConfigurerAdapterTests.groovy

@@ -79,8 +79,7 @@ class WebSecurityConfigurerAdapterTests extends BaseSpringSpec {
                          'Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains',
                          'Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains',
                          'Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
                          'Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
                          'Pragma':'no-cache',
                          'Pragma':'no-cache',
-                         'X-XSS-Protection' : '1; mode=block',
-                         'X-CSRF-TOKEN' : csrfToken.token]
+                         'X-XSS-Protection' : '1; mode=block']
     }
     }
 
 
     @EnableWebSecurity
     @EnableWebSecurity

+ 10 - 20
config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/NamespaceHttpHeadersTests.groovy

@@ -49,8 +49,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
                 'Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains',
                 'Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains',
                 'Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
                 'Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
                 'Pragma':'no-cache',
                 'Pragma':'no-cache',
-                'X-XSS-Protection' : '1; mode=block',
-                'X-CSRF-TOKEN' : csrfToken.token]
+                'X-XSS-Protection' : '1; mode=block']
     }
     }
 
 
     @Configuration
     @Configuration
@@ -70,8 +69,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
             springSecurityFilterChain.doFilter(request,response,chain)
             springSecurityFilterChain.doFilter(request,response,chain)
         then:
         then:
             responseHeaders == ['Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
             responseHeaders == ['Cache-Control': 'no-cache,no-store,max-age=0,must-revalidate',
-                'Pragma':'no-cache',
-                'X-CSRF-TOKEN' : csrfToken.token]
+                'Pragma':'no-cache']
     }
     }
 
 
     @Configuration
     @Configuration
@@ -91,8 +89,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
         when:
         when:
             springSecurityFilterChain.doFilter(request,response,chain)
             springSecurityFilterChain.doFilter(request,response,chain)
         then:
         then:
-            responseHeaders == ['Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains',
-                'X-CSRF-TOKEN' : csrfToken.token]
+            responseHeaders == ['Strict-Transport-Security': 'max-age=31536000 ; includeSubDomains']
     }
     }
 
 
     @Configuration
     @Configuration
@@ -111,8 +108,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
         when:
         when:
             springSecurityFilterChain.doFilter(request,response,chain)
             springSecurityFilterChain.doFilter(request,response,chain)
         then:
         then:
-            responseHeaders == ['Strict-Transport-Security': 'max-age=15768000',
-                'X-CSRF-TOKEN' : csrfToken.token]
+            responseHeaders == ['Strict-Transport-Security': 'max-age=15768000']
     }
     }
 
 
     @Configuration
     @Configuration
@@ -133,8 +129,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
         when:
         when:
             springSecurityFilterChain.doFilter(request,response,chain)
             springSecurityFilterChain.doFilter(request,response,chain)
         then:
         then:
-            responseHeaders == ['X-Frame-Options': 'SAMEORIGIN',
-                'X-CSRF-TOKEN' : csrfToken.token]
+            responseHeaders == ['X-Frame-Options': 'SAMEORIGIN']
     }
     }
 
 
     @Configuration
     @Configuration
@@ -156,8 +151,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
         when:
         when:
             springSecurityFilterChain.doFilter(request,response,chain)
             springSecurityFilterChain.doFilter(request,response,chain)
         then:
         then:
-            responseHeaders == ['X-Frame-Options': 'ALLOW-FROM https://example.com',
-                'X-CSRF-TOKEN' : csrfToken.token]
+            responseHeaders == ['X-Frame-Options': 'ALLOW-FROM https://example.com']
     }
     }
 
 
 
 
@@ -178,8 +172,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
         when:
         when:
             springSecurityFilterChain.doFilter(request,response,chain)
             springSecurityFilterChain.doFilter(request,response,chain)
         then:
         then:
-            responseHeaders == ['X-XSS-Protection': '1; mode=block',
-                'X-CSRF-TOKEN' : csrfToken.token]
+            responseHeaders == ['X-XSS-Protection': '1; mode=block']
     }
     }
 
 
     @Configuration
     @Configuration
@@ -199,8 +192,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
         when:
         when:
             springSecurityFilterChain.doFilter(request,response,chain)
             springSecurityFilterChain.doFilter(request,response,chain)
         then:
         then:
-            responseHeaders == ['X-XSS-Protection': '1',
-                'X-CSRF-TOKEN' : csrfToken.token]
+            responseHeaders == ['X-XSS-Protection': '1']
     }
     }
 
 
     @Configuration
     @Configuration
@@ -220,8 +212,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
         when:
         when:
             springSecurityFilterChain.doFilter(request,response,chain)
             springSecurityFilterChain.doFilter(request,response,chain)
         then:
         then:
-            responseHeaders == ['X-Content-Type-Options': 'nosniff',
-                'X-CSRF-TOKEN' : csrfToken.token]
+            responseHeaders == ['X-Content-Type-Options': 'nosniff']
     }
     }
 
 
     @Configuration
     @Configuration
@@ -243,8 +234,7 @@ public class NamespaceHttpHeadersTests extends BaseSpringSpec {
         when:
         when:
             springSecurityFilterChain.doFilter(request,response,chain)
             springSecurityFilterChain.doFilter(request,response,chain)
         then:
         then:
-            responseHeaders == ['customHeaderName': 'customHeaderValue',
-                'X-CSRF-TOKEN' : csrfToken.token]
+            responseHeaders == ['customHeaderName': 'customHeaderValue']
     }
     }
 
 
     @Configuration
     @Configuration

+ 6 - 5
config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy

@@ -29,6 +29,7 @@ import org.springframework.security.web.access.AccessDeniedHandler;
 import org.springframework.security.web.csrf.CsrfFilter
 import org.springframework.security.web.csrf.CsrfFilter
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
+import org.springframework.security.web.csrf.DefaultCsrfToken;
 import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor
 import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor
 import org.springframework.security.web.util.RequestMatcher
 import org.springframework.security.web.util.RequestMatcher
 
 
@@ -113,7 +114,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
             mockBean(CsrfTokenRepository,'repo')
             mockBean(CsrfTokenRepository,'repo')
             createAppContext()
             createAppContext()
             CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
             CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
-            CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
+            CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
             when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
             when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
             request.setParameter(token.parameterName,token.token)
             request.setParameter(token.parameterName,token.token)
             request.servletPath = "/some-url"
             request.servletPath = "/some-url"
@@ -147,7 +148,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
             mockBean(CsrfTokenRepository,'repo')
             mockBean(CsrfTokenRepository,'repo')
             createAppContext()
             createAppContext()
             CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
             CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
-            CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
+            CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
             when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
             when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
             request.setParameter(token.parameterName,token.token)
             request.setParameter(token.parameterName,token.token)
             request.servletPath = "/some-url"
             request.servletPath = "/some-url"
@@ -200,7 +201,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
             mockBean(CsrfTokenRepository,'repo')
             mockBean(CsrfTokenRepository,'repo')
             createAppContext()
             createAppContext()
             CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
             CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
-            CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
+            CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
             when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
             when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
             request.setParameter(token.parameterName,token.token)
             request.setParameter(token.parameterName,token.token)
             request.method = "POST"
             request.method = "POST"
@@ -223,7 +224,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
             mockBean(CsrfTokenRepository,'repo')
             mockBean(CsrfTokenRepository,'repo')
             createAppContext()
             createAppContext()
             CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
             CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
-            CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
+            CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
             when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
             when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
             request.setParameter(token.parameterName,token.token)
             request.setParameter(token.parameterName,token.token)
             request.method = "POST"
             request.method = "POST"
@@ -244,7 +245,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
             mockBean(CsrfTokenRepository,'repo')
             mockBean(CsrfTokenRepository,'repo')
             createAppContext()
             createAppContext()
             CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
             CsrfTokenRepository repo = appContext.getBean("repo",CsrfTokenRepository)
-            CsrfToken token = new CsrfToken("X-CSRF-TOKEN","_csrf", "abc")
+            CsrfToken token = new DefaultCsrfToken("X-CSRF-TOKEN","_csrf", "abc")
             when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
             when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
             request.setParameter(token.parameterName,token.token)
             request.setParameter(token.parameterName,token.token)
             request.method = "POST"
             request.method = "POST"

+ 3 - 2
config/src/test/java/org/springframework/security/config/annotation/web/configurers/SessionManagementConfigurerServlet31Tests.java

@@ -40,7 +40,6 @@ import org.springframework.context.annotation.Configuration;
 import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
-import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder;
 import org.springframework.security.config.annotation.authentication.builders.AuthenticationManagerBuilder;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
@@ -96,7 +95,9 @@ public class SessionManagementConfigurerServlet31Tests {
         request.setMethod("POST");
         request.setMethod("POST");
         request.setParameter("username", "user");
         request.setParameter("username", "user");
         request.setParameter("password", "password");
         request.setParameter("password", "password");
-        CsrfToken token = new HttpSessionCsrfTokenRepository().generateAndSaveToken(request, response);
+        HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository();
+        CsrfToken token = repository.generateToken(request);
+        repository.saveToken(token, request, response);
         request.setParameter(token.getParameterName(),token.getToken());
         request.setParameter(token.getParameterName(),token.getToken());
         when(ReflectionUtils.findMethod(HttpServletRequest.class, "changeSessionId")).thenReturn(method);
         when(ReflectionUtils.findMethod(HttpServletRequest.class, "changeSessionId")).thenReturn(method);
 
 

+ 82 - 3
web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java

@@ -70,11 +70,11 @@ public final class CsrfFilter extends OncePerRequestFilter {
             throws ServletException, IOException {
             throws ServletException, IOException {
         CsrfToken csrfToken = tokenRepository.loadToken(request);
         CsrfToken csrfToken = tokenRepository.loadToken(request);
         if(csrfToken == null) {
         if(csrfToken == null) {
-            csrfToken = tokenRepository.generateAndSaveToken(request, response);
+            CsrfToken generatedToken = tokenRepository.generateToken(request);
+            csrfToken = new SaveOnAccessCsrfToken(tokenRepository, request, response, generatedToken);
         }
         }
         request.setAttribute(CsrfToken.class.getName(), csrfToken);
         request.setAttribute(CsrfToken.class.getName(), csrfToken);
         request.setAttribute(csrfToken.getParameterName(), csrfToken);
         request.setAttribute(csrfToken.getParameterName(), csrfToken);
-        response.addHeader(csrfToken.getHeaderName(), csrfToken.getToken());
 
 
         if(!requireCsrfProtectionMatcher.matches(request)) {
         if(!requireCsrfProtectionMatcher.matches(request)) {
             filterChain.doFilter(request, response);
             filterChain.doFilter(request, response);
@@ -128,7 +128,86 @@ public final class CsrfFilter extends OncePerRequestFilter {
         this.accessDeniedHandler = accessDeniedHandler;
         this.accessDeniedHandler = accessDeniedHandler;
     }
     }
 
 
-    private static class DefaultRequiresCsrfMatcher implements RequestMatcher {
+    @SuppressWarnings("serial")
+    private static final class SaveOnAccessCsrfToken implements CsrfToken {
+        private transient CsrfTokenRepository tokenRepository;
+        private transient HttpServletRequest request;
+        private transient HttpServletResponse response;
+
+        private final CsrfToken delegate;
+
+        public SaveOnAccessCsrfToken(CsrfTokenRepository tokenRepository,
+                HttpServletRequest request, HttpServletResponse response,
+                CsrfToken delegate) {
+            super();
+            this.tokenRepository = tokenRepository;
+            this.request = request;
+            this.response = response;
+            this.delegate = delegate;
+        }
+
+        public String getHeaderName() {
+            return delegate.getHeaderName();
+        }
+
+        public String getParameterName() {
+            return delegate.getParameterName();
+        }
+
+        public String getToken() {
+            saveTokenIfNecessary();
+            return delegate.getToken();
+        }
+
+        @Override
+        public String toString() {
+            return "SaveOnAccessCsrfToken [delegate=" + delegate + "]";
+        }
+
+        @Override
+        public int hashCode() {
+            final int prime = 31;
+            int result = 1;
+            result = prime * result
+                    + ((delegate == null) ? 0 : delegate.hashCode());
+            return result;
+        }
+
+        @Override
+        public boolean equals(Object obj) {
+            if (this == obj)
+                return true;
+            if (obj == null)
+                return false;
+            if (getClass() != obj.getClass())
+                return false;
+            SaveOnAccessCsrfToken other = (SaveOnAccessCsrfToken) obj;
+            if (delegate == null) {
+                if (other.delegate != null)
+                    return false;
+            } else if (!delegate.equals(other.delegate))
+                return false;
+            return true;
+        }
+
+        private void saveTokenIfNecessary() {
+            if(this.tokenRepository == null) {
+                return;
+            }
+
+            synchronized(this) {
+                if(tokenRepository != null) {
+                    this.tokenRepository.saveToken(delegate, request, response);
+                    this.tokenRepository = null;
+                    this.request = null;
+                    this.response = null;
+                }
+            }
+        }
+
+    }
+
+    private static final class DefaultRequiresCsrfMatcher implements RequestMatcher {
         private Pattern allowedMethods = Pattern.compile("^(GET|HEAD|TRACE|OPTIONS)$");
         private Pattern allowedMethods = Pattern.compile("^(GET|HEAD|TRACE|OPTIONS)$");
 
 
         /* (non-Javadoc)
         /* (non-Javadoc)

+ 10 - 36
web/src/main/java/org/springframework/security/web/csrf/CsrfToken.java

@@ -17,37 +17,16 @@ package org.springframework.security.web.csrf;
 
 
 import java.io.Serializable;
 import java.io.Serializable;
 
 
-import org.springframework.util.Assert;
-
 /**
 /**
- * A CSRF token that is used to protect against CSRF attacks.
+ * Provides the information about an expected CSRF token.
+ *
+ * @see DefaultCsrfToken
  *
  *
  * @author Rob Winch
  * @author Rob Winch
  * @since 3.2
  * @since 3.2
+ *
  */
  */
-@SuppressWarnings("serial")
-public final class CsrfToken implements Serializable {
-
-    private final String token;
-
-    private final String parameterName;
-
-    private final String headerName;
-
-    /**
-     * Creates a new instance
-     * @param headerName the HTTP header name to use
-     * @param parameterName the HTTP parameter name to use
-     * @param token the value of the token (i.e. expected value of the HTTP parameter of parametername).
-     */
-    public CsrfToken(String headerName, String parameterName, String token) {
-        Assert.hasLength(headerName, "headerName cannot be null or empty");
-        Assert.hasLength(parameterName, "parameterName cannot be null or empty");
-        Assert.hasLength(token, "token cannot be null or empty");
-        this.headerName = headerName;
-        this.parameterName = parameterName;
-        this.token = token;
-    }
+public interface CsrfToken extends Serializable {
 
 
     /**
     /**
      * Gets the HTTP header that the CSRF is populated on the response and can
      * Gets the HTTP header that the CSRF is populated on the response and can
@@ -56,23 +35,18 @@ public final class CsrfToken implements Serializable {
      * @return the HTTP header that the CSRF is populated on the response and
      * @return the HTTP header that the CSRF is populated on the response and
      *         can be placed on requests instead of the parameter
      *         can be placed on requests instead of the parameter
      */
      */
-    public String getHeaderName() {
-        return headerName;
-    }
+    String getHeaderName();
 
 
     /**
     /**
      * Gets the HTTP parameter name that should contain the token. Cannot be null.
      * Gets the HTTP parameter name that should contain the token. Cannot be null.
      * @return the HTTP parameter name that should contain the token.
      * @return the HTTP parameter name that should contain the token.
      */
      */
-    public String getParameterName() {
-        return parameterName;
-    }
+    String getParameterName();
 
 
     /**
     /**
      * Gets the token value. Cannot be null.
      * Gets the token value. Cannot be null.
      * @return the token value
      * @return the token value
      */
      */
-    public String getToken() {
-        return token;
-    }
-}
+    String getToken();
+
+}

+ 3 - 6
web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java

@@ -33,17 +33,14 @@ import javax.servlet.http.HttpSession;
 public interface CsrfTokenRepository {
 public interface CsrfTokenRepository {
 
 
     /**
     /**
-     * Generates and saves the expected {@link CsrfToken}
+     * Generates a {@link CsrfToken}
      *
      *
      * @param request
      * @param request
      *            the {@link HttpServletRequest} to use
      *            the {@link HttpServletRequest} to use
-     * @param response
-     *            the {@link HttpServletResponse} to use
-     * @return the {@link CsrfToken} that was generated and saved. Cannot be
+     * @return the {@link CsrfToken} that was generated. Cannot be
      *         null.
      *         null.
      */
      */
-    CsrfToken generateAndSaveToken(HttpServletRequest request,
-            HttpServletResponse response);
+    CsrfToken generateToken(HttpServletRequest request);
 
 
     /**
     /**
      * Saves the {@link CsrfToken} using the {@link HttpServletRequest} and
      * Saves the {@link CsrfToken} using the {@link HttpServletRequest} and

+ 70 - 0
web/src/main/java/org/springframework/security/web/csrf/DefaultCsrfToken.java

@@ -0,0 +1,70 @@
+/*
+ * Copyright 2002-2013 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.web.csrf;
+
+import org.springframework.util.Assert;
+
+/**
+ * A CSRF token that is used to protect against CSRF attacks.
+ *
+ * @author Rob Winch
+ * @since 3.2
+ */
+@SuppressWarnings("serial")
+public final class DefaultCsrfToken implements CsrfToken {
+
+    private final String token;
+
+    private final String parameterName;
+
+    private final String headerName;
+
+    /**
+     * Creates a new instance
+     * @param headerName the HTTP header name to use
+     * @param parameterName the HTTP parameter name to use
+     * @param token the value of the token (i.e. expected value of the HTTP parameter of parametername).
+     */
+    public DefaultCsrfToken(String headerName, String parameterName, String token) {
+        Assert.hasLength(headerName, "headerName cannot be null or empty");
+        Assert.hasLength(parameterName, "parameterName cannot be null or empty");
+        Assert.hasLength(token, "token cannot be null or empty");
+        this.headerName = headerName;
+        this.parameterName = parameterName;
+        this.token = token;
+    }
+
+    /* (non-Javadoc)
+     * @see org.springframework.security.web.csrf.CsrfToken#getHeaderName()
+     */
+    public String getHeaderName() {
+        return headerName;
+    }
+
+    /* (non-Javadoc)
+     * @see org.springframework.security.web.csrf.CsrfToken#getParameterName()
+     */
+    public String getParameterName() {
+        return parameterName;
+    }
+
+    /* (non-Javadoc)
+     * @see org.springframework.security.web.csrf.CsrfToken#getToken()
+     */
+    public String getToken() {
+        return token;
+    }
+}

+ 5 - 7
web/src/main/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepository.java

@@ -63,14 +63,12 @@ public final class HttpSessionCsrfTokenRepository implements CsrfTokenRepository
         return (CsrfToken) request.getSession().getAttribute(sessionAttributeName);
         return (CsrfToken) request.getSession().getAttribute(sessionAttributeName);
     }
     }
 
 
-    /* (non-Javadoc)
-     * @see org.springframework.security.web.csrf.CsrfTokenRepository#generateNewToken(javax.servlet.http.HttpServletRequest, javax.servlet.http.HttpServletResponse)
+    /*
+     * (non-Javadoc)
+     * @see org.springframework.security.web.csrf.CsrfTokenRepository#generateToken(javax.servlet.http.HttpServletRequest)
      */
      */
-    public CsrfToken generateAndSaveToken(HttpServletRequest request,
-            HttpServletResponse response) {
-        CsrfToken token = new CsrfToken(headerName, parameterName, createNewToken());
-        saveToken(token, request, response);
-        return token;
+    public CsrfToken generateToken(HttpServletRequest request) {
+        return new DefaultCsrfToken(headerName, parameterName, createNewToken());
     }
     }
 
 
     /**
     /**

+ 147 - 62
web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java

@@ -18,6 +18,7 @@ package org.springframework.security.web.csrf;
 import static org.fest.assertions.Assertions.assertThat;
 import static org.fest.assertions.Assertions.assertThat;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
 import static org.mockito.Mockito.when;
@@ -27,8 +28,11 @@ import java.util.Arrays;
 
 
 import javax.servlet.FilterChain;
 import javax.servlet.FilterChain;
 import javax.servlet.ServletException;
 import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import javax.servlet.http.HttpServletResponse;
 
 
+import org.fest.assertions.GenericAssert;
+import org.fest.assertions.ObjectAssert;
 import org.junit.Before;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runner.RunWith;
@@ -59,12 +63,12 @@ public class CsrfFilterTests {
     private MockHttpServletResponse response;
     private MockHttpServletResponse response;
     private CsrfToken token;
     private CsrfToken token;
 
 
-
     private CsrfFilter filter;
     private CsrfFilter filter;
 
 
     @Before
     @Before
     public void setup() {
     public void setup() {
-        token = new CsrfToken("headerName","paramName", "csrfTokenValue");
+        token = new DefaultCsrfToken("headerName", "paramName",
+                "csrfTokenValue");
         resetRequestResponse();
         resetRequestResponse();
         filter = new CsrfFilter(tokenRepository);
         filter = new CsrfFilter(tokenRepository);
         filter.setRequireCsrfProtectionMatcher(requestMatcher);
         filter.setRequireCsrfProtectionMatcher(requestMatcher);
@@ -81,171 +85,221 @@ public class CsrfFilterTests {
         new CsrfFilter(null);
         new CsrfFilter(null);
     }
     }
 
 
+    // SEC-2276
+    @Test
+    public void doFilterDoesNotSaveCsrfTokenUntilAccessed() throws ServletException,
+            IOException {
+        when(requestMatcher.matches(request)).thenReturn(false);
+        when(tokenRepository.generateToken(request)).thenReturn(token);
+
+        filter.doFilter(request, response, filterChain);
+        CsrfToken attrToken = (CsrfToken) request.getAttribute(token.getParameterName());
+
+        // no CsrfToken should have been saved yet
+        verify(tokenRepository,times(0)).saveToken(any(CsrfToken.class), any(HttpServletRequest.class), any(HttpServletResponse.class));
+        verify(filterChain).doFilter(request, response);
+
+        // access the token
+        attrToken.getToken();
+
+        // now the CsrfToken should have been saved
+        verify(tokenRepository).saveToken(eq(token), any(HttpServletRequest.class), any(HttpServletResponse.class));
+    }
+
     @Test
     @Test
-    public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
+    public void doFilterAccessDeniedNoTokenPresent() throws ServletException,
+            IOException {
         when(requestMatcher.matches(request)).thenReturn(true);
         when(requestMatcher.matches(request)).thenReturn(true);
         when(tokenRepository.loadToken(request)).thenReturn(token);
         when(tokenRepository.loadToken(request)).thenReturn(token);
 
 
         filter.doFilter(request, response, filterChain);
         filter.doFilter(request, response, filterChain);
 
 
-        assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
-        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
-        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
+        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
+                token);
+        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
+                token);
 
 
-        verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
+        verify(deniedHandler).handle(eq(request), eq(response),
+                any(InvalidCsrfTokenException.class));
         verifyZeroInteractions(filterChain);
         verifyZeroInteractions(filterChain);
     }
     }
 
 
     @Test
     @Test
-    public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
+    public void doFilterAccessDeniedIncorrectTokenPresent()
+            throws ServletException, IOException {
         when(requestMatcher.matches(request)).thenReturn(true);
         when(requestMatcher.matches(request)).thenReturn(true);
         when(tokenRepository.loadToken(request)).thenReturn(token);
         when(tokenRepository.loadToken(request)).thenReturn(token);
-        request.setParameter(token.getParameterName(), token.getToken()+ " INVALID");
+        request.setParameter(token.getParameterName(), token.getToken()
+                + " INVALID");
 
 
         filter.doFilter(request, response, filterChain);
         filter.doFilter(request, response, filterChain);
 
 
-        assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
-        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
-        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
+        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
+                token);
+        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
+                token);
 
 
-        verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
+        verify(deniedHandler).handle(eq(request), eq(response),
+                any(InvalidCsrfTokenException.class));
         verifyZeroInteractions(filterChain);
         verifyZeroInteractions(filterChain);
     }
     }
 
 
     @Test
     @Test
-    public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
+    public void doFilterAccessDeniedIncorrectTokenPresentHeader()
+            throws ServletException, IOException {
         when(requestMatcher.matches(request)).thenReturn(true);
         when(requestMatcher.matches(request)).thenReturn(true);
         when(tokenRepository.loadToken(request)).thenReturn(token);
         when(tokenRepository.loadToken(request)).thenReturn(token);
-        request.addHeader(token.getHeaderName(), token.getToken()+ " INVALID");
+        request.addHeader(token.getHeaderName(), token.getToken() + " INVALID");
 
 
         filter.doFilter(request, response, filterChain);
         filter.doFilter(request, response, filterChain);
 
 
-        assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
-        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
-        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
+        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
+                token);
+        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
+                token);
 
 
-        verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
+        verify(deniedHandler).handle(eq(request), eq(response),
+                any(InvalidCsrfTokenException.class));
         verifyZeroInteractions(filterChain);
         verifyZeroInteractions(filterChain);
     }
     }
 
 
     @Test
     @Test
-    public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter() throws ServletException, IOException {
+    public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
+            throws ServletException, IOException {
         when(requestMatcher.matches(request)).thenReturn(true);
         when(requestMatcher.matches(request)).thenReturn(true);
         when(tokenRepository.loadToken(request)).thenReturn(token);
         when(tokenRepository.loadToken(request)).thenReturn(token);
         request.setParameter(token.getParameterName(), token.getToken());
         request.setParameter(token.getParameterName(), token.getToken());
-        request.addHeader(token.getHeaderName(), token.getToken()+ " INVALID");
+        request.addHeader(token.getHeaderName(), token.getToken() + " INVALID");
 
 
         filter.doFilter(request, response, filterChain);
         filter.doFilter(request, response, filterChain);
 
 
-        assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
-        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
-        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
+        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
+                token);
+        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
+                token);
 
 
-        verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
+        verify(deniedHandler).handle(eq(request), eq(response),
+                any(InvalidCsrfTokenException.class));
         verifyZeroInteractions(filterChain);
         verifyZeroInteractions(filterChain);
     }
     }
 
 
     @Test
     @Test
-    public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
+    public void doFilterNotCsrfRequestExistingToken() throws ServletException,
+            IOException {
         when(requestMatcher.matches(request)).thenReturn(false);
         when(requestMatcher.matches(request)).thenReturn(false);
         when(tokenRepository.loadToken(request)).thenReturn(token);
         when(tokenRepository.loadToken(request)).thenReturn(token);
 
 
         filter.doFilter(request, response, filterChain);
         filter.doFilter(request, response, filterChain);
 
 
-        assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
-        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
-        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
+        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
+                token);
+        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
+                token);
 
 
         verify(filterChain).doFilter(request, response);
         verify(filterChain).doFilter(request, response);
         verifyZeroInteractions(deniedHandler);
         verifyZeroInteractions(deniedHandler);
     }
     }
 
 
     @Test
     @Test
-    public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
+    public void doFilterNotCsrfRequestGenerateToken() throws ServletException,
+            IOException {
         when(requestMatcher.matches(request)).thenReturn(false);
         when(requestMatcher.matches(request)).thenReturn(false);
-        when(tokenRepository.generateAndSaveToken(request, response)).thenReturn(token);
+        when(tokenRepository.generateToken(request))
+                .thenReturn(token);
 
 
         filter.doFilter(request, response, filterChain);
         filter.doFilter(request, response, filterChain);
 
 
-        assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
-        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
-        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
+        assertToken(request.getAttribute(token.getParameterName())).isEqualTo(
+                token);
+        assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
+                token);
 
 
         verify(filterChain).doFilter(request, response);
         verify(filterChain).doFilter(request, response);
         verifyZeroInteractions(deniedHandler);
         verifyZeroInteractions(deniedHandler);
     }
     }
 
 
     @Test
     @Test
-    public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
+    public void doFilterIsCsrfRequestExistingTokenHeader()
+            throws ServletException, IOException {
         when(requestMatcher.matches(request)).thenReturn(true);
         when(requestMatcher.matches(request)).thenReturn(true);
         when(tokenRepository.loadToken(request)).thenReturn(token);
         when(tokenRepository.loadToken(request)).thenReturn(token);
         request.addHeader(token.getHeaderName(), token.getToken());
         request.addHeader(token.getHeaderName(), token.getToken());
 
 
         filter.doFilter(request, response, filterChain);
         filter.doFilter(request, response, filterChain);
 
 
-        assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
-        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
-        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
+        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
+                token);
+        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
+                token);
 
 
         verify(filterChain).doFilter(request, response);
         verify(filterChain).doFilter(request, response);
         verifyZeroInteractions(deniedHandler);
         verifyZeroInteractions(deniedHandler);
     }
     }
 
 
     @Test
     @Test
-    public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam() throws ServletException, IOException {
+    public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
+            throws ServletException, IOException {
         when(requestMatcher.matches(request)).thenReturn(true);
         when(requestMatcher.matches(request)).thenReturn(true);
         when(tokenRepository.loadToken(request)).thenReturn(token);
         when(tokenRepository.loadToken(request)).thenReturn(token);
-        request.setParameter(token.getParameterName(), token.getToken()+ " INVALID");
+        request.setParameter(token.getParameterName(), token.getToken()
+                + " INVALID");
         request.addHeader(token.getHeaderName(), token.getToken());
         request.addHeader(token.getHeaderName(), token.getToken());
 
 
         filter.doFilter(request, response, filterChain);
         filter.doFilter(request, response, filterChain);
 
 
-        assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
-        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
-        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
+        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
+                token);
+        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
+                token);
 
 
         verify(filterChain).doFilter(request, response);
         verify(filterChain).doFilter(request, response);
         verifyZeroInteractions(deniedHandler);
         verifyZeroInteractions(deniedHandler);
     }
     }
 
 
     @Test
     @Test
-    public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
+    public void doFilterIsCsrfRequestExistingToken() throws ServletException,
+            IOException {
         when(requestMatcher.matches(request)).thenReturn(true);
         when(requestMatcher.matches(request)).thenReturn(true);
         when(tokenRepository.loadToken(request)).thenReturn(token);
         when(tokenRepository.loadToken(request)).thenReturn(token);
         request.setParameter(token.getParameterName(), token.getToken());
         request.setParameter(token.getParameterName(), token.getToken());
 
 
         filter.doFilter(request, response, filterChain);
         filter.doFilter(request, response, filterChain);
 
 
-        assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
-        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
-        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
+        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
+                token);
+        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
+                token);
 
 
         verify(filterChain).doFilter(request, response);
         verify(filterChain).doFilter(request, response);
         verifyZeroInteractions(deniedHandler);
         verifyZeroInteractions(deniedHandler);
     }
     }
 
 
     @Test
     @Test
-    public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
+    public void doFilterIsCsrfRequestGenerateToken() throws ServletException,
+            IOException {
         when(requestMatcher.matches(request)).thenReturn(true);
         when(requestMatcher.matches(request)).thenReturn(true);
-        when(tokenRepository.generateAndSaveToken(request, response)).thenReturn(token);
+        when(tokenRepository.generateToken(request))
+                .thenReturn(token);
         request.setParameter(token.getParameterName(), token.getToken());
         request.setParameter(token.getParameterName(), token.getToken());
 
 
         filter.doFilter(request, response, filterChain);
         filter.doFilter(request, response, filterChain);
 
 
-        assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
-        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
-        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
+        assertToken(request.getAttribute(token.getParameterName())).isEqualTo(
+                token);
+        assertToken(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
+                token);
 
 
         verify(filterChain).doFilter(request, response);
         verify(filterChain).doFilter(request, response);
         verifyZeroInteractions(deniedHandler);
         verifyZeroInteractions(deniedHandler);
     }
     }
 
 
     @Test
     @Test
-    public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException {
+    public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods()
+            throws ServletException, IOException {
         filter = new CsrfFilter(tokenRepository);
         filter = new CsrfFilter(tokenRepository);
         filter.setAccessDeniedHandler(deniedHandler);
         filter.setAccessDeniedHandler(deniedHandler);
 
 
-        for(String method : Arrays.asList("GET","TRACE", "OPTIONS", "HEAD")) {
+        for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
             resetRequestResponse();
             resetRequestResponse();
             when(tokenRepository.loadToken(request)).thenReturn(token);
             when(tokenRepository.loadToken(request)).thenReturn(token);
             request.setMethod(method);
             request.setMethod(method);
@@ -258,24 +312,28 @@ public class CsrfFilterTests {
     }
     }
 
 
     @Test
     @Test
-    public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException {
+    public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods()
+            throws ServletException, IOException {
         filter = new CsrfFilter(tokenRepository);
         filter = new CsrfFilter(tokenRepository);
         filter.setAccessDeniedHandler(deniedHandler);
         filter.setAccessDeniedHandler(deniedHandler);
 
 
-        for(String method : Arrays.asList("POST","PUT", "PATCH", "DELETE", "INVALID")) {
+        for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE",
+                "INVALID")) {
             resetRequestResponse();
             resetRequestResponse();
             when(tokenRepository.loadToken(request)).thenReturn(token);
             when(tokenRepository.loadToken(request)).thenReturn(token);
             request.setMethod(method);
             request.setMethod(method);
 
 
             filter.doFilter(request, response, filterChain);
             filter.doFilter(request, response, filterChain);
 
 
-            verify(deniedHandler).handle(eq(request), eq(response), any(InvalidCsrfTokenException.class));
+            verify(deniedHandler).handle(eq(request), eq(response),
+                    any(InvalidCsrfTokenException.class));
             verifyZeroInteractions(filterChain);
             verifyZeroInteractions(filterChain);
         }
         }
     }
     }
 
 
     @Test
     @Test
-    public void doFilterDefaultAccessDenied() throws ServletException, IOException {
+    public void doFilterDefaultAccessDenied() throws ServletException,
+            IOException {
         filter = new CsrfFilter(tokenRepository);
         filter = new CsrfFilter(tokenRepository);
         filter.setRequireCsrfProtectionMatcher(requestMatcher);
         filter.setRequireCsrfProtectionMatcher(requestMatcher);
         when(requestMatcher.matches(request)).thenReturn(true);
         when(requestMatcher.matches(request)).thenReturn(true);
@@ -283,11 +341,13 @@ public class CsrfFilterTests {
 
 
         filter.doFilter(request, response, filterChain);
         filter.doFilter(request, response, filterChain);
 
 
-        assertThat(response.getHeader(token.getHeaderName())).isEqualTo(token.getToken());
-        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(token);
-        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(token);
+        assertThat(request.getAttribute(token.getParameterName())).isEqualTo(
+                token);
+        assertThat(request.getAttribute(CsrfToken.class.getName())).isEqualTo(
+                token);
 
 
-        assertThat(response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
+        assertThat(response.getStatus()).isEqualTo(
+                HttpServletResponse.SC_FORBIDDEN);
         verifyZeroInteractions(filterChain);
         verifyZeroInteractions(filterChain);
     }
     }
 
 
@@ -300,4 +360,29 @@ public class CsrfFilterTests {
     public void setAccessDeniedHandlerNull() {
     public void setAccessDeniedHandlerNull() {
         filter.setAccessDeniedHandler(null);
         filter.setAccessDeniedHandler(null);
     }
     }
+
+    private static final CsrfTokenAssert assertToken(Object token) {
+        return new CsrfTokenAssert((CsrfToken)token);
+    }
+
+    private static class CsrfTokenAssert extends
+            GenericAssert<CsrfTokenAssert, CsrfToken> {
+
+        /**
+         * Creates a new </code>{@link ObjectAssert}</code>.
+         *
+         * @param actual
+         *            the target to verify.
+         */
+        protected CsrfTokenAssert(CsrfToken actual) {
+            super(CsrfTokenAssert.class, actual);
+        }
+
+        public CsrfTokenAssert isEqualTo(CsrfToken expected) {
+            assertThat(actual.getHeaderName()).isEqualTo(expected.getHeaderName());
+            assertThat(actual.getParameterName()).isEqualTo(expected.getParameterName());
+            assertThat(actual.getToken()).isEqualTo(expected.getToken());
+            return this;
+        }
+    }
 }
 }

+ 7 - 7
web/src/test/java/org/springframework/security/web/csrf/CsrfTokenTests.java → web/src/test/java/org/springframework/security/web/csrf/DefaultCsrfTokenTests.java

@@ -21,38 +21,38 @@ import org.junit.Test;
  * @author Rob Winch
  * @author Rob Winch
  *
  *
  */
  */
-public class CsrfTokenTests {
+public class DefaultCsrfTokenTests {
     private final String headerName = "headerName";
     private final String headerName = "headerName";
     private final String parameterName = "parameterName";
     private final String parameterName = "parameterName";
     private final String tokenValue = "tokenValue";
     private final String tokenValue = "tokenValue";
 
 
     @Test(expected = IllegalArgumentException.class)
     @Test(expected = IllegalArgumentException.class)
     public void constructorNullHeaderName() {
     public void constructorNullHeaderName() {
-        new CsrfToken(null,parameterName, tokenValue);
+        new DefaultCsrfToken(null,parameterName, tokenValue);
     }
     }
 
 
     @Test(expected = IllegalArgumentException.class)
     @Test(expected = IllegalArgumentException.class)
     public void constructorEmptyHeaderName() {
     public void constructorEmptyHeaderName() {
-        new CsrfToken("",parameterName, tokenValue);
+        new DefaultCsrfToken("",parameterName, tokenValue);
     }
     }
 
 
     @Test(expected = IllegalArgumentException.class)
     @Test(expected = IllegalArgumentException.class)
     public void constructorNullParameterName() {
     public void constructorNullParameterName() {
-        new CsrfToken(headerName,null, tokenValue);
+        new DefaultCsrfToken(headerName,null, tokenValue);
     }
     }
 
 
     @Test(expected = IllegalArgumentException.class)
     @Test(expected = IllegalArgumentException.class)
     public void constructorEmptyParameterName() {
     public void constructorEmptyParameterName() {
-        new CsrfToken(headerName,"", tokenValue);
+        new DefaultCsrfToken(headerName,"", tokenValue);
     }
     }
 
 
     @Test(expected = IllegalArgumentException.class)
     @Test(expected = IllegalArgumentException.class)
     public void constructorNullTokenValue() {
     public void constructorNullTokenValue() {
-        new CsrfToken(headerName,parameterName, null);
+        new DefaultCsrfToken(headerName,parameterName, null);
     }
     }
 
 
     @Test(expected = IllegalArgumentException.class)
     @Test(expected = IllegalArgumentException.class)
     public void constructorEmptyTokenValue() {
     public void constructorEmptyTokenValue() {
-        new CsrfToken(headerName,parameterName, "");
+        new DefaultCsrfToken(headerName,parameterName, "");
     }
     }
 }
 }

+ 7 - 7
web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java

@@ -42,23 +42,23 @@ public class HttpSessionCsrfTokenRepositoryTests {
     }
     }
 
 
     @Test
     @Test
-    public void generateAndSaveToken() {
-        token = repo.generateAndSaveToken(request, response);
+    public void generateToken() {
+        token = repo.generateToken(request);
 
 
         assertThat(token.getParameterName()).isEqualTo("_csrf");
         assertThat(token.getParameterName()).isEqualTo("_csrf");
         assertThat(token.getToken()).isNotEmpty();
         assertThat(token.getToken()).isNotEmpty();
 
 
         CsrfToken loadedToken = repo.loadToken(request);
         CsrfToken loadedToken = repo.loadToken(request);
 
 
-        assertThat(loadedToken).isEqualTo(token);
+        assertThat(loadedToken).isNull();
     }
     }
 
 
     @Test
     @Test
-    public void generateAndSaveTokenCustomParameter() {
+    public void generateCustomParameter() {
         String paramName = "_csrf";
         String paramName = "_csrf";
         repo.setParameterName(paramName);
         repo.setParameterName(paramName);
 
 
-        token = repo.generateAndSaveToken(request, response);
+        token = repo.generateToken(request);
 
 
         assertThat(token.getParameterName()).isEqualTo(paramName);
         assertThat(token.getParameterName()).isEqualTo(paramName);
         assertThat(token.getToken()).isNotEmpty();
         assertThat(token.getToken()).isNotEmpty();
@@ -71,7 +71,7 @@ public class HttpSessionCsrfTokenRepositoryTests {
 
 
     @Test
     @Test
     public void saveToken() {
     public void saveToken() {
-        CsrfToken tokenToSave = new CsrfToken("123", "abc", "def");
+        CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def");
         repo.saveToken(tokenToSave, request, response);
         repo.saveToken(tokenToSave, request, response);
 
 
         String attrName = request.getSession().getAttributeNames()
         String attrName = request.getSession().getAttributeNames()
@@ -84,7 +84,7 @@ public class HttpSessionCsrfTokenRepositoryTests {
 
 
     @Test
     @Test
     public void saveTokenCustomSessionAttribute() {
     public void saveTokenCustomSessionAttribute() {
-        CsrfToken tokenToSave = new CsrfToken("123", "abc", "def");
+        CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def");
         String sessionAttributeName = "custom";
         String sessionAttributeName = "custom";
         repo.setSessionAttributeName(sessionAttributeName);
         repo.setSessionAttributeName(sessionAttributeName);
         repo.saveToken(tokenToSave, request, response);
         repo.saveToken(tokenToSave, request, response);

+ 2 - 1
web/src/test/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessorTests.java

@@ -25,6 +25,7 @@ import org.junit.Test;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.security.web.csrf.CsrfToken;
+import org.springframework.security.web.csrf.DefaultCsrfToken;
 
 
 /**
 /**
  * @author Rob Winch
  * @author Rob Winch
@@ -51,7 +52,7 @@ public class CsrfRequestDataValueProcessorTests {
 
 
     @Test
     @Test
     public void getExtraHiddenFieldsHasCsrfToken() {
     public void getExtraHiddenFieldsHasCsrfToken() {
-        CsrfToken token = new CsrfToken("1", "a", "b");
+        CsrfToken token = new DefaultCsrfToken("1", "a", "b");
         request.setAttribute(CsrfToken.class.getName(), token);
         request.setAttribute(CsrfToken.class.getName(), token);
         Map<String,String> expected = new HashMap<String,String>();
         Map<String,String> expected = new HashMap<String,String>();
         expected.put(token.getParameterName(),token.getToken());
         expected.put(token.getParameterName(),token.getToken());