Bladeren bron

Add XorCsrfChannelInterceptor

Issue gh-12378
Steve Riesenberg 2 jaren geleden
bovenliggende
commit
c306df9b46
12 gewijzigde bestanden met toevoegingen van 501 en 58 verwijderingen
  1. 10 2
      config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java
  2. 5 3
      config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java
  3. 43 0
      config/src/test/java/org/springframework/security/config/annotation/web/socket/TestDeferredCsrfToken.java
  4. 5 3
      config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java
  5. 31 6
      config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java
  6. 86 0
      messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptor.java
  7. 72 0
      messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java
  8. 18 7
      messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java
  9. 148 0
      messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java
  10. 32 3
      messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java
  11. 2 1
      web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java
  12. 49 33
      web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java

+ 10 - 2
config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -56,6 +56,8 @@ final class WebSocketMessageBrokerSecurityConfiguration
 
 	private static final String SIMPLE_URL_HANDLER_MAPPING_BEAN_NAME = "stompWebSocketHandlerMapping";
 
+	private static final String CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME = "csrfChannelInterceptor";
+
 	private MessageMatcherDelegatingAuthorizationManager b;
 
 	private static final AuthorizationManager<Message<?>> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager
@@ -66,7 +68,7 @@ final class WebSocketMessageBrokerSecurityConfiguration
 
 	private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor();
 
-	private final ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor();
+	private ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor();
 
 	private AuthorizationChannelInterceptor authorizationChannelInterceptor = new AuthorizationChannelInterceptor(
 			ANY_MESSAGE_AUTHENTICATED);
@@ -86,6 +88,12 @@ final class WebSocketMessageBrokerSecurityConfiguration
 
 	@Override
 	public void configureClientInboundChannel(ChannelRegistration registration) {
+		ChannelInterceptor csrfChannelInterceptor = getBeanOrNull(CSRF_CHANNEL_INTERCEPTOR_BEAN_NAME,
+				ChannelInterceptor.class);
+		if (csrfChannelInterceptor != null) {
+			this.csrfChannelInterceptor = csrfChannelInterceptor;
+		}
+
 		this.authorizationChannelInterceptor
 				.setAuthorizationEventPublisher(new SpringAuthorizationEventPublisher(this.context));
 		this.authorizationChannelInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);

+ 5 - 3
config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2019 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -61,6 +61,7 @@ import org.springframework.security.messaging.context.SecurityContextChannelInte
 import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.security.web.csrf.DefaultCsrfToken;
+import org.springframework.security.web.csrf.DeferredCsrfToken;
 import org.springframework.security.web.csrf.MissingCsrfTokenException;
 import org.springframework.stereotype.Controller;
 import org.springframework.test.util.ReflectionTestUtils;
@@ -79,6 +80,7 @@ import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSo
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
 
 public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
 
@@ -284,7 +286,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
 
 	private void assertHandshake(HttpServletRequest request) {
 		TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
-		assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token);
+		assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
 		assertThat(handshakeHandler.attributes.get(this.sessionAttr))
 				.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
 	}
@@ -306,7 +308,7 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
 		request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
 		request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
 		request.getSession().setAttribute(this.sessionAttr, "sessionValue");
-		request.setAttribute(CsrfToken.class.getName(), this.token);
+		request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token));
 		return request;
 	}
 

+ 43 - 0
config/src/test/java/org/springframework/security/config/annotation/web/socket/TestDeferredCsrfToken.java

@@ -0,0 +1,43 @@
+/*
+ * Copyright 2002-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.annotation.web.socket;
+
+import org.springframework.security.web.csrf.CsrfToken;
+import org.springframework.security.web.csrf.DeferredCsrfToken;
+
+/**
+ * @author Steve Riesenberg
+ */
+final class TestDeferredCsrfToken implements DeferredCsrfToken {
+
+	private final CsrfToken csrfToken;
+
+	TestDeferredCsrfToken(CsrfToken csrfToken) {
+		this.csrfToken = csrfToken;
+	}
+
+	@Override
+	public CsrfToken get() {
+		return this.csrfToken;
+	}
+
+	@Override
+	public boolean isGenerated() {
+		return false;
+	}
+
+}

+ 5 - 3
config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -70,6 +70,7 @@ import org.springframework.security.messaging.context.SecurityContextChannelInte
 import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.security.web.csrf.DefaultCsrfToken;
+import org.springframework.security.web.csrf.DeferredCsrfToken;
 import org.springframework.security.web.csrf.MissingCsrfTokenException;
 import org.springframework.stereotype.Controller;
 import org.springframework.test.util.ReflectionTestUtils;
@@ -92,6 +93,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.fail;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.verify;
+import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
 
 public class WebSocketMessageBrokerSecurityConfigurationTests {
 
@@ -367,7 +369,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
 
 	private void assertHandshake(HttpServletRequest request) {
 		TestHandshakeHandler handshakeHandler = this.context.getBean(TestHandshakeHandler.class);
-		assertThat(handshakeHandler.attributes.get(CsrfToken.class.getName())).isSameAs(this.token);
+		assertThatCsrfToken(handshakeHandler.attributes.get(CsrfToken.class.getName())).isEqualTo(this.token);
 		assertThat(handshakeHandler.attributes.get(this.sessionAttr))
 				.isEqualTo(request.getSession().getAttribute(this.sessionAttr));
 	}
@@ -389,7 +391,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
 		request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
 		request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
 		request.getSession().setAttribute(this.sessionAttr, "sessionValue");
-		request.setAttribute(CsrfToken.class.getName(), this.token);
+		request.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token));
 		return request;
 	}
 

+ 31 - 6
config/src/test/java/org/springframework/security/config/websocket/WebSocketMessageBrokerConfigTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -61,6 +61,7 @@ import org.springframework.security.test.context.annotation.SecurityTestExecutio
 import org.springframework.security.test.context.support.WithMockUser;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.security.web.csrf.DefaultCsrfToken;
+import org.springframework.security.web.csrf.DeferredCsrfToken;
 import org.springframework.security.web.csrf.InvalidCsrfTokenException;
 import org.springframework.stereotype.Controller;
 import org.springframework.test.context.junit.jupiter.SpringExtension;
@@ -77,6 +78,7 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.verify;
+import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 
 /**
@@ -381,12 +383,14 @@ public class WebSocketMessageBrokerConfigTests {
 		MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
 		String csrfAttributeName = CsrfToken.class.getName();
 		String customAttributeName = this.getClass().getName();
-		MvcResult result = mvc.perform(get("/app").requestAttr(csrfAttributeName, this.token)
-				.sessionAttr(customAttributeName, "attributeValue")).andReturn();
+		MvcResult result = mvc.perform(
+				get("/app").requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token))
+						.sessionAttr(customAttributeName, "attributeValue"))
+				.andReturn();
 		CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
 		String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
 		String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
-		assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
+		assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
 		assertThat(handshakeValue).isEqualTo(sessionValue)
 				.withFailMessage("Explicitly listed session variables are not overridden");
 	}
@@ -398,12 +402,13 @@ public class WebSocketMessageBrokerConfigTests {
 		MockMvc mvc = MockMvcBuilders.webAppContextSetup(context).build();
 		String csrfAttributeName = CsrfToken.class.getName();
 		String customAttributeName = this.getClass().getName();
-		MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket").requestAttr(csrfAttributeName, this.token)
+		MvcResult result = mvc.perform(get("/app/289/tpyx6mde/websocket")
+				.requestAttr(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(this.token))
 				.sessionAttr(customAttributeName, "attributeValue")).andReturn();
 		CsrfToken handshakeToken = (CsrfToken) this.testHandshakeHandler.attributes.get(csrfAttributeName);
 		String handshakeValue = (String) this.testHandshakeHandler.attributes.get(customAttributeName);
 		String sessionValue = (String) result.getRequest().getSession().getAttribute(customAttributeName);
-		assertThat(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
+		assertThatCsrfToken(handshakeToken).isEqualTo(this.token).withFailMessage("CsrfToken is populated");
 		assertThat(handshakeValue).isEqualTo(sessionValue)
 				.withFailMessage("Explicitly listed session variables are not overridden");
 	}
@@ -526,6 +531,26 @@ public class WebSocketMessageBrokerConfigTests {
 		return SecurityContextHolder.getContextHolderStrategy();
 	}
 
+	private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
+
+		private final CsrfToken csrfToken;
+
+		TestDeferredCsrfToken(CsrfToken csrfToken) {
+			this.csrfToken = csrfToken;
+		}
+
+		@Override
+		public CsrfToken get() {
+			return this.csrfToken;
+		}
+
+		@Override
+		public boolean isGenerated() {
+			return false;
+		}
+
+	}
+
 	@Controller
 	static class MessageController {
 

+ 86 - 0
messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptor.java

@@ -0,0 +1,86 @@
+/*
+ * Copyright 2002-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.messaging.web.csrf;
+
+import java.security.MessageDigest;
+import java.util.Map;
+
+import org.springframework.messaging.Message;
+import org.springframework.messaging.MessageChannel;
+import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
+import org.springframework.messaging.simp.SimpMessageType;
+import org.springframework.messaging.support.ChannelInterceptor;
+import org.springframework.security.crypto.codec.Utf8;
+import org.springframework.security.messaging.util.matcher.MessageMatcher;
+import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher;
+import org.springframework.security.web.csrf.CsrfToken;
+import org.springframework.security.web.csrf.InvalidCsrfTokenException;
+import org.springframework.security.web.csrf.MissingCsrfTokenException;
+
+/**
+ * {@link ChannelInterceptor} that validates a CSRF token masked by the
+ * {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler} in
+ * the header of any {@link SimpMessageType#CONNECT} message.
+ *
+ * @author Steve Riesenberg
+ * @since 5.8
+ */
+public final class XorCsrfChannelInterceptor implements ChannelInterceptor {
+
+	private final MessageMatcher<Object> matcher = new SimpMessageTypeMatcher(SimpMessageType.CONNECT);
+
+	@Override
+	public Message<?> preSend(Message<?> message, MessageChannel channel) {
+		if (!this.matcher.matches(message)) {
+			return message;
+		}
+		Map<String, Object> sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders());
+		CsrfToken expectedToken = (sessionAttributes != null)
+				? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null;
+		if (expectedToken == null) {
+			throw new MissingCsrfTokenException(null);
+		}
+		String actualToken = SimpMessageHeaderAccessor.wrap(message)
+				.getFirstNativeHeader(expectedToken.getHeaderName());
+		String actualTokenValue = XorCsrfTokenUtils.getTokenValue(actualToken, expectedToken.getToken());
+		boolean csrfCheckPassed = equalsConstantTime(expectedToken.getToken(), actualTokenValue);
+		if (!csrfCheckPassed) {
+			throw new InvalidCsrfTokenException(expectedToken, actualToken);
+		}
+		return message;
+	}
+
+	/**
+	 * Constant time comparison to prevent against timing attacks.
+	 * @param expected
+	 * @param actual
+	 * @return
+	 */
+	private static boolean equalsConstantTime(String expected, String actual) {
+		if (expected == actual) {
+			return true;
+		}
+		if (expected == null || actual == null) {
+			return false;
+		}
+		// Encode after ensure that the string is not null
+		byte[] expectedBytes = Utf8.encode(expected);
+		byte[] actualBytes = Utf8.encode(actual);
+		return MessageDigest.isEqual(expectedBytes, actualBytes);
+	}
+
+}

+ 72 - 0
messaging/src/main/java/org/springframework/security/messaging/web/csrf/XorCsrfTokenUtils.java

@@ -0,0 +1,72 @@
+/*
+ * Copyright 2002-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.messaging.web.csrf;
+
+import java.util.Base64;
+
+import org.springframework.security.crypto.codec.Utf8;
+
+/**
+ * Copied from
+ * {@link org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler}.
+ *
+ * @see <a href=
+ * "https://github.com/spring-projects/spring-security/issues/12378">gh-12378</a>
+ */
+final class XorCsrfTokenUtils {
+
+	private XorCsrfTokenUtils() {
+	}
+
+	static String getTokenValue(String actualToken, String token) {
+		byte[] actualBytes;
+		try {
+			actualBytes = Base64.getUrlDecoder().decode(actualToken);
+		}
+		catch (Exception ex) {
+			return null;
+		}
+
+		byte[] tokenBytes = Utf8.encode(token);
+		int tokenSize = tokenBytes.length;
+		if (actualBytes.length < tokenSize) {
+			return null;
+		}
+
+		// extract token and random bytes
+		int randomBytesSize = actualBytes.length - tokenSize;
+		byte[] xoredCsrf = new byte[tokenSize];
+		byte[] randomBytes = new byte[randomBytesSize];
+
+		System.arraycopy(actualBytes, 0, randomBytes, 0, randomBytesSize);
+		System.arraycopy(actualBytes, randomBytesSize, xoredCsrf, 0, tokenSize);
+
+		byte[] csrfBytes = xorCsrf(randomBytes, xoredCsrf);
+		return Utf8.decode(csrfBytes);
+	}
+
+	private static byte[] xorCsrf(byte[] randomBytes, byte[] csrfBytes) {
+		int len = Math.min(randomBytes.length, csrfBytes.length);
+		byte[] xoredCsrf = new byte[len];
+		System.arraycopy(csrfBytes, 0, xoredCsrf, 0, csrfBytes.length);
+		for (int i = 0; i < len; i++) {
+			xoredCsrf[i] ^= randomBytes[i];
+		}
+		return xoredCsrf;
+	}
+
+}

+ 18 - 7
messaging/src/main/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptor.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2015 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,15 +24,18 @@ import org.springframework.http.server.ServerHttpRequest;
 import org.springframework.http.server.ServerHttpResponse;
 import org.springframework.http.server.ServletServerHttpRequest;
 import org.springframework.security.web.csrf.CsrfToken;
+import org.springframework.security.web.csrf.DefaultCsrfToken;
+import org.springframework.security.web.csrf.DeferredCsrfToken;
 import org.springframework.web.socket.WebSocketHandler;
 import org.springframework.web.socket.server.HandshakeInterceptor;
 
 /**
- * Copies a CsrfToken from the HttpServletRequest's attributes to the WebSocket
- * attributes. This is used as the expected CsrfToken when validating connection requests
- * to ensure only the same origin connects.
+ * Loads a CsrfToken from the HttpServletRequest and HttpServletResponse to populate the
+ * WebSocket attributes. This is used as the expected CsrfToken when validating connection
+ * requests to ensure only the same origin connects.
  *
  * @author Rob Winch
+ * @author Steve Riesenberg
  * @since 4.0
  */
 public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor {
@@ -41,11 +44,19 @@ public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor
 	public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler,
 			Map<String, Object> attributes) {
 		HttpServletRequest httpRequest = ((ServletServerHttpRequest) request).getServletRequest();
-		CsrfToken token = (CsrfToken) httpRequest.getAttribute(CsrfToken.class.getName());
-		if (token == null) {
+		DeferredCsrfToken deferredCsrfToken = (DeferredCsrfToken) httpRequest
+				.getAttribute(DeferredCsrfToken.class.getName());
+		if (deferredCsrfToken == null) {
 			return true;
 		}
-		attributes.put(CsrfToken.class.getName(), token);
+		CsrfToken csrfToken = deferredCsrfToken.get();
+		// Ensure the values of the CsrfToken are copied into a new token so the old token
+		// is available for garbage collection.
+		// This is required because the original token could hold a reference to the
+		// HttpServletRequest/Response of the handshake request.
+		CsrfToken resolvedCsrfToken = new DefaultCsrfToken(csrfToken.getHeaderName(), csrfToken.getParameterName(),
+				csrfToken.getToken());
+		attributes.put(CsrfToken.class.getName(), resolvedCsrfToken);
 		return true;
 	}
 

+ 148 - 0
messaging/src/test/java/org/springframework/security/messaging/web/csrf/XorCsrfChannelInterceptorTests.java

@@ -0,0 +1,148 @@
+/*
+ * Copyright 2002-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.messaging.web.csrf;
+
+import java.util.HashMap;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import org.springframework.messaging.Message;
+import org.springframework.messaging.MessageChannel;
+import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
+import org.springframework.messaging.simp.SimpMessageType;
+import org.springframework.messaging.support.MessageBuilder;
+import org.springframework.security.web.csrf.CsrfToken;
+import org.springframework.security.web.csrf.DefaultCsrfToken;
+import org.springframework.security.web.csrf.InvalidCsrfTokenException;
+import org.springframework.security.web.csrf.MissingCsrfTokenException;
+
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for {@link XorCsrfChannelInterceptor}.
+ *
+ * @author Steve Riesenberg
+ */
+public class XorCsrfChannelInterceptorTests {
+
+	private static final String XOR_CSRF_TOKEN_VALUE = "wpe7zB62-NCpcA==";
+
+	private static final String INVALID_XOR_CSRF_TOKEN_VALUE = "KneoaygbRZtfHQ==";
+
+	private CsrfToken token;
+
+	private SimpMessageHeaderAccessor messageHeaders;
+
+	private MessageChannel channel;
+
+	private XorCsrfChannelInterceptor interceptor;
+
+	@BeforeEach
+	public void setup() {
+		this.token = new DefaultCsrfToken("header", "param", "token");
+		this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
+		this.messageHeaders.setSessionAttributes(new HashMap<>());
+		this.channel = mock(MessageChannel.class);
+		this.interceptor = new XorCsrfChannelInterceptor();
+	}
+
+	@Test
+	public void preSendWhenConnectWithValidTokenThenSuccess() {
+		this.messageHeaders.setNativeHeader(this.token.getHeaderName(), XOR_CSRF_TOKEN_VALUE);
+		this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
+		this.interceptor.preSend(message(), this.channel);
+	}
+
+	@Test
+	public void preSendWhenConnectWithInvalidTokenThenThrowsInvalidCsrfTokenException() {
+		this.messageHeaders.setNativeHeader(this.token.getHeaderName(), INVALID_XOR_CSRF_TOKEN_VALUE);
+		this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
+		// @formatter:off
+		assertThatExceptionOfType(InvalidCsrfTokenException.class)
+				.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
+		// @formatter:on
+	}
+
+	@Test
+	public void preSendWhenConnectWithNoTokenThenThrowsInvalidCsrfTokenException() {
+		this.messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
+		// @formatter:off
+		assertThatExceptionOfType(InvalidCsrfTokenException.class)
+				.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
+		// @formatter:on
+	}
+
+	@Test
+	public void preSendWhenConnectWithMissingTokenThenThrowsMissingCsrfTokenException() {
+		// @formatter:off
+		assertThatExceptionOfType(MissingCsrfTokenException.class)
+				.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
+		// @formatter:on
+	}
+
+	@Test
+	public void preSendWhenConnectWithNullSessionAttributesThenThrowsMissingCsrfTokenException() {
+		this.messageHeaders.setSessionAttributes(null);
+		// @formatter:off
+		assertThatExceptionOfType(MissingCsrfTokenException.class)
+				.isThrownBy(() -> this.interceptor.preSend(message(), mock(MessageChannel.class)));
+		// @formatter:on
+	}
+
+	@Test
+	public void preSendWhenAckThenIgnores() {
+		this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK);
+		this.interceptor.preSend(message(), this.channel);
+	}
+
+	@Test
+	public void preSendWhenDisconnectThenIgnores() {
+		this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT);
+		this.interceptor.preSend(message(), this.channel);
+	}
+
+	@Test
+	public void preSendWhenHeartbeatThenIgnores() {
+		this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT);
+		this.interceptor.preSend(message(), this.channel);
+	}
+
+	@Test
+	public void preSendWhenMessageThenIgnores() {
+		this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE);
+		this.interceptor.preSend(message(), this.channel);
+	}
+
+	@Test
+	public void preSendWhenOtherThenIgnores() {
+		this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.OTHER);
+		this.interceptor.preSend(message(), this.channel);
+	}
+
+	@Test
+	public void preSendWhenUnsubscribeThenIgnores() {
+		this.messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE);
+		this.interceptor.preSend(message(), this.channel);
+	}
+
+	private Message<String> message() {
+		return MessageBuilder.withPayload("message").copyHeaders(this.messageHeaders.toMap()).build();
+	}
+
+}

+ 32 - 3
messaging/src/test/java/org/springframework/security/messaging/web/socket/server/CsrfTokenHandshakeInterceptorTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2016 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -31,6 +31,7 @@ import org.springframework.http.server.ServletServerHttpRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.security.web.csrf.DefaultCsrfToken;
+import org.springframework.security.web.csrf.DeferredCsrfToken;
 import org.springframework.web.socket.WebSocketHandler;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -72,10 +73,38 @@ public class CsrfTokenHandshakeInterceptorTests {
 	@Test
 	public void beforeHandshake() throws Exception {
 		CsrfToken token = new DefaultCsrfToken("header", "param", "token");
-		this.httpRequest.setAttribute(CsrfToken.class.getName(), token);
+		this.httpRequest.setAttribute(DeferredCsrfToken.class.getName(), new TestDeferredCsrfToken(token));
 		this.interceptor.beforeHandshake(this.request, this.response, this.wsHandler, this.attributes);
 		assertThat(this.attributes.keySet()).containsOnly(CsrfToken.class.getName());
-		assertThat(this.attributes.values()).containsOnly(token);
+		CsrfToken csrfToken = (CsrfToken) this.attributes.get(CsrfToken.class.getName());
+		assertThat(csrfToken.getHeaderName()).isEqualTo(token.getHeaderName());
+		assertThat(csrfToken.getParameterName()).isEqualTo(token.getParameterName());
+		assertThat(csrfToken.getToken()).isEqualTo(token.getToken());
+		// Ensure the values of the CsrfToken are copied into a new token so the old token
+		// is available for garbage collection.
+		// This is required because the original token could hold a reference to the
+		// HttpServletRequest/Response of the handshake request.
+		assertThat(csrfToken).isNotSameAs(token);
+	}
+
+	private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
+
+		private final CsrfToken csrfToken;
+
+		private TestDeferredCsrfToken(CsrfToken csrfToken) {
+			this.csrfToken = csrfToken;
+		}
+
+		@Override
+		public CsrfToken get() {
+			return this.csrfToken;
+		}
+
+		@Override
+		public boolean isGenerated() {
+			return false;
+		}
+
 	}
 
 }

+ 2 - 1
web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -108,6 +108,7 @@ public final class CsrfFilter extends OncePerRequestFilter {
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
 		DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response);
+		request.setAttribute(DeferredCsrfToken.class.getName(), deferredCsrfToken);
 		this.requestHandler.handle(request, response, deferredCsrfToken::get);
 		if (!this.requireCsrfProtectionMatcher.matches(request)) {
 			if (this.logger.isTraceEnabled()) {

+ 49 - 33
web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -126,11 +126,12 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(this.token, false));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
 		verifyNoMoreInteractions(this.filterChain);
 	}
@@ -138,12 +139,13 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(this.token, false));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
 		verifyNoMoreInteractions(this.filterChain);
 	}
@@ -151,12 +153,13 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(this.token, false));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
 		verifyNoMoreInteractions(this.filterChain);
 	}
@@ -165,13 +168,14 @@ public class CsrfFilterTests {
 	public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
 			throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(this.token, false));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
 		this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
 		verifyNoMoreInteractions(this.filterChain);
 	}
@@ -179,11 +183,12 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(false);
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(this.token, false));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		verify(this.filterChain).doFilter(this.request, this.response);
 		verifyNoMoreInteractions(this.deniedHandler);
 	}
@@ -191,11 +196,12 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(false);
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(this.token, true));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		verify(this.filterChain).doFilter(this.request, this.response);
 		verifyNoMoreInteractions(this.deniedHandler);
 	}
@@ -203,12 +209,13 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(this.token, false));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		verify(this.filterChain).doFilter(this.request, this.response);
 		verifyNoMoreInteractions(this.deniedHandler);
 	}
@@ -217,13 +224,14 @@ public class CsrfFilterTests {
 	public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
 			throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(this.token, false));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
 		this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		verify(this.filterChain).doFilter(this.request, this.response);
 		verifyNoMoreInteractions(this.deniedHandler);
 	}
@@ -231,12 +239,13 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(this.token, false));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		verify(this.filterChain).doFilter(this.request, this.response);
 		verifyNoMoreInteractions(this.deniedHandler);
 		verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class),
@@ -246,12 +255,13 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(this.token, true));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, true);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		// LazyCsrfTokenRepository requires the response as an attribute
 		assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
 		verify(this.filterChain).doFilter(this.request, this.response);
@@ -316,11 +326,12 @@ public class CsrfFilterTests {
 		this.filter = new CsrfFilter(this.tokenRepository);
 		this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(this.token, false));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
 		verifyNoMoreInteractions(this.filterChain);
 	}
@@ -344,22 +355,24 @@ public class CsrfFilterTests {
 		given(token.getToken()).willReturn(null);
 		given(token.getHeaderName()).willReturn(this.token.getHeaderName());
 		given(token.getParameterName()).willReturn(this.token.getParameterName());
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(token, false));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(token, false);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
 		filter.doFilterInternal(this.request, this.response, this.filterChain);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
 	}
 
 	@Test
 	public void doFilterWhenRequestHandlerThenUsed() throws Exception {
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(this.token, false));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
 		this.filter = createCsrfFilter(this.tokenRepository);
 		this.filter.setRequestHandler(requestHandler);
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		verify(this.tokenRepository).loadDeferredToken(this.request, this.response);
 		verify(requestHandler).handle(eq(this.request), eq(this.response), any());
 		verify(this.filterChain).doFilter(this.request, this.response);
@@ -368,14 +381,15 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() throws Exception {
 		given(this.requestMatcher.matches(this.request)).willReturn(false);
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(this.token, false));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler();
 		requestHandler.setCsrfRequestAttributeName(this.token.getParameterName());
 		this.filter.setRequestHandler(requestHandler);
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThat(this.request.getAttribute(CsrfToken.class.getName())).isNotNull();
 		assertThat(this.request.getAttribute(this.token.getParameterName())).isNotNull();
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		verify(this.filterChain).doFilter(this.request, this.response);
 		assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
 
@@ -394,12 +408,13 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterWhenXorCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() throws Exception {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(this.token, false));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(this.token, false);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 		XorCsrfTokenRequestAttributeHandler requestHandler = new XorCsrfTokenRequestAttributeHandler();
 		this.filter.setRequestHandler(requestHandler);
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 		verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(AccessDeniedException.class));
 		verifyNoMoreInteractions(this.filterChain);
 	}
@@ -424,10 +439,11 @@ public class CsrfFilterTests {
 		requestHandler.setCsrfRequestAttributeName(csrfAttrName);
 		filter.setRequestHandler(requestHandler);
 		CsrfToken expectedCsrfToken = mock(CsrfToken.class);
-		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
-				.willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true));
+		DeferredCsrfToken deferredCsrfToken = new TestDeferredCsrfToken(expectedCsrfToken, true);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response)).willReturn(deferredCsrfToken);
 
 		filter.doFilter(this.request, this.response, this.filterChain);
+		assertThat(this.request.getAttribute(DeferredCsrfToken.class.getName())).isSameAs(deferredCsrfToken);
 
 		verifyNoInteractions(expectedCsrfToken);
 		CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);