2
0
Эх сурвалжийг харах

SEC-2872: CsrfAuthenticationStrategy Delay Saving CsrfToken

Rob Winch 10 жил өмнө
parent
commit
dfaebfa63b

+ 84 - 4
web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java

@@ -53,11 +53,91 @@ public final class CsrfAuthenticationStrategy implements
             throws SessionAuthenticationException {
         boolean containsToken = this.csrfTokenRepository.loadToken(request) != null;
         if(containsToken) {
-            CsrfToken newToken = this.csrfTokenRepository.generateToken(request);
             this.csrfTokenRepository.saveToken(null, request, response);
-            this.csrfTokenRepository.saveToken(newToken, request, response);
-            request.setAttribute(CsrfToken.class.getName(), newToken);
-            request.setAttribute(newToken.getParameterName(), newToken);
+
+            CsrfToken newToken = this.csrfTokenRepository.generateToken(request);
+            CsrfToken tokenForRequest = new SaveOnAccessCsrfToken(csrfTokenRepository, request, response, newToken);
+
+            request.setAttribute(CsrfToken.class.getName(), tokenForRequest);
+            request.setAttribute(newToken.getParameterName(), tokenForRequest);
+        }
+    }
+
+    private static final class SaveOnAccessCsrfToken implements CsrfToken {
+        private transient CsrfTokenRepository tokenRepository;
+        private transient HttpServletRequest request;
+        private transient HttpServletResponse response;
+
+        private final CsrfToken delegate;
+
+        public SaveOnAccessCsrfToken(CsrfTokenRepository tokenRepository,
+                                     HttpServletRequest request, HttpServletResponse response,
+                                     CsrfToken delegate) {
+            super();
+            this.tokenRepository = tokenRepository;
+            this.request = request;
+            this.response = response;
+            this.delegate = delegate;
+        }
+
+        public String getHeaderName() {
+            return delegate.getHeaderName();
+        }
+
+        public String getParameterName() {
+            return delegate.getParameterName();
         }
+
+        public String getToken() {
+            saveTokenIfNecessary();
+            return delegate.getToken();
+        }
+
+        @Override
+        public String toString() {
+            return "SaveOnAccessCsrfToken [delegate=" + delegate + "]";
+        }
+
+        @Override
+        public int hashCode() {
+            final int prime = 31;
+            int result = 1;
+            result = prime * result
+                    + ((delegate == null) ? 0 : delegate.hashCode());
+            return result;
+        }
+
+        @Override
+        public boolean equals(Object obj) {
+            if (this == obj)
+                return true;
+            if (obj == null)
+                return false;
+            if (getClass() != obj.getClass())
+                return false;
+            SaveOnAccessCsrfToken other = (SaveOnAccessCsrfToken) obj;
+            if (delegate == null) {
+                if (other.delegate != null)
+                    return false;
+            } else if (!delegate.equals(other.delegate))
+                return false;
+            return true;
+        }
+
+        private void saveTokenIfNecessary() {
+            if(this.tokenRepository == null) {
+                return;
+            }
+
+            synchronized(this) {
+                if(tokenRepository != null) {
+                    this.tokenRepository.saveToken(delegate, request, response);
+                    this.tokenRepository = null;
+                    this.request = null;
+                    this.response = null;
+                }
+            }
+        }
+
     }
 }

+ 15 - 1
web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java

@@ -15,6 +15,7 @@
  */
 package org.springframework.security.web.csrf;
 
+import static org.fest.assertions.Assertions.assertThat;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.never;
@@ -73,7 +74,7 @@ public class CsrfAuthenticationStrategyTests {
         strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), request, response);
 
         verify(csrfTokenRepository).saveToken(null, request, response);
-        verify(csrfTokenRepository).saveToken(eq(generatedToken), eq(request), eq(response));
+        verify(csrfTokenRepository,never()).saveToken(eq(generatedToken), any(HttpServletRequest.class), any(HttpServletResponse.class));
         // SEC-2404, SEC-2832
         CsrfToken tokenInRequest = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
         assertThat(tokenInRequest.getToken()).isSameAs(generatedToken.getToken());
@@ -82,6 +83,19 @@ public class CsrfAuthenticationStrategyTests {
         assertThat(request.getAttribute(generatedToken.getParameterName())).isSameAs(tokenInRequest);
     }
 
+    // SEC-2872
+    @Test
+    public void delaySavingCsrf() {
+        when(csrfTokenRepository.loadToken(request)).thenReturn(existingToken);
+        when(csrfTokenRepository.generateToken(request)).thenReturn(generatedToken);
+        strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), request, response);
+
+        verify(csrfTokenRepository).saveToken(null, request, response);
+        verify(csrfTokenRepository,never()).saveToken(eq(generatedToken), any(HttpServletRequest.class), any(HttpServletResponse.class));
+
+        CsrfToken tokenInRequest = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
+        tokenInRequest.getToken();
+        verify(csrfTokenRepository).saveToken(eq(generatedToken), any(HttpServletRequest.class), any(HttpServletResponse.class));
     }
 
     @Test