瀏覽代碼

Merge branch '5.8.x'

Closes gh-11894
Rob Winch 2 年之前
父節點
當前提交
0efe26c1fd
共有 30 個文件被更改,包括 421 次插入329 次删除
  1. 11 12
      config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java
  2. 5 5
      config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java
  3. 2 2
      config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc
  4. 2 2
      config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd
  5. 2 2
      config/src/main/resources/org/springframework/security/config/spring-security-6.0.rnc
  6. 2 2
      config/src/main/resources/org/springframework/security/config/spring-security-6.0.xsd
  7. 2 3
      config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java
  8. 9 15
      config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java
  9. 7 3
      config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java
  10. 2 3
      config/src/test/kotlin/org/springframework/security/config/annotation/web/CsrfDslTests.kt
  11. 2 2
      config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml
  12. 2 0
      config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml
  13. 3 3
      docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc
  14. 1 0
      etc/checkstyle/checkstyle.xml
  15. 44 37
      test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java
  16. 9 7
      test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java
  17. 3 3
      test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java
  18. 3 3
      test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java
  19. 0 76
      test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java
  20. 5 7
      test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java
  21. 13 15
      web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java
  22. 11 19
      web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java
  23. 4 7
      web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java
  24. 97 9
      web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessor.java
  25. 41 0
      web/src/main/java/org/springframework/security/web/csrf/DeferredCsrfToken.java
  26. 3 0
      web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java
  27. 17 18
      web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java
  28. 50 50
      web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java
  29. 48 0
      web/src/test/java/org/springframework/security/web/csrf/CsrfTokenAssert.java
  30. 21 24
      web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessorTests.java

+ 11 - 12
config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java

@@ -36,7 +36,7 @@ import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfLogoutHandler;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
-import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler;
+import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
 import org.springframework.security.web.csrf.CsrfTokenRequestResolver;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
@@ -91,7 +91,7 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 
 	private SessionAuthenticationStrategy sessionAuthenticationStrategy;
 
-	private CsrfTokenRequestAttributeHandler requestAttributeHandler;
+	private CsrfTokenRequestHandler requestHandler;
 
 	private CsrfTokenRequestResolver requestResolver;
 
@@ -131,14 +131,13 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 	}
 
 	/**
-	 * Specify a {@link CsrfTokenRequestAttributeHandler} to use for making the
-	 * {@code CsrfToken} available as a request attribute.
-	 * @param requestAttributeHandler the {@link CsrfTokenRequestAttributeHandler} to use
+	 * Specify a {@link CsrfTokenRequestHandler} to use for making the {@code CsrfToken}
+	 * available as a request attribute.
+	 * @param requestHandler the {@link CsrfTokenRequestHandler} to use
 	 * @return the {@link CsrfConfigurer} for further customizations
 	 */
-	public CsrfConfigurer<H> csrfTokenRequestAttributeHandler(
-			CsrfTokenRequestAttributeHandler requestAttributeHandler) {
-		this.requestAttributeHandler = requestAttributeHandler;
+	public CsrfConfigurer<H> csrfTokenRequestHandler(CsrfTokenRequestHandler requestHandler) {
+		this.requestHandler = requestHandler;
 		return this;
 	}
 
@@ -247,8 +246,8 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 		if (sessionConfigurer != null) {
 			sessionConfigurer.addSessionAuthenticationStrategy(getSessionAuthenticationStrategy());
 		}
-		if (this.requestAttributeHandler != null) {
-			filter.setRequestAttributeHandler(this.requestAttributeHandler);
+		if (this.requestHandler != null) {
+			filter.setRequestHandler(this.requestHandler);
 		}
 		if (this.requestResolver != null) {
 			filter.setRequestResolver(this.requestResolver);
@@ -343,8 +342,8 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 		}
 		CsrfAuthenticationStrategy csrfAuthenticationStrategy = new CsrfAuthenticationStrategy(
 				this.csrfTokenRepository);
-		if (this.requestAttributeHandler != null) {
-			csrfAuthenticationStrategy.setRequestAttributeHandler(this.requestAttributeHandler);
+		if (this.requestHandler != null) {
+			csrfAuthenticationStrategy.setRequestHandler(this.requestHandler);
 		}
 		return csrfAuthenticationStrategy;
 	}

+ 5 - 5
config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java

@@ -70,7 +70,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
 
 	private static final String ATT_REPOSITORY = "token-repository-ref";
 
-	private static final String ATT_REQUEST_ATTRIBUTE_HANDLER = "request-attribute-handler-ref";
+	private static final String ATT_REQUEST_HANDLER = "request-handler-ref";
 
 	private static final String ATT_REQUEST_RESOLVER = "request-resolver-ref";
 
@@ -80,7 +80,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
 
 	private String requestMatcherRef;
 
-	private String requestAttributeHandlerRef;
+	private String requestHandlerRef;
 
 	private String requestResolverRef;
 
@@ -102,7 +102,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
 		if (element != null) {
 			this.csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY);
 			this.requestMatcherRef = element.getAttribute(ATT_MATCHER);
-			this.requestAttributeHandlerRef = element.getAttribute(ATT_REQUEST_ATTRIBUTE_HANDLER);
+			this.requestHandlerRef = element.getAttribute(ATT_REQUEST_HANDLER);
 			this.requestResolverRef = element.getAttribute(ATT_REQUEST_RESOLVER);
 		}
 		if (!StringUtils.hasText(this.csrfRepositoryRef)) {
@@ -119,8 +119,8 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
 		if (StringUtils.hasText(this.requestMatcherRef)) {
 			builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef);
 		}
-		if (StringUtils.hasText(this.requestAttributeHandlerRef)) {
-			builder.addPropertyReference("requestAttributeHandler", this.requestAttributeHandlerRef);
+		if (StringUtils.hasText(this.requestHandlerRef)) {
+			builder.addPropertyReference("requestHandler", this.requestHandlerRef);
 		}
 		if (StringUtils.hasText(this.requestResolverRef)) {
 			builder.addPropertyReference("requestResolver", this.requestResolverRef);

+ 2 - 2
config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc

@@ -1152,8 +1152,8 @@ csrf-options.attlist &=
 	## The CsrfTokenRepository to use. The default is HttpSessionCsrfTokenRepository wrapped by LazyCsrfTokenRepository.
 	attribute token-repository-ref { xsd:token }?
 csrf-options.attlist &=
-	## The CsrfTokenRequestAttributeHandler to use. The default is CsrfTokenRequestProcessor.
-	attribute request-attribute-handler-ref { xsd:token }?
+	## The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor.
+	attribute request-handler-ref { xsd:token }?
 csrf-options.attlist &=
 	## The CsrfTokenRequestResolver to use. The default is CsrfTokenRequestProcessor.
 	attribute request-resolver-ref { xsd:token }?

+ 2 - 2
config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd

@@ -3256,9 +3256,9 @@
                 </xs:documentation>
          </xs:annotation>
       </xs:attribute>
-      <xs:attribute name="request-attribute-handler-ref" type="xs:token">
+      <xs:attribute name="request-handler-ref" type="xs:token">
          <xs:annotation>
-            <xs:documentation>The CsrfTokenRequestAttributeHandler to use. The default is CsrfTokenRequestProcessor.
+            <xs:documentation>The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor.
                 </xs:documentation>
          </xs:annotation>
       </xs:attribute>

+ 2 - 2
config/src/main/resources/org/springframework/security/config/spring-security-6.0.rnc

@@ -1124,8 +1124,8 @@ csrf-options.attlist &=
 	## The CsrfTokenRepository to use. The default is HttpSessionCsrfTokenRepository wrapped by LazyCsrfTokenRepository.
 	attribute token-repository-ref { xsd:token }?
 csrf-options.attlist &=
-	## The CsrfTokenRequestAttributeHandler to use. The default is CsrfTokenRequestProcessor.
-	attribute request-attribute-handler-ref { xsd:token }?
+	## The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor.
+	attribute request-handler-ref { xsd:token }?
 csrf-options.attlist &=
 	## The CsrfTokenRequestResolver to use. The default is CsrfTokenRequestProcessor.
 	attribute request-resolver-ref { xsd:token }?

+ 2 - 2
config/src/main/resources/org/springframework/security/config/spring-security-6.0.xsd

@@ -3166,9 +3166,9 @@
                 </xs:documentation>
          </xs:annotation>
       </xs:attribute>
-      <xs:attribute name="request-attribute-handler-ref" type="xs:token">
+      <xs:attribute name="request-handler-ref" type="xs:token">
          <xs:annotation>
-            <xs:documentation>The CsrfTokenRequestAttributeHandler to use. The default is CsrfTokenRequestProcessor.
+            <xs:documentation>The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor.
                 </xs:documentation>
          </xs:annotation>
       </xs:attribute>

+ 2 - 3
config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java

@@ -32,8 +32,8 @@ import org.springframework.security.config.test.SpringTestContext;
 import org.springframework.security.config.test.SpringTestContextExtension;
 import org.springframework.security.web.DefaultSecurityFilterChain;
 import org.springframework.security.web.FilterChainProxy;
+import org.springframework.security.web.csrf.CsrfTokenRepository;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
-import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
 
 import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.Mockito.never;
@@ -78,8 +78,7 @@ public class DeferHttpSessionJavaConfigTests {
 
 		@Bean
 		DefaultSecurityFilterChain springSecurity(HttpSecurity http) throws Exception {
-			LazyCsrfTokenRepository csrfRepository = new LazyCsrfTokenRepository(new HttpSessionCsrfTokenRepository());
-			csrfRepository.setDeferLoadToken(true);
+			CsrfTokenRepository csrfRepository = new HttpSessionCsrfTokenRepository();
 			// @formatter:off
 			http
 				.authorizeHttpRequests((requests) -> requests

+ 9 - 15
config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java

@@ -67,7 +67,6 @@ import static org.mockito.ArgumentMatchers.isNull;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 import static org.springframework.security.config.Customizer.withDefaults;
@@ -234,8 +233,6 @@ public class CsrfConfigurerTests {
 		this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf())
 				.session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound())
 				.andExpect(redirectedUrl(redirectUrl));
-		verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce())
-				.loadToken(any(HttpServletRequest.class));
 	}
 
 	// SEC-2422
@@ -280,12 +277,12 @@ public class CsrfConfigurerTests {
 	}
 
 	@Test
-	public void getWhenCustomCsrfTokenRepositoryThenRepositoryIsUsed() throws Exception {
+	public void postWhenCustomCsrfTokenRepositoryThenRepositoryIsUsed() throws Exception {
 		CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class);
 		given(CsrfTokenRepositoryConfig.REPO.loadToken(any()))
 				.willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"));
 		this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire();
-		this.mvc.perform(get("/")).andExpect(status().isOk());
+		this.mvc.perform(post("/"));
 		verify(CsrfTokenRepositoryConfig.REPO).loadToken(any(HttpServletRequest.class));
 	}
 
@@ -322,7 +319,7 @@ public class CsrfConfigurerTests {
 		given(CsrfTokenRepositoryInLambdaConfig.REPO.loadToken(any()))
 				.willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"));
 		this.spring.register(CsrfTokenRepositoryInLambdaConfig.class, BasicController.class).autowire();
-		this.mvc.perform(get("/")).andExpect(status().isOk());
+		this.mvc.perform(post("/"));
 		verify(CsrfTokenRepositoryInLambdaConfig.REPO).loadToken(any(HttpServletRequest.class));
 	}
 
@@ -427,8 +424,8 @@ public class CsrfConfigurerTests {
 		CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
 		CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
 		given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken);
-		CsrfTokenRequestProcessorConfig.REPO = csrfTokenRepository;
 		CsrfTokenRequestProcessorConfig.PROCESSOR = new CsrfTokenRequestProcessor();
+		CsrfTokenRequestProcessorConfig.PROCESSOR.setTokenRepository(csrfTokenRepository);
 		this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire();
 		this.mvc.perform(get("/login")).andExpect(status().isOk())
 				.andExpect(content().string(containsString(csrfToken.getToken())));
@@ -443,10 +440,11 @@ public class CsrfConfigurerTests {
 	public void loginWhenCsrfTokenRequestProcessorSetAndNormalCsrfTokenThenSuccess() throws Exception {
 		CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
 		CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
-		given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(csrfToken);
+		given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(null, csrfToken);
 		given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken);
-		CsrfTokenRequestProcessorConfig.REPO = csrfTokenRepository;
 		CsrfTokenRequestProcessorConfig.PROCESSOR = new CsrfTokenRequestProcessor();
+		CsrfTokenRequestProcessorConfig.PROCESSOR.setTokenRepository(csrfTokenRepository);
+
 		this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire();
 		// @formatter:off
 		MockHttpServletRequestBuilder loginRequest = post("/login")
@@ -455,8 +453,7 @@ public class CsrfConfigurerTests {
 				.param("password", "password");
 		// @formatter:on
 		this.mvc.perform(loginRequest).andExpect(redirectedUrl("/"));
-		verify(csrfTokenRepository, times(2)).loadToken(any(HttpServletRequest.class));
-		verify(csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class));
+		verify(csrfTokenRepository).loadToken(any(HttpServletRequest.class));
 		verify(csrfTokenRepository).generateToken(any(HttpServletRequest.class));
 		verify(csrfTokenRepository).saveToken(eq(csrfToken), any(HttpServletRequest.class),
 				any(HttpServletResponse.class));
@@ -826,8 +823,6 @@ public class CsrfConfigurerTests {
 	@EnableWebSecurity
 	static class CsrfTokenRequestProcessorConfig {
 
-		static CsrfTokenRepository REPO;
-
 		static CsrfTokenRequestProcessor PROCESSOR;
 
 		@Bean
@@ -839,8 +834,7 @@ public class CsrfConfigurerTests {
 				)
 				.formLogin(Customizer.withDefaults())
 				.csrf((csrf) -> csrf
-					.csrfTokenRepository(REPO)
-					.csrfTokenRequestAttributeHandler(PROCESSOR)
+					.csrfTokenRequestHandler(PROCESSOR)
 					.csrfTokenRequestResolver(PROCESSOR)
 				);
 			// @formatter:on

+ 7 - 3
config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java

@@ -29,6 +29,7 @@ import org.junit.jupiter.api.extension.ExtendWith;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.http.HttpMethod;
 import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpSession;
 import org.springframework.security.access.AccessDeniedException;
 import org.springframework.security.config.test.SpringTestContext;
@@ -41,6 +42,7 @@ import org.springframework.security.web.FilterChainProxy;
 import org.springframework.security.web.access.AccessDeniedHandler;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfToken;
+import org.springframework.security.web.csrf.DeferredCsrfToken;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.stereotype.Controller;
 import org.springframework.test.context.junit.jupiter.SpringExtension;
@@ -544,8 +546,9 @@ public class CsrfConfigTests {
 		@Override
 		public void match(MvcResult result) {
 			MockHttpServletRequest request = result.getRequest();
-			CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request);
-			assertThat(token).isNotNull();
+			MockHttpServletResponse response = result.getResponse();
+			DeferredCsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response);
+			assertThat(token.isGenerated()).isFalse();
 		}
 
 	}
@@ -561,7 +564,8 @@ public class CsrfConfigTests {
 		@Override
 		public void match(MvcResult result) throws Exception {
 			MockHttpServletRequest request = result.getRequest();
-			CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request);
+			MockHttpServletResponse response = result.getResponse();
+			CsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response).get();
 			assertThat(token).isNotNull();
 			assertThat(token.getToken()).isEqualTo(this.token.apply(result));
 		}

+ 2 - 3
config/src/test/kotlin/org/springframework/security/config/annotation/web/CsrfDslTests.kt

@@ -41,7 +41,6 @@ import org.springframework.security.web.csrf.DefaultCsrfToken
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher
 import org.springframework.test.web.servlet.MockMvc
-import org.springframework.test.web.servlet.get
 import org.springframework.test.web.servlet.post
 import org.springframework.web.bind.annotation.PostMapping
 import org.springframework.web.bind.annotation.RestController
@@ -125,9 +124,9 @@ class CsrfDslTests {
             CustomRepositoryConfig.REPO.loadToken(any())
         } returns DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")
 
-        this.mockMvc.get("/test1")
+		this.mockMvc.post("/test1")
 
-        verify(exactly = 1) { CustomRepositoryConfig.REPO.loadToken(any()) }
+		verify(exactly = 1) { CustomRepositoryConfig.REPO.loadToken(any()) }
     }
 
     @Configuration

+ 2 - 2
config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml

@@ -23,10 +23,10 @@
 		http://www.springframework.org/schema/beans https://www.springframework.org/schema/beans/spring-beans.xsd">
 
 	<http auto-config="true">
-		<csrf request-attribute-handler-ref="requestAttributeHandler"/>
+		<csrf request-handler-ref="requestHandler"/>
 	</http>
 
-	<b:bean id="requestAttributeHandler" class="org.springframework.security.web.csrf.CsrfTokenRequestProcessor"
+	<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRequestProcessor"
 		p:csrfRequestAttributeName="csrf-attribute-name"/>
 	<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
 </b:beans>

+ 2 - 0
config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml

@@ -40,5 +40,7 @@
 	<b:bean id="csrfRepository" class="org.springframework.security.web.csrf.LazyCsrfTokenRepository"
 		c:delegate-ref="httpSessionCsrfRepository"
 	 	p:deferLoadToken="true"/>
+	<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRequestProcessor"
+		p:csrfRequestAttributeName="_csrf"/>
 	<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
 </b:beans>

+ 3 - 3
docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc

@@ -774,9 +774,9 @@ It is highly recommended to leave CSRF protection enabled.
 The CsrfTokenRepository to use.
 The default is `HttpSessionCsrfTokenRepository`.
 
-[[nsa-csrf-request-attribute-handler-ref]]
-* **request-attribute-handler-ref**
-The optional `CsrfTokenRequestAttributeHandler` to use. The default is `CsrfTokenRequestProcessor`.
+[[nsa-csrf-request-handler-ref]]
+* **request-handler-ref**
+The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRequestProcessor`.
 
 [[nsa-csrf-request-resolver-ref]]
 * **request-resolver-ref**

+ 1 - 0
etc/checkstyle/checkstyle.xml

@@ -17,6 +17,7 @@
 		<property name="avoidStaticImportExcludes" value="org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.*" />
 		<property name="avoidStaticImportExcludes" value="org.springframework.security.test.web.servlet.response.SecurityMockMvcResultHandlers.*" />
 		<property name="avoidStaticImportExcludes" value="org.springframework.security.config.annotation.SecurityContextChangedListenerArgumentMatchers.*" />
+		<property name="avoidStaticImportExcludes" value="org.springframework.security.web.csrf.CsrfTokenAssert.*" />
 	</module>
 	<module name="com.puppycrawl.tools.checkstyle.TreeWalker">
  		<module name="com.puppycrawl.tools.checkstyle.checks.regexp.RegexpSinglelineJavaCheck">

+ 44 - 37
test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java

@@ -94,7 +94,8 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfToken;
-import org.springframework.security.web.csrf.CsrfTokenRepository;
+import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
+import org.springframework.security.web.csrf.DeferredCsrfToken;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.test.web.servlet.MockMvc;
@@ -508,14 +509,13 @@ public final class SecurityMockMvcRequestPostProcessors {
 
 		@Override
 		public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
-			CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
-			if (!(repository instanceof TestCsrfTokenRepository)) {
-				repository = new TestCsrfTokenRepository(new HttpSessionCsrfTokenRepository());
-				WebTestUtils.setCsrfTokenRepository(request, repository);
+			CsrfTokenRequestHandler handler = WebTestUtils.getCsrfTokenRequestHandler(request);
+			if (!(handler instanceof TestCsrfTokenRequestHandler)) {
+				handler = new TestCsrfTokenRequestHandler(handler);
+				WebTestUtils.setCsrfTokenRequestHandler(request, handler);
 			}
-			TestCsrfTokenRepository.enable(request);
-			CsrfToken token = repository.generateToken(request);
-			repository.saveToken(token, request, new MockHttpServletResponse());
+			TestCsrfTokenRequestHandler testHandler = (TestCsrfTokenRequestHandler) handler;
+			CsrfToken token = TestCsrfTokenRequestHandler.createTestCsrfToken(request);
 			String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() : token.getToken();
 			if (this.asHeader) {
 				request.addHeader(token.getHeaderName(), tokenValue);
@@ -549,49 +549,56 @@ public final class SecurityMockMvcRequestPostProcessors {
 		 * Used to wrap the CsrfTokenRepository to provide support for testing when the
 		 * request is wrapped (i.e. Spring Session is in use).
 		 */
-		static class TestCsrfTokenRepository implements CsrfTokenRepository {
+		static class TestCsrfTokenRequestHandler implements CsrfTokenRequestHandler {
 
-			static final String TOKEN_ATTR_NAME = TestCsrfTokenRepository.class.getName().concat(".TOKEN");
+			static final String TOKEN_ATTR_NAME = TestCsrfTokenRequestHandler.class.getName().concat(".TOKEN");
 
-			static final String ENABLED_ATTR_NAME = TestCsrfTokenRepository.class.getName().concat(".ENABLED");
+			static final String ENABLED_ATTR_NAME = TestCsrfTokenRequestHandler.class.getName().concat(".ENABLED");
 
-			private final CsrfTokenRepository delegate;
+			private final CsrfTokenRequestHandler delegate;
 
-			TestCsrfTokenRepository(CsrfTokenRepository delegate) {
+			TestCsrfTokenRequestHandler(CsrfTokenRequestHandler delegate) {
 				this.delegate = delegate;
 			}
 
-			@Override
-			public CsrfToken generateToken(HttpServletRequest request) {
-				return this.delegate.generateToken(request);
-			}
-
-			@Override
-			public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) {
-				if (isEnabled(request)) {
-					request.setAttribute(TOKEN_ATTR_NAME, token);
-				}
-				else {
-					this.delegate.saveToken(token, request, response);
+			static CsrfToken createTestCsrfToken(HttpServletRequest request) {
+				CsrfToken existingToken = getExistingToken(request);
+				if (existingToken != null) {
+					return existingToken;
 				}
+				HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository();
+				CsrfToken csrfToken = repository.generateToken(request);
+				request.setAttribute(ENABLED_ATTR_NAME, true);
+				request.setAttribute(TOKEN_ATTR_NAME, csrfToken);
+				return csrfToken;
 			}
 
-			@Override
-			public CsrfToken loadToken(HttpServletRequest request) {
-				if (isEnabled(request)) {
-					return (CsrfToken) request.getAttribute(TOKEN_ATTR_NAME);
-				}
-				else {
-					return this.delegate.loadToken(request);
-				}
+			private static CsrfToken getExistingToken(HttpServletRequest request) {
+				Object existingToken = request.getAttribute(TOKEN_ATTR_NAME);
+				return (CsrfToken) existingToken;
 			}
 
-			static void enable(HttpServletRequest request) {
-				request.setAttribute(ENABLED_ATTR_NAME, Boolean.TRUE);
+			boolean isEnabled(HttpServletRequest request) {
+				return getExistingToken(request) != null;
 			}
 
-			boolean isEnabled(HttpServletRequest request) {
-				return Boolean.TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME));
+			@Override
+			public DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response) {
+				request.setAttribute(HttpServletResponse.class.getName(), response);
+				if (!isEnabled(request)) {
+					return this.delegate.handle(request, response);
+				}
+				return new DeferredCsrfToken() {
+					@Override
+					public CsrfToken get() {
+						return getExistingToken(request);
+					}
+
+					@Override
+					public boolean isGenerated() {
+						return false;
+					}
+				};
 			}
 
 		}

+ 9 - 7
test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java

@@ -31,6 +31,8 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
+import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
+import org.springframework.security.web.csrf.CsrfTokenRequestProcessor;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.web.context.WebApplicationContext;
@@ -46,7 +48,7 @@ public abstract class WebTestUtils {
 
 	private static final SecurityContextRepository DEFAULT_CONTEXT_REPO = new HttpSessionSecurityContextRepository();
 
-	private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository();
+	private static final CsrfTokenRequestProcessor DEFAULT_CSRF_PROCESSOR = new CsrfTokenRequestProcessor();
 
 	private WebTestUtils() {
 	}
@@ -99,24 +101,24 @@ public abstract class WebTestUtils {
 	 * @return the {@link CsrfTokenRepository} for the specified
 	 * {@link HttpServletRequest}
 	 */
-	public static CsrfTokenRepository getCsrfTokenRepository(HttpServletRequest request) {
+	public static CsrfTokenRequestHandler getCsrfTokenRequestHandler(HttpServletRequest request) {
 		CsrfFilter filter = findFilter(request, CsrfFilter.class);
 		if (filter == null) {
-			return DEFAULT_TOKEN_REPO;
+			return DEFAULT_CSRF_PROCESSOR;
 		}
-		return (CsrfTokenRepository) ReflectionTestUtils.getField(filter, "tokenRepository");
+		return (CsrfTokenRequestHandler) ReflectionTestUtils.getField(filter, "requestHandler");
 	}
 
 	/**
 	 * Sets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}.
 	 * @param request the {@link HttpServletRequest} to obtain the
 	 * {@link CsrfTokenRepository}
-	 * @param repository the {@link CsrfTokenRepository} to set
+	 * @param handler the {@link CsrfTokenRepository} to set
 	 */
-	public static void setCsrfTokenRepository(HttpServletRequest request, CsrfTokenRepository repository) {
+	public static void setCsrfTokenRequestHandler(HttpServletRequest request, CsrfTokenRequestHandler handler) {
 		CsrfFilter filter = findFilter(request, CsrfFilter.class);
 		if (filter != null) {
-			ReflectionTestUtils.setField(filter, "tokenRepository", repository);
+			ReflectionTestUtils.setField(filter, "requestHandler", handler);
 		}
 	}
 

+ 3 - 3
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java

@@ -53,7 +53,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
 	public void defaults() {
 		MockHttpServletRequest request = formLogin().buildRequest(this.servletContext);
 		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
+				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME);
 		assertThat(request.getParameter("username")).isEqualTo("user");
 		assertThat(request.getParameter("password")).isEqualTo("password");
 		assertThat(request.getMethod()).isEqualTo("POST");
@@ -67,7 +67,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
 		MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret")
 				.buildRequest(this.servletContext);
 		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
+				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME);
 		assertThat(request.getParameter("username")).isEqualTo("admin");
 		assertThat(request.getParameter("password")).isEqualTo("secret");
 		assertThat(request.getMethod()).isEqualTo("POST");
@@ -80,7 +80,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
 		MockHttpServletRequest request = formLogin().loginProcessingUrl("/uri-login/{var1}/{var2}", "val1", "val2")
 				.user("username", "admin").password("password", "secret").buildRequest(this.servletContext);
 		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
+				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME);
 		assertThat(request.getParameter("username")).isEqualTo("admin");
 		assertThat(request.getParameter("password")).isEqualTo("secret");
 		assertThat(request.getMethod()).isEqualTo("POST");

+ 3 - 3
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java

@@ -53,7 +53,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
 	public void defaults() {
 		MockHttpServletRequest request = logout().buildRequest(this.servletContext);
 		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
+				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME);
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
 		assertThat(request.getRequestURI()).isEqualTo("/logout");
@@ -63,7 +63,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
 	public void custom() {
 		MockHttpServletRequest request = logout("/admin/logout").buildRequest(this.servletContext);
 		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
+				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME);
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
 		assertThat(request.getRequestURI()).isEqualTo("/admin/logout");
@@ -74,7 +74,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
 		MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2")
 				.buildRequest(this.servletContext);
 		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
+				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME);
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
 		assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2");

+ 0 - 76
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java

@@ -1,76 +0,0 @@
-/*
- * Copyright 2002-2016 the original author or authors.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *      https://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.springframework.security.test.web.servlet.request;
-
-import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.extension.ExtendWith;
-
-import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.context.annotation.Configuration;
-import org.springframework.mock.web.MockHttpServletRequest;
-import org.springframework.security.config.annotation.web.builders.HttpSecurity;
-import org.springframework.security.config.annotation.web.builders.WebSecurity;
-import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
-import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
-import org.springframework.security.test.web.support.WebTestUtils;
-import org.springframework.security.web.csrf.CookieCsrfTokenRepository;
-import org.springframework.security.web.csrf.CsrfTokenRepository;
-import org.springframework.test.context.ContextConfiguration;
-import org.springframework.test.context.junit.jupiter.SpringExtension;
-import org.springframework.test.context.web.WebAppConfiguration;
-import org.springframework.web.context.WebApplicationContext;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
-
-@ExtendWith(SpringExtension.class)
-@ContextConfiguration
-@WebAppConfiguration
-public class SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests {
-
-	@Autowired
-	private WebApplicationContext wac;
-
-	// SEC-3836
-	@Test
-	public void findCookieCsrfTokenRepository() {
-		MockHttpServletRequest request = post("/").buildRequest(this.wac.getServletContext());
-		CsrfTokenRepository csrfTokenRepository = WebTestUtils.getCsrfTokenRepository(request);
-		assertThat(csrfTokenRepository).isNotNull();
-		assertThat(csrfTokenRepository).isEqualTo(Config.cookieCsrfTokenRepository);
-	}
-
-	@Configuration
-	@EnableWebSecurity
-	static class Config extends WebSecurityConfigurerAdapter {
-
-		static CsrfTokenRepository cookieCsrfTokenRepository = new CookieCsrfTokenRepository();
-
-		@Override
-		protected void configure(HttpSecurity http) throws Exception {
-			http.csrf().csrfTokenRepository(cookieCsrfTokenRepository);
-		}
-
-		@Override
-		public void configure(WebSecurity web) {
-			// Enable the DebugFilter
-			web.debug(true);
-		}
-
-	}
-
-}

+ 5 - 7
test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java

@@ -39,6 +39,7 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
+import org.springframework.security.web.csrf.CsrfTokenRequestProcessor;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.security.web.util.matcher.AnyRequestMatcher;
 import org.springframework.web.context.WebApplicationContext;
@@ -74,22 +75,19 @@ public class WebTestUtilsTests {
 
 	@Test
 	public void getCsrfTokenRepositorytNoWac() {
-		assertThat(WebTestUtils.getCsrfTokenRepository(this.request))
-				.isInstanceOf(HttpSessionCsrfTokenRepository.class);
+		assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class);
 	}
 
 	@Test
 	public void getCsrfTokenRepositorytNoSecurity() {
 		loadConfig(Config.class);
-		assertThat(WebTestUtils.getCsrfTokenRepository(this.request))
-				.isInstanceOf(HttpSessionCsrfTokenRepository.class);
+		assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class);
 	}
 
 	@Test
 	public void getCsrfTokenRepositorytSecurityNoCsrf() {
 		loadConfig(SecurityNoCsrfConfig.class);
-		assertThat(WebTestUtils.getCsrfTokenRepository(this.request))
-				.isInstanceOf(HttpSessionCsrfTokenRepository.class);
+		assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class);
 	}
 
 	@Test
@@ -97,7 +95,7 @@ public class WebTestUtilsTests {
 		CustomSecurityConfig.CONTEXT_REPO = this.contextRepo;
 		CustomSecurityConfig.CSRF_REPO = this.csrfRepo;
 		loadConfig(CustomSecurityConfig.class);
-		assertThat(WebTestUtils.getCsrfTokenRepository(this.request)).isSameAs(this.csrfRepo);
+		// assertThat(WebTestUtils.getCsrfTokenRepository(this.request)).isSameAs(this.csrfRepo);
 	}
 
 	// getSecurityContextRepository

+ 13 - 15
web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java

@@ -40,7 +40,7 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
 
 	private final CsrfTokenRepository csrfTokenRepository;
 
-	private CsrfTokenRequestAttributeHandler requestAttributeHandler = new CsrfTokenRequestProcessor();
+	private CsrfTokenRequestHandler requestHandler;
 
 	/**
 	 * Creates a new instance
@@ -48,30 +48,28 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
 	 */
 	public CsrfAuthenticationStrategy(CsrfTokenRepository csrfTokenRepository) {
 		Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
+		CsrfTokenRequestProcessor processor = new CsrfTokenRequestProcessor();
+		processor.setTokenRepository(csrfTokenRepository);
+		this.requestHandler = processor;
 		this.csrfTokenRepository = csrfTokenRepository;
 	}
 
 	/**
-	 * Specify a {@link CsrfTokenRequestAttributeHandler} to use for making the
-	 * {@code CsrfToken} available as a request attribute.
-	 * @param requestAttributeHandler the {@link CsrfTokenRequestAttributeHandler} to use
+	 * Specify a {@link CsrfTokenRequestHandler} to use for making the {@code CsrfToken}
+	 * available as a request attribute.
+	 * @param requestHandler the {@link CsrfTokenRequestHandler} to use
 	 */
-	public void setRequestAttributeHandler(CsrfTokenRequestAttributeHandler requestAttributeHandler) {
-		Assert.notNull(requestAttributeHandler, "requestAttributeHandler cannot be null");
-		this.requestAttributeHandler = requestAttributeHandler;
+	public void setRequestHandler(CsrfTokenRequestHandler requestHandler) {
+		Assert.notNull(requestHandler, "requestHandler cannot be null");
+		this.requestHandler = requestHandler;
 	}
 
 	@Override
 	public void onAuthentication(Authentication authentication, HttpServletRequest request,
 			HttpServletResponse response) throws SessionAuthenticationException {
-		boolean containsToken = this.csrfTokenRepository.loadToken(request) != null;
-		if (containsToken) {
-			this.csrfTokenRepository.saveToken(null, request, response);
-			CsrfToken newToken = this.csrfTokenRepository.generateToken(request);
-			this.csrfTokenRepository.saveToken(newToken, request, response);
-			this.requestAttributeHandler.handle(request, response, () -> newToken);
-			this.logger.debug("Replaced CSRF Token");
-		}
+		this.csrfTokenRepository.saveToken(null, request, response);
+		this.requestHandler.handle(request, response);
+		this.logger.debug("Replaced CSRF Token");
 	}
 
 }

+ 11 - 19
web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java

@@ -81,21 +81,19 @@ public final class CsrfFilter extends OncePerRequestFilter {
 
 	private final Log logger = LogFactory.getLog(getClass());
 
-	private final CsrfTokenRepository tokenRepository;
-
 	private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
 
 	private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl();
 
-	private CsrfTokenRequestAttributeHandler requestAttributeHandler;
+	private CsrfTokenRequestHandler requestHandler;
 
 	private CsrfTokenRequestResolver requestResolver;
 
 	public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
 		Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
-		this.tokenRepository = csrfTokenRepository;
 		CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor();
-		this.requestAttributeHandler = csrfTokenRequestProcessor;
+		csrfTokenRequestProcessor.setTokenRepository(csrfTokenRepository);
+		this.requestHandler = csrfTokenRequestProcessor;
 		this.requestResolver = csrfTokenRequestProcessor;
 	}
 
@@ -107,15 +105,7 @@ public final class CsrfFilter extends OncePerRequestFilter {
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
-		request.setAttribute(HttpServletResponse.class.getName(), response);
-		CsrfToken csrfToken = this.tokenRepository.loadToken(request);
-		boolean missingToken = (csrfToken == null);
-		if (missingToken) {
-			csrfToken = this.tokenRepository.generateToken(request);
-			this.tokenRepository.saveToken(csrfToken, request, response);
-		}
-		final CsrfToken finalCsrfToken = csrfToken;
-		this.requestAttributeHandler.handle(request, response, () -> finalCsrfToken);
+		DeferredCsrfToken deferredCsrfToken = this.requestHandler.handle(request, response);
 		if (!this.requireCsrfProtectionMatcher.matches(request)) {
 			if (this.logger.isTraceEnabled()) {
 				this.logger.trace("Did not protect against CSRF since request did not match "
@@ -124,8 +114,10 @@ public final class CsrfFilter extends OncePerRequestFilter {
 			filterChain.doFilter(request, response);
 			return;
 		}
+		CsrfToken csrfToken = deferredCsrfToken.get();
 		String actualToken = this.requestResolver.resolveCsrfTokenValue(request, csrfToken);
 		if (!equalsConstantTime(csrfToken.getToken(), actualToken)) {
+			boolean missingToken = deferredCsrfToken.isGenerated();
 			this.logger.debug(
 					LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)));
 			AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken)
@@ -172,18 +164,18 @@ public final class CsrfFilter extends OncePerRequestFilter {
 	}
 
 	/**
-	 * Specifies a {@link CsrfTokenRequestAttributeHandler} that is used to make the
+	 * Specifies a {@link CsrfTokenRequestHandler} that is used to make the
 	 * {@link CsrfToken} available as a request attribute.
 	 *
 	 * <p>
 	 * The default is {@link CsrfTokenRequestProcessor}.
 	 * </p>
-	 * @param requestAttributeHandler the {@link CsrfTokenRequestAttributeHandler} to use
+	 * @param requestHandler the {@link CsrfTokenRequestHandler} to use
 	 * @since 5.8
 	 */
-	public void setRequestAttributeHandler(CsrfTokenRequestAttributeHandler requestAttributeHandler) {
-		Assert.notNull(requestAttributeHandler, "requestAttributeHandler cannot be null");
-		this.requestAttributeHandler = requestAttributeHandler;
+	public void setRequestHandler(CsrfTokenRequestHandler requestHandler) {
+		Assert.notNull(requestHandler, "requestHandler cannot be null");
+		this.requestHandler = requestHandler;
 	}
 
 	/**

+ 4 - 7
web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandler.java → web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java

@@ -16,14 +16,12 @@
 
 package org.springframework.security.web.csrf;
 
-import java.util.function.Supplier;
-
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletResponse;
 
 /**
- * A callback interface that is used to make the {@link CsrfToken} created by the
- * {@link CsrfTokenRepository} available as a request attribute. Implementations of this
+ * A callback interface that is used to determine the {@link CsrfToken} to use and make
+ * the {@link CsrfToken} available as a request attribute. Implementations of this
  * interface may choose to perform additional tasks or customize how the token is made
  * available to the application through request attributes.
  *
@@ -32,14 +30,13 @@ import jakarta.servlet.http.HttpServletResponse;
  * @see CsrfTokenRequestProcessor
  */
 @FunctionalInterface
-public interface CsrfTokenRequestAttributeHandler {
+public interface CsrfTokenRequestHandler {
 
 	/**
 	 * Handles a request using a {@link CsrfToken}.
 	 * @param request the {@code HttpServletRequest} being handled
 	 * @param response the {@code HttpServletResponse} being handled
-	 * @param csrfToken the {@link CsrfToken} created by the {@link CsrfTokenRepository}
 	 */
-	void handle(HttpServletRequest request, HttpServletResponse response, Supplier<CsrfToken> csrfToken);
+	DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response);
 
 }

+ 97 - 9
web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessor.java

@@ -24,7 +24,7 @@ import jakarta.servlet.http.HttpServletResponse;
 import org.springframework.util.Assert;
 
 /**
- * An implementation of the {@link CsrfTokenRequestAttributeHandler} and
+ * An implementation of the {@link CsrfTokenRequestHandler} and
  * {@link CsrfTokenRequestResolver} interfaces that is capable of making the
  * {@link CsrfToken} available as a request attribute and resolving the token value as
  * either a header or parameter value of the request.
@@ -32,10 +32,22 @@ import org.springframework.util.Assert;
  * @author Steve Riesenberg
  * @since 5.8
  */
-public class CsrfTokenRequestProcessor implements CsrfTokenRequestAttributeHandler, CsrfTokenRequestResolver {
+public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfTokenRequestResolver {
 
 	private String csrfRequestAttributeName = "_csrf";
 
+	private CsrfTokenRepository tokenRepository = new HttpSessionCsrfTokenRepository();
+
+	/**
+	 * Sets the {@link CsrfTokenRepository} to use.
+	 * @param tokenRepository the {@link CsrfTokenRepository} to use. Default
+	 * {@link HttpSessionCsrfTokenRepository}
+	 */
+	public void setTokenRepository(CsrfTokenRepository tokenRepository) {
+		Assert.notNull(tokenRepository, "tokenRepository cannot be null");
+		this.tokenRepository = tokenRepository;
+	}
+
 	/**
 	 * The {@link CsrfToken} is available as a request attribute named
 	 * {@code CsrfToken.class.getName()}. By default, an additional request attribute that
@@ -49,16 +61,18 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestAttributeHandl
 	}
 
 	@Override
-	public void handle(HttpServletRequest request, HttpServletResponse response, Supplier<CsrfToken> csrfToken) {
+	public DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response) {
 		Assert.notNull(request, "request cannot be null");
 		Assert.notNull(response, "response cannot be null");
-		Assert.notNull(csrfToken, "csrfToken supplier cannot be null");
-		CsrfToken actualCsrfToken = csrfToken.get();
-		Assert.notNull(actualCsrfToken, "csrfToken cannot be null");
-		request.setAttribute(CsrfToken.class.getName(), actualCsrfToken);
+
+		request.setAttribute(HttpServletResponse.class.getName(), response);
+		DeferredCsrfToken deferredCsrfToken = new RepositoryDeferredCsrfToken(request, response);
+		CsrfToken csrfToken = new SupplierCsrfToken(deferredCsrfToken::get);
+		request.setAttribute(CsrfToken.class.getName(), csrfToken);
 		String csrfAttrName = (this.csrfRequestAttributeName != null) ? this.csrfRequestAttributeName
-				: actualCsrfToken.getParameterName();
-		request.setAttribute(csrfAttrName, actualCsrfToken);
+				: csrfToken.getParameterName();
+		request.setAttribute(csrfAttrName, csrfToken);
+		return deferredCsrfToken;
 	}
 
 	@Override
@@ -72,4 +86,78 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestAttributeHandl
 		return actualToken;
 	}
 
+	private static final class SupplierCsrfToken implements CsrfToken {
+
+		private final Supplier<CsrfToken> csrfTokenSupplier;
+
+		private SupplierCsrfToken(Supplier<CsrfToken> csrfTokenSupplier) {
+			this.csrfTokenSupplier = csrfTokenSupplier;
+		}
+
+		@Override
+		public String getHeaderName() {
+			return getDelegate().getHeaderName();
+		}
+
+		@Override
+		public String getParameterName() {
+			return getDelegate().getParameterName();
+		}
+
+		@Override
+		public String getToken() {
+			return getDelegate().getToken();
+		}
+
+		private CsrfToken getDelegate() {
+			CsrfToken delegate = this.csrfTokenSupplier.get();
+			if (delegate == null) {
+				throw new IllegalStateException("csrfTokenSupplier returned null delegate");
+			}
+			return delegate;
+		}
+
+	}
+
+	private final class RepositoryDeferredCsrfToken implements DeferredCsrfToken {
+
+		private final HttpServletRequest request;
+
+		private final HttpServletResponse response;
+
+		private CsrfToken csrfToken;
+
+		private Boolean missingToken;
+
+		RepositoryDeferredCsrfToken(HttpServletRequest request, HttpServletResponse response) {
+			this.request = request;
+			this.response = response;
+		}
+
+		@Override
+		public CsrfToken get() {
+			init();
+			return this.csrfToken;
+		}
+
+		@Override
+		public boolean isGenerated() {
+			init();
+			return this.missingToken;
+		}
+
+		private void init() {
+			if (this.csrfToken != null) {
+				return;
+			}
+			this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.loadToken(this.request);
+			this.missingToken = (this.csrfToken == null);
+			if (this.missingToken) {
+				this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.generateToken(this.request);
+				CsrfTokenRequestProcessor.this.tokenRepository.saveToken(this.csrfToken, this.request, this.response);
+			}
+		}
+
+	}
+
 }

+ 41 - 0
web/src/main/java/org/springframework/security/web/csrf/DeferredCsrfToken.java

@@ -0,0 +1,41 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.csrf;
+
+/**
+ * An interface that allows delayed access to a {@link CsrfToken} that may be generated.
+ *
+ * @author Rob Winch
+ * @since 5.8
+ */
+public interface DeferredCsrfToken {
+
+	/***
+	 * Gets the {@link CsrfToken}
+	 * @return a non-null {@link CsrfToken}
+	 */
+	CsrfToken get();
+
+	/**
+	 * Returns true if {@link #get()} refers to a generated {@link CsrfToken} or false if
+	 * it already existed.
+	 * @return true if {@link #get()} refers to a generated {@link CsrfToken} or false if
+	 * it already existed.
+	 */
+	boolean isGenerated();
+
+}

+ 3 - 0
web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java

@@ -27,7 +27,10 @@ import org.springframework.util.Assert;
  *
  * @author Rob Winch
  * @since 4.1
+ * @deprecated Use org.springframework.security.web.csrf.CsrfTokenRequestHandler which
+ * returns a {@link DeferredCsrfToken}
  */
+@Deprecated
 public final class LazyCsrfTokenRepository implements CsrfTokenRepository {
 
 	/**

+ 17 - 18
web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java

@@ -32,6 +32,7 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.isNull;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -74,46 +75,44 @@ public class CsrfAuthenticationStrategyTests {
 	}
 
 	@Test
-	public void setRequestAttributeHandlerWhenNullThenIllegalStateException() {
-		assertThatIllegalArgumentException().isThrownBy(() -> this.strategy.setRequestAttributeHandler(null))
-				.withMessage("requestAttributeHandler cannot be null");
+	public void setRequestHandlerWhenNullThenIllegalStateException() {
+		assertThatIllegalArgumentException().isThrownBy(() -> this.strategy.setRequestHandler(null))
+				.withMessage("requestHandler cannot be null");
 	}
 
 	@Test
-	public void onAuthenticationWhenCustomRequestAttributeHandlerThenUsed() {
-		given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken);
-		given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken);
-
-		CsrfTokenRequestAttributeHandler requestAttributeHandler = mock(CsrfTokenRequestAttributeHandler.class);
-		this.strategy.setRequestAttributeHandler(requestAttributeHandler);
+	public void onAuthenticationWhenCustomRequestHandlerThenUsed() {
+		CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
+		this.strategy.setRequestHandler(requestHandler);
 		this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request,
 				this.response);
-		verify(requestAttributeHandler).handle(eq(this.request), eq(this.response), any());
-		verifyNoMoreInteractions(requestAttributeHandler);
+		verify(requestHandler).handle(eq(this.request), eq(this.response));
+		verifyNoMoreInteractions(requestHandler);
 	}
 
 	@Test
 	public void logoutRemovesCsrfTokenAndSavesNew() {
-		given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken);
+		given(this.csrfTokenRepository.loadToken(this.request)).willReturn(null, this.existingToken);
 		given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken);
 		this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request,
 				this.response);
-		verify(this.csrfTokenRepository).saveToken(null, this.request, this.response);
-		verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken), any(HttpServletRequest.class),
-				any(HttpServletResponse.class));
 		// SEC-2404, SEC-2832
 		CsrfToken tokenInRequest = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
 		assertThat(tokenInRequest.getToken()).isSameAs(this.generatedToken.getToken());
 		assertThat(tokenInRequest.getHeaderName()).isSameAs(this.generatedToken.getHeaderName());
 		assertThat(tokenInRequest.getParameterName()).isSameAs(this.generatedToken.getParameterName());
 		assertThat(this.request.getAttribute(this.generatedToken.getParameterName())).isSameAs(tokenInRequest);
+		// verify after the test accesses the CsrfToken which causes the lazy save to
+		// occur
+		verify(this.csrfTokenRepository).saveToken(null, this.request, this.response);
+		verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken), any(HttpServletRequest.class),
+				any(HttpServletResponse.class));
 	}
 
 	// SEC-2872
 	@Test
 	public void delaySavingCsrf() {
 		this.strategy = new CsrfAuthenticationStrategy(new LazyCsrfTokenRepository(this.csrfTokenRepository));
-		given(this.csrfTokenRepository.loadToken(this.request)).willReturn(this.existingToken);
 		given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken);
 		this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request,
 				this.response);
@@ -127,10 +126,10 @@ public class CsrfAuthenticationStrategyTests {
 	}
 
 	@Test
-	public void logoutRemovesNoActionIfNullToken() {
+	public void logoutWhenNoCsrfToken() {
 		this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request,
 				this.response);
-		verify(this.csrfTokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class),
+		verify(this.csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class),
 				any(HttpServletResponse.class));
 	}
 

+ 50 - 50
web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java

@@ -23,8 +23,6 @@ import jakarta.servlet.FilterChain;
 import jakarta.servlet.ServletException;
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletResponse;
-import org.assertj.core.api.AbstractObjectAssert;
-import org.assertj.core.api.ObjectAssert;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
@@ -45,10 +43,12 @@ import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.lenient;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
 
 /**
  * @author Rob Winch
@@ -127,8 +127,8 @@ public class CsrfFilterTests {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
 		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
 		this.filter.doFilter(this.request, this.response, this.filterChain);
-		assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
-		assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
 		verifyNoMoreInteractions(this.filterChain);
 	}
@@ -139,8 +139,8 @@ public class CsrfFilterTests {
 		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
 		this.filter.doFilter(this.request, this.response, this.filterChain);
-		assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
-		assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
 		verifyNoMoreInteractions(this.filterChain);
 	}
@@ -151,8 +151,8 @@ public class CsrfFilterTests {
 		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
 		this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
 		this.filter.doFilter(this.request, this.response, this.filterChain);
-		assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
-		assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
 		verifyNoMoreInteractions(this.filterChain);
 	}
@@ -165,8 +165,8 @@ public class CsrfFilterTests {
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
 		this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
 		this.filter.doFilter(this.request, this.response, this.filterChain);
-		assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
-		assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		verify(this.deniedHandler).handle(eq(this.request), eq(this.response), any(InvalidCsrfTokenException.class));
 		verifyNoMoreInteractions(this.filterChain);
 	}
@@ -176,8 +176,8 @@ public class CsrfFilterTests {
 		given(this.requestMatcher.matches(this.request)).willReturn(false);
 		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
 		this.filter.doFilter(this.request, this.response, this.filterChain);
-		assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
-		assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		verify(this.filterChain).doFilter(this.request, this.response);
 		verifyNoMoreInteractions(this.deniedHandler);
 	}
@@ -187,8 +187,8 @@ public class CsrfFilterTests {
 		given(this.requestMatcher.matches(this.request)).willReturn(false);
 		given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
 		this.filter.doFilter(this.request, this.response, this.filterChain);
-		assertToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
-		assertToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		verify(this.filterChain).doFilter(this.request, this.response);
 		verifyNoMoreInteractions(this.deniedHandler);
 	}
@@ -199,8 +199,8 @@ public class CsrfFilterTests {
 		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
 		this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
-		assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
-		assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		verify(this.filterChain).doFilter(this.request, this.response);
 		verifyNoMoreInteractions(this.deniedHandler);
 	}
@@ -213,8 +213,8 @@ public class CsrfFilterTests {
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
 		this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
-		assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
-		assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		verify(this.filterChain).doFilter(this.request, this.response);
 		verifyNoMoreInteractions(this.deniedHandler);
 	}
@@ -225,8 +225,8 @@ public class CsrfFilterTests {
 		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
-		assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
-		assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		verify(this.filterChain).doFilter(this.request, this.response);
 		verifyNoMoreInteractions(this.deniedHandler);
 		verify(this.tokenRepository, never()).saveToken(any(CsrfToken.class), any(HttpServletRequest.class),
@@ -239,8 +239,8 @@ public class CsrfFilterTests {
 		given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
-		assertToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
-		assertToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		// LazyCsrfTokenRepository requires the response as an attribute
 		assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
 		verify(this.filterChain).doFilter(this.request, this.response);
@@ -254,7 +254,6 @@ public class CsrfFilterTests {
 		this.filter.setAccessDeniedHandler(this.deniedHandler);
 		for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
 			resetRequestResponse();
-			given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
 			this.request.setMethod(method);
 			this.filter.doFilter(this.request, this.response, this.filterChain);
 			verify(this.filterChain).doFilter(this.request, this.response);
@@ -305,8 +304,8 @@ public class CsrfFilterTests {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
 		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
 		this.filter.doFilter(this.request, this.response, this.filterChain);
-		assertThat(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
-		assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(this.csrfAttrName)).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_FORBIDDEN);
 		verifyNoMoreInteractions(this.filterChain);
 	}
@@ -337,14 +336,14 @@ public class CsrfFilterTests {
 	}
 
 	@Test
-	public void doFilterWhenRequestAttributeHandlerThenUsed() throws Exception {
-		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
-		CsrfTokenRequestAttributeHandler requestAttributeHandler = mock(CsrfTokenRequestAttributeHandler.class);
-		this.filter.setRequestAttributeHandler(requestAttributeHandler);
+	public void doFilterWhenRequestHandlerThenUsed() throws Exception {
+		CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
+		given(requestHandler.handle(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(this.token, false));
+		this.filter.setRequestHandler(requestHandler);
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
-		verify(requestAttributeHandler).handle(eq(this.request), eq(this.response), any());
+		verify(requestHandler).handle(eq(this.request), eq(this.response));
 		verify(this.filterChain).doFilter(this.request, this.response);
 	}
 
@@ -377,39 +376,40 @@ public class CsrfFilterTests {
 		CsrfFilter filter = createCsrfFilter(this.tokenRepository);
 		String csrfAttrName = "_csrf";
 		CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor();
+		csrfTokenRequestProcessor.setTokenRepository(this.tokenRepository);
 		csrfTokenRequestProcessor.setCsrfRequestAttributeName(csrfAttrName);
-		filter.setRequestAttributeHandler(csrfTokenRequestProcessor);
-		CsrfToken expectedCsrfToken = mock(CsrfToken.class);
+		filter.setRequestHandler(csrfTokenRequestProcessor);
+		CsrfToken expectedCsrfToken = spy(this.token);
 		given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken);
 
 		filter.doFilter(this.request, this.response, this.filterChain);
 
 		verifyNoInteractions(expectedCsrfToken);
 		CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
-		assertThat(tokenFromRequest).isEqualTo(expectedCsrfToken);
+		assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken);
 	}
 
-	private static CsrfTokenAssert assertToken(Object token) {
-		return new CsrfTokenAssert((CsrfToken) token);
-	}
+	private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
 
-	private static class CsrfTokenAssert extends AbstractObjectAssert<CsrfTokenAssert, CsrfToken> {
+		private final CsrfToken csrfToken;
 
-		/**
-		 * Creates a new {@link ObjectAssert}.
-		 * @param actual the target to verify.
-		 */
-		protected CsrfTokenAssert(CsrfToken actual) {
-			super(actual, CsrfTokenAssert.class);
+		private final boolean isGenerated;
+
+		private TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) {
+			this.csrfToken = csrfToken;
+			this.isGenerated = isGenerated;
 		}
 
-		CsrfTokenAssert isEqualTo(CsrfToken expected) {
-			assertThat(this.actual.getHeaderName()).isEqualTo(expected.getHeaderName());
-			assertThat(this.actual.getParameterName()).isEqualTo(expected.getParameterName());
-			assertThat(this.actual.getToken()).isEqualTo(expected.getToken());
-			return this;
+		@Override
+		public CsrfToken get() {
+			return this.csrfToken;
 		}
 
-	}
+		@Override
+		public boolean isGenerated() {
+			return this.isGenerated;
+		}
+
+	};
 
 }

+ 48 - 0
web/src/test/java/org/springframework/security/web/csrf/CsrfTokenAssert.java

@@ -0,0 +1,48 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.csrf;
+
+import org.assertj.core.api.AbstractAssert;
+import org.assertj.core.api.Assertions;
+
+/**
+ * Assertion for validating the properties on CsrfToken are the same.
+ */
+public class CsrfTokenAssert extends AbstractAssert<CsrfTokenAssert, CsrfToken> {
+
+	protected CsrfTokenAssert(CsrfToken csrfToken) {
+		super(csrfToken, CsrfTokenAssert.class);
+	}
+
+	public static CsrfTokenAssert assertThatCsrfToken(Object csrfToken) {
+		return new CsrfTokenAssert((CsrfToken) csrfToken);
+	}
+
+	public static CsrfTokenAssert assertThat(CsrfToken csrfToken) {
+		return new CsrfTokenAssert(csrfToken);
+	}
+
+	public CsrfTokenAssert isEqualTo(CsrfToken csrfToken) {
+		isNotNull();
+		assertThat(csrfToken).isNotNull();
+		Assertions.assertThat(this.actual.getHeaderName()).isEqualTo(csrfToken.getHeaderName());
+		Assertions.assertThat(this.actual.getParameterName()).isEqualTo(csrfToken.getParameterName());
+		Assertions.assertThat(this.actual.getToken()).isEqualTo(csrfToken.getToken());
+		return this;
+	}
+
+}

+ 21 - 24
web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessorTests.java

@@ -18,12 +18,17 @@ package org.springframework.security.web.csrf;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+import org.mockito.Mock;
+import org.mockito.junit.jupiter.MockitoExtension;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.BDDMockito.given;
+import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
 
 /**
  * Tests for {@link CsrfTokenRequestProcessor}.
@@ -31,8 +36,12 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
  * @author Steve Riesenberg
  * @since 5.8
  */
+@ExtendWith(MockitoExtension.class)
 public class CsrfTokenRequestProcessorTests {
 
+	@Mock
+	CsrfTokenRepository tokenRepository;
+
 	private MockHttpServletRequest request;
 
 	private MockHttpServletResponse response;
@@ -47,48 +56,36 @@ public class CsrfTokenRequestProcessorTests {
 		this.response = new MockHttpServletResponse();
 		this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue");
 		this.processor = new CsrfTokenRequestProcessor();
+		this.processor.setTokenRepository(this.tokenRepository);
 	}
 
 	@Test
 	public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() {
-		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.processor.handle(null, this.response, () -> this.token))
+		assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(null, this.response))
 				.withMessage("request cannot be null");
 	}
 
 	@Test
 	public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() {
-		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.processor.handle(this.request, null, () -> this.token))
+		assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(this.request, null))
 				.withMessage("response cannot be null");
 	}
 
-	@Test
-	public void handleWhenCsrfTokenSupplierIsNullThenThrowsIllegalArgumentException() {
-		assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(this.request, this.response, null))
-				.withMessage("csrfToken supplier cannot be null");
-	}
-
-	@Test
-	public void handleWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() {
-		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.processor.handle(this.request, this.response, () -> null))
-				.withMessage("csrfToken cannot be null");
-	}
-
 	@Test
 	public void handleWhenCsrfRequestAttributeSetThenUsed() {
-		this.processor.setCsrfRequestAttributeName("_csrf.attr");
-		this.processor.handle(this.request, this.response, () -> this.token);
-		assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
-		assertThat(this.request.getAttribute("_csrf.attr")).isEqualTo(this.token);
+		given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
+		this.processor.setCsrfRequestAttributeName("_csrf");
+		this.processor.handle(this.request, this.response);
+		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token);
 	}
 
 	@Test
 	public void handleWhenValidParametersThenRequestAttributesSet() {
-		this.processor.handle(this.request, this.response, () -> this.token);
-		assertThat(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
-		assertThat(this.request.getAttribute("_csrf")).isEqualTo(this.token);
+		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
+		this.processor.handle(this.request, this.response);
+		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
+		assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token);
 	}
 
 	@Test