Ver Fonte

Add deferred CsrfTokenRepository.loadDeferredToken

* Move DeferredCsrfToken to top-level and implement Supplier<CsrfToken>
* Move RepositoryDeferredCsrfToken to top-level and make package-private
* Add CsrfTokenRepository.loadToken(HttpServletRequest, HttpServletResponse)
* Update CsrfFilter
* Rename CsrfTokenRepositoryRequestHandler to CsrfTokenRequestAttributeHandler

Issue gh-11892
Closes gh-11918
Steve Riesenberg há 2 anos atrás
pai
commit
475b3bb6bb
31 ficheiros alterados com 533 adições e 350 exclusões
  1. 4 8
      config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java
  2. 5 11
      config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java
  3. 1 1
      config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc
  4. 1 1
      config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd
  5. 2 2
      config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java
  6. 64 35
      config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java
  7. 3 7
      config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java
  8. 1 1
      config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml
  9. 1 1
      config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml
  10. 1 1
      docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc
  11. 37 44
      test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java
  12. 7 9
      test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java
  13. 3 3
      test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java
  14. 3 3
      test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java
  15. 74 0
      test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java
  16. 7 8
      test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java
  17. 9 8
      web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java
  18. 23 16
      web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java
  19. 18 1
      web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java
  20. 5 66
      web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandler.java
  21. 9 6
      web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java
  22. 1 1
      web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestResolver.java
  23. 2 1
      web/src/main/java/org/springframework/security/web/csrf/DeferredCsrfToken.java
  24. 3 2
      web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java
  25. 71 0
      web/src/main/java/org/springframework/security/web/csrf/RepositoryDeferredCsrfToken.java
  26. 29 1
      web/src/test/java/org/springframework/security/web/csrf/CookieCsrfTokenRepositoryTests.java
  27. 11 16
      web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java
  28. 50 60
      web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java
  29. 26 36
      web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandlerTests.java
  30. 22 1
      web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java
  31. 40 0
      web/src/test/java/org/springframework/security/web/csrf/TestDeferredCsrfToken.java

+ 4 - 8
config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java

@@ -36,7 +36,6 @@ import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfLogoutHandler;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
-import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
 import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
@@ -249,13 +248,7 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 	@SuppressWarnings("unchecked")
 	@Override
 	public void configure(H http) {
-		CsrfFilter filter;
-		if (this.requestHandler != null) {
-			filter = new CsrfFilter(this.requestHandler);
-		}
-		else {
-			filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.csrfTokenRepository));
-		}
+		CsrfFilter filter = new CsrfFilter(this.csrfTokenRepository);
 		RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher();
 		if (requireCsrfProtectionMatcher != null) {
 			filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);
@@ -272,6 +265,9 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 		if (sessionConfigurer != null) {
 			sessionConfigurer.addSessionAuthenticationStrategy(getSessionAuthenticationStrategy());
 		}
+		if (this.requestHandler != null) {
+			filter.setRequestHandler(this.requestHandler);
+		}
 		filter = postProcess(filter);
 		http.addFilter(filter);
 	}

+ 5 - 11
config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -41,7 +41,6 @@ import org.springframework.security.web.access.DelegatingAccessDeniedHandler;
 import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfLogoutHandler;
-import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
 import org.springframework.security.web.csrf.MissingCsrfTokenException;
@@ -112,18 +111,13 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
 					new BeanComponentDefinition(lazyTokenRepository.getBeanDefinition(), this.csrfRepositoryRef));
 		}
 		BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(CsrfFilter.class);
-		if (!StringUtils.hasText(this.requestHandlerRef)) {
-			BeanDefinition csrfTokenRequestHandler = BeanDefinitionBuilder
-					.rootBeanDefinition(CsrfTokenRepositoryRequestHandler.class)
-					.addConstructorArgReference(this.csrfRepositoryRef).getBeanDefinition();
-			builder.addConstructorArgValue(csrfTokenRequestHandler);
-		}
-		else {
-			builder.addConstructorArgReference(this.requestHandlerRef);
-		}
+		builder.addConstructorArgReference(this.csrfRepositoryRef);
 		if (StringUtils.hasText(this.requestMatcherRef)) {
 			builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef);
 		}
+		if (StringUtils.hasText(this.requestHandlerRef)) {
+			builder.addPropertyReference("requestHandler", this.requestHandlerRef);
+		}
 		this.csrfFilter = builder.getBeanDefinition();
 		return this.csrfFilter;
 	}

+ 1 - 1
config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc

@@ -1152,7 +1152,7 @@ csrf-options.attlist &=
 	## The CsrfTokenRepository to use. The default is HttpSessionCsrfTokenRepository wrapped by LazyCsrfTokenRepository.
 	attribute token-repository-ref { xsd:token }?
 csrf-options.attlist &=
-	## The CsrfTokenRequestHandler to use. The default is CsrfTokenRepositoryRequestHandler.
+	## The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestAttributeHandler.
 	attribute request-handler-ref { xsd:token }?
 
 headers =

+ 1 - 1
config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd

@@ -3258,7 +3258,7 @@
       </xs:attribute>
       <xs:attribute name="request-handler-ref" type="xs:token">
          <xs:annotation>
-            <xs:documentation>The CsrfTokenRequestHandler to use. The default is CsrfTokenRepositoryRequestHandler.
+            <xs:documentation>The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestAttributeHandler.
                 </xs:documentation>
          </xs:annotation>
       </xs:attribute>

+ 2 - 2
config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java

@@ -33,7 +33,7 @@ import org.springframework.security.config.test.SpringTestContext;
 import org.springframework.security.config.test.SpringTestContextExtension;
 import org.springframework.security.web.DefaultSecurityFilterChain;
 import org.springframework.security.web.FilterChainProxy;
-import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
+import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
 import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
@@ -85,7 +85,7 @@ public class DeferHttpSessionJavaConfigTests {
 			csrfRepository.setDeferLoadToken(true);
 			HttpSessionRequestCache requestCache = new HttpSessionRequestCache();
 			requestCache.setMatchingRequestParameterName("continue");
-			CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler();
+			CsrfTokenRequestAttributeHandler requestHandler = new CsrfTokenRequestAttributeHandler();
 			requestHandler.setCsrfRequestAttributeName("_csrf");
 			// @formatter:off
 			http

+ 64 - 35
config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java

@@ -44,8 +44,10 @@ import org.springframework.security.web.access.AccessDeniedHandler;
 import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
-import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
+import org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler;
+import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
 import org.springframework.security.web.csrf.DefaultCsrfToken;
+import org.springframework.security.web.csrf.DeferredCsrfToken;
 import org.springframework.security.web.firewall.StrictHttpFirewall;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -61,7 +63,6 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.hamcrest.Matchers.containsString;
 import static org.mockito.ArgumentMatchers.any;
-import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.ArgumentMatchers.isNull;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.atLeastOnce;
@@ -207,30 +208,30 @@ public class CsrfConfigurerTests {
 	public void loginWhenCsrfEnabledThenDoesNotRedirectToPreviousPostRequest() throws Exception {
 		CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class);
 		DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
-		given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadToken(any())).willReturn(csrfToken);
-		given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).willReturn(csrfToken);
+		given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
+				any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken));
 		this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire();
 		MvcResult mvcResult = this.mvc.perform(post("/some-url")).andReturn();
 		this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf())
 				.session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound())
 				.andExpect(redirectedUrl("/"));
 		verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce())
-				.loadToken(any(HttpServletRequest.class));
+				.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
 	@Test
 	public void loginWhenCsrfEnabledThenRedirectsToPreviousGetRequest() throws Exception {
 		CsrfDisablesPostRequestFromRequestCacheConfig.REPO = mock(CsrfTokenRepository.class);
 		DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
-		given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadToken(any())).willReturn(csrfToken);
-		given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.generateToken(any())).willReturn(csrfToken);
+		given(CsrfDisablesPostRequestFromRequestCacheConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
+				any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken));
 		this.spring.register(CsrfDisablesPostRequestFromRequestCacheConfig.class).autowire();
 		MvcResult mvcResult = this.mvc.perform(get("/some-url")).andReturn();
 		this.mvc.perform(post("/login").param("username", "user").param("password", "password").with(csrf())
 				.session((MockHttpSession) mvcResult.getRequest().getSession())).andExpect(status().isFound())
 				.andExpect(redirectedUrl("http://localhost/some-url"));
 		verify(CsrfDisablesPostRequestFromRequestCacheConfig.REPO, atLeastOnce())
-				.loadToken(any(HttpServletRequest.class));
+				.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class));
 	}
 
 	// SEC-2422
@@ -277,11 +278,13 @@ public class CsrfConfigurerTests {
 	@Test
 	public void getWhenCustomCsrfTokenRepositoryThenRepositoryIsUsed() throws Exception {
 		CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class);
-		given(CsrfTokenRepositoryConfig.REPO.loadToken(any()))
-				.willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"));
+		given(CsrfTokenRepositoryConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
+				any(HttpServletResponse.class)))
+						.willReturn(new TestDeferredCsrfToken(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")));
 		this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire();
 		this.mvc.perform(get("/")).andExpect(status().isOk());
-		verify(CsrfTokenRepositoryConfig.REPO).loadToken(any(HttpServletRequest.class));
+		verify(CsrfTokenRepositoryConfig.REPO).loadDeferredToken(any(HttpServletRequest.class),
+				any(HttpServletResponse.class));
 	}
 
 	@Test
@@ -297,8 +300,8 @@ public class CsrfConfigurerTests {
 	public void loginWhenCustomCsrfTokenRepositoryThenCsrfTokenIsCleared() throws Exception {
 		CsrfTokenRepositoryConfig.REPO = mock(CsrfTokenRepository.class);
 		DefaultCsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
-		given(CsrfTokenRepositoryConfig.REPO.loadToken(any())).willReturn(csrfToken);
-		given(CsrfTokenRepositoryConfig.REPO.generateToken(any())).willReturn(csrfToken);
+		given(CsrfTokenRepositoryConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
+				any(HttpServletResponse.class))).willReturn(new TestDeferredCsrfToken(csrfToken));
 		this.spring.register(CsrfTokenRepositoryConfig.class, BasicController.class).autowire();
 		// @formatter:off
 		MockHttpServletRequestBuilder loginRequest = post("/login")
@@ -314,11 +317,13 @@ public class CsrfConfigurerTests {
 	@Test
 	public void getWhenCustomCsrfTokenRepositoryInLambdaThenRepositoryIsUsed() throws Exception {
 		CsrfTokenRepositoryInLambdaConfig.REPO = mock(CsrfTokenRepository.class);
-		given(CsrfTokenRepositoryInLambdaConfig.REPO.loadToken(any()))
-				.willReturn(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token"));
+		given(CsrfTokenRepositoryInLambdaConfig.REPO.loadDeferredToken(any(HttpServletRequest.class),
+				any(HttpServletResponse.class)))
+						.willReturn(new TestDeferredCsrfToken(new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token")));
 		this.spring.register(CsrfTokenRepositoryInLambdaConfig.class, BasicController.class).autowire();
 		this.mvc.perform(get("/")).andExpect(status().isOk());
-		verify(CsrfTokenRepositoryInLambdaConfig.REPO).loadToken(any(HttpServletRequest.class));
+		verify(CsrfTokenRepositoryInLambdaConfig.REPO).loadDeferredToken(any(HttpServletRequest.class),
+				any(HttpServletResponse.class));
 	}
 
 	@Test
@@ -418,30 +423,30 @@ public class CsrfConfigurerTests {
 	}
 
 	@Test
-	public void getLoginWhenCsrfTokenRequestProcessorSetThenRespondsWithNormalCsrfToken() throws Exception {
+	public void getLoginWhenCsrfTokenRequestHandlerSetThenRespondsWithNormalCsrfToken() throws Exception {
 		CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
 		CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
-		given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken);
-		CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository);
-		this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire();
+		given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
+				.willReturn(new TestDeferredCsrfToken(csrfToken));
+		CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository;
+		CsrfTokenRequestHandlerConfig.HANDLER = new CsrfTokenRequestAttributeHandler();
+		this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire();
 		this.mvc.perform(get("/login")).andExpect(status().isOk())
 				.andExpect(content().string(containsString(csrfToken.getToken())));
-		verify(csrfTokenRepository).loadToken(any(HttpServletRequest.class));
-		verify(csrfTokenRepository).generateToken(any(HttpServletRequest.class));
-		verify(csrfTokenRepository).saveToken(eq(csrfToken), any(HttpServletRequest.class),
-				any(HttpServletResponse.class));
+		verify(csrfTokenRepository).loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class));
 		verifyNoMoreInteractions(csrfTokenRepository);
 	}
 
 	@Test
-	public void loginWhenCsrfTokenRequestProcessorSetAndNormalCsrfTokenThenSuccess() throws Exception {
+	public void loginWhenCsrfTokenRequestHandlerSetAndNormalCsrfTokenThenSuccess() throws Exception {
 		CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
 		CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
-		given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(null, csrfToken);
-		given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken);
-		CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository);
+		given(csrfTokenRepository.loadDeferredToken(any(HttpServletRequest.class), any(HttpServletResponse.class)))
+				.willReturn(new TestDeferredCsrfToken(csrfToken));
+		CsrfTokenRequestHandlerConfig.REPO = csrfTokenRepository;
+		CsrfTokenRequestHandlerConfig.HANDLER = new CsrfTokenRequestAttributeHandler();
+		this.spring.register(CsrfTokenRequestHandlerConfig.class, BasicController.class).autowire();
 
-		this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire();
 		// @formatter:off
 		MockHttpServletRequestBuilder loginRequest = post("/login")
 				.header(csrfToken.getHeaderName(), csrfToken.getToken())
@@ -449,9 +454,8 @@ public class CsrfConfigurerTests {
 				.param("password", "password");
 		// @formatter:on
 		this.mvc.perform(loginRequest).andExpect(redirectedUrl("/"));
-		verify(csrfTokenRepository, times(2)).loadToken(any(HttpServletRequest.class));
-		verify(csrfTokenRepository).generateToken(any(HttpServletRequest.class));
-		verify(csrfTokenRepository).saveToken(eq(csrfToken), any(HttpServletRequest.class),
+		verify(csrfTokenRepository).saveToken(isNull(), any(HttpServletRequest.class), any(HttpServletResponse.class));
+		verify(csrfTokenRepository, times(2)).loadDeferredToken(any(HttpServletRequest.class),
 				any(HttpServletResponse.class));
 		verifyNoMoreInteractions(csrfTokenRepository);
 	}
@@ -799,9 +803,11 @@ public class CsrfConfigurerTests {
 
 	@Configuration
 	@EnableWebSecurity
-	static class CsrfTokenRequestProcessorConfig {
+	static class CsrfTokenRequestHandlerConfig {
+
+		static CsrfTokenRepository REPO;
 
-		static CsrfTokenRepositoryRequestHandler HANDLER;
+		static CsrfTokenRequestHandler HANDLER;
 
 		@Bean
 		SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
@@ -811,7 +817,10 @@ public class CsrfConfigurerTests {
 					.anyRequest().authenticated()
 				)
 				.formLogin(Customizer.withDefaults())
-				.csrf((csrf) -> csrf.csrfTokenRequestHandler(HANDLER));
+				.csrf((csrf) -> csrf
+					.csrfTokenRepository(REPO)
+					.csrfTokenRequestHandler(HANDLER)
+				);
 			// @formatter:on
 
 			return http.build();
@@ -841,4 +850,24 @@ public class CsrfConfigurerTests {
 
 	}
 
+	private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
+
+		private final CsrfToken csrfToken;
+
+		private TestDeferredCsrfToken(CsrfToken csrfToken) {
+			this.csrfToken = csrfToken;
+		}
+
+		@Override
+		public CsrfToken get() {
+			return this.csrfToken;
+		}
+
+		@Override
+		public boolean isGenerated() {
+			return false;
+		}
+
+	}
+
 }

+ 3 - 7
config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java

@@ -30,7 +30,6 @@ import org.junit.jupiter.api.extension.ExtendWith;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.http.HttpMethod;
 import org.springframework.mock.web.MockHttpServletRequest;
-import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpSession;
 import org.springframework.security.access.AccessDeniedException;
 import org.springframework.security.config.test.SpringTestContext;
@@ -42,7 +41,6 @@ import org.springframework.security.web.FilterChainProxy;
 import org.springframework.security.web.access.AccessDeniedHandler;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfToken;
-import org.springframework.security.web.csrf.DeferredCsrfToken;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.stereotype.Controller;
 import org.springframework.test.context.junit.jupiter.SpringExtension;
@@ -546,9 +544,8 @@ public class CsrfConfigTests {
 		@Override
 		public void match(MvcResult result) {
 			MockHttpServletRequest request = result.getRequest();
-			MockHttpServletResponse response = result.getResponse();
-			DeferredCsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response);
-			assertThat(token.isGenerated()).isFalse();
+			CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request);
+			assertThat(token).isNotNull();
 		}
 
 	}
@@ -564,8 +561,7 @@ public class CsrfConfigTests {
 		@Override
 		public void match(MvcResult result) throws Exception {
 			MockHttpServletRequest request = result.getRequest();
-			MockHttpServletResponse response = result.getResponse();
-			CsrfToken token = WebTestUtils.getCsrfTokenRequestHandler(request).handle(request, response).get();
+			CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request);
 			assertThat(token).isNotNull();
 			assertThat(token.getToken()).isEqualTo(this.token.apply(result));
 		}

+ 1 - 1
config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml

@@ -26,7 +26,7 @@
 		<csrf request-handler-ref="requestHandler"/>
 	</http>
 
-	<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler"
+	<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler"
 		p:csrfRequestAttributeName="csrf-attribute-name"/>
 	<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
 </b:beans>

+ 1 - 1
config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml

@@ -42,7 +42,7 @@
 	<b:bean id="csrfRepository" class="org.springframework.security.web.csrf.LazyCsrfTokenRepository"
 		c:delegate-ref="httpSessionCsrfRepository"
 	 	p:deferLoadToken="true"/>
-	<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler"
+	<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRequestAttributeHandler"
 		p:csrfRequestAttributeName="_csrf"/>
 	<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
 </b:beans>

+ 1 - 1
docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc

@@ -783,7 +783,7 @@ The default is `HttpSessionCsrfTokenRepository`.
 
 [[nsa-csrf-request-handler-ref]]
 * **request-handler-ref**
-The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRepositoryRequestHandler`.
+The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRequestAttributeHandler`.
 
 [[nsa-csrf-request-matcher-ref]]
 * **request-matcher-ref**

+ 37 - 44
test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java

@@ -94,8 +94,7 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfToken;
-import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
-import org.springframework.security.web.csrf.DeferredCsrfToken;
+import org.springframework.security.web.csrf.CsrfTokenRepository;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.test.web.servlet.MockMvc;
@@ -509,13 +508,14 @@ public final class SecurityMockMvcRequestPostProcessors {
 
 		@Override
 		public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
-			CsrfTokenRequestHandler handler = WebTestUtils.getCsrfTokenRequestHandler(request);
-			if (!(handler instanceof TestCsrfTokenRequestHandler)) {
-				handler = new TestCsrfTokenRequestHandler(handler);
-				WebTestUtils.setCsrfTokenRequestHandler(request, handler);
+			CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
+			if (!(repository instanceof TestCsrfTokenRepository)) {
+				repository = new TestCsrfTokenRepository(new HttpSessionCsrfTokenRepository());
+				WebTestUtils.setCsrfTokenRepository(request, repository);
 			}
-			TestCsrfTokenRequestHandler testHandler = (TestCsrfTokenRequestHandler) handler;
-			CsrfToken token = TestCsrfTokenRequestHandler.createTestCsrfToken(request);
+			TestCsrfTokenRepository.enable(request);
+			CsrfToken token = repository.generateToken(request);
+			repository.saveToken(token, request, new MockHttpServletResponse());
 			String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() : token.getToken();
 			if (this.asHeader) {
 				request.addHeader(token.getHeaderName(), tokenValue);
@@ -549,56 +549,49 @@ public final class SecurityMockMvcRequestPostProcessors {
 		 * Used to wrap the CsrfTokenRepository to provide support for testing when the
 		 * request is wrapped (i.e. Spring Session is in use).
 		 */
-		static class TestCsrfTokenRequestHandler implements CsrfTokenRequestHandler {
+		static class TestCsrfTokenRepository implements CsrfTokenRepository {
 
-			static final String TOKEN_ATTR_NAME = TestCsrfTokenRequestHandler.class.getName().concat(".TOKEN");
+			static final String TOKEN_ATTR_NAME = TestCsrfTokenRepository.class.getName().concat(".TOKEN");
 
-			static final String ENABLED_ATTR_NAME = TestCsrfTokenRequestHandler.class.getName().concat(".ENABLED");
+			static final String ENABLED_ATTR_NAME = TestCsrfTokenRepository.class.getName().concat(".ENABLED");
 
-			private final CsrfTokenRequestHandler delegate;
+			private final CsrfTokenRepository delegate;
 
-			TestCsrfTokenRequestHandler(CsrfTokenRequestHandler delegate) {
+			TestCsrfTokenRepository(CsrfTokenRepository delegate) {
 				this.delegate = delegate;
 			}
 
-			static CsrfToken createTestCsrfToken(HttpServletRequest request) {
-				CsrfToken existingToken = getExistingToken(request);
-				if (existingToken != null) {
-					return existingToken;
-				}
-				HttpSessionCsrfTokenRepository repository = new HttpSessionCsrfTokenRepository();
-				CsrfToken csrfToken = repository.generateToken(request);
-				request.setAttribute(ENABLED_ATTR_NAME, true);
-				request.setAttribute(TOKEN_ATTR_NAME, csrfToken);
-				return csrfToken;
-			}
-
-			private static CsrfToken getExistingToken(HttpServletRequest request) {
-				Object existingToken = request.getAttribute(TOKEN_ATTR_NAME);
-				return (CsrfToken) existingToken;
+			@Override
+			public CsrfToken generateToken(HttpServletRequest request) {
+				return this.delegate.generateToken(request);
 			}
 
-			boolean isEnabled(HttpServletRequest request) {
-				return getExistingToken(request) != null;
+			@Override
+			public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) {
+				if (isEnabled(request)) {
+					request.setAttribute(TOKEN_ATTR_NAME, token);
+				}
+				else {
+					this.delegate.saveToken(token, request, response);
+				}
 			}
 
 			@Override
-			public DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response) {
-				request.setAttribute(HttpServletResponse.class.getName(), response);
-				if (!isEnabled(request)) {
-					return this.delegate.handle(request, response);
+			public CsrfToken loadToken(HttpServletRequest request) {
+				if (isEnabled(request)) {
+					return (CsrfToken) request.getAttribute(TOKEN_ATTR_NAME);
 				}
-				return new DeferredCsrfToken() {
-					@Override
-					public CsrfToken get() {
-						return getExistingToken(request);
-					}
+				else {
+					return this.delegate.loadToken(request);
+				}
+			}
 
-					@Override
-					public boolean isGenerated() {
-						return false;
-					}
-				};
+			static void enable(HttpServletRequest request) {
+				request.setAttribute(ENABLED_ATTR_NAME, Boolean.TRUE);
+			}
+
+			boolean isEnabled(HttpServletRequest request) {
+				return Boolean.TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME));
 			}
 
 		}

+ 7 - 9
test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java

@@ -31,8 +31,6 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
-import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
-import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.web.context.WebApplicationContext;
@@ -48,7 +46,7 @@ public abstract class WebTestUtils {
 
 	private static final SecurityContextRepository DEFAULT_CONTEXT_REPO = new HttpSessionSecurityContextRepository();
 
-	private static final CsrfTokenRepositoryRequestHandler DEFAULT_CSRF_HANDLER = new CsrfTokenRepositoryRequestHandler();
+	private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository();
 
 	private WebTestUtils() {
 	}
@@ -101,24 +99,24 @@ public abstract class WebTestUtils {
 	 * @return the {@link CsrfTokenRepository} for the specified
 	 * {@link HttpServletRequest}
 	 */
-	public static CsrfTokenRequestHandler getCsrfTokenRequestHandler(HttpServletRequest request) {
+	public static CsrfTokenRepository getCsrfTokenRepository(HttpServletRequest request) {
 		CsrfFilter filter = findFilter(request, CsrfFilter.class);
 		if (filter == null) {
-			return DEFAULT_CSRF_HANDLER;
+			return DEFAULT_TOKEN_REPO;
 		}
-		return (CsrfTokenRequestHandler) ReflectionTestUtils.getField(filter, "requestHandler");
+		return (CsrfTokenRepository) ReflectionTestUtils.getField(filter, "tokenRepository");
 	}
 
 	/**
 	 * Sets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}.
 	 * @param request the {@link HttpServletRequest} to obtain the
 	 * {@link CsrfTokenRepository}
-	 * @param handler the {@link CsrfTokenRepository} to set
+	 * @param repository the {@link CsrfTokenRepository} to set
 	 */
-	public static void setCsrfTokenRequestHandler(HttpServletRequest request, CsrfTokenRequestHandler handler) {
+	public static void setCsrfTokenRepository(HttpServletRequest request, CsrfTokenRepository repository) {
 		CsrfFilter filter = findFilter(request, CsrfFilter.class);
 		if (filter != null) {
-			ReflectionTestUtils.setField(filter, "requestHandler", handler);
+			ReflectionTestUtils.setField(filter, "tokenRepository", repository);
 		}
 	}
 

+ 3 - 3
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java

@@ -53,7 +53,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
 	public void defaults() {
 		MockHttpServletRequest request = formLogin().buildRequest(this.servletContext);
 		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME);
+				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
 		assertThat(request.getParameter("username")).isEqualTo("user");
 		assertThat(request.getParameter("password")).isEqualTo("password");
 		assertThat(request.getMethod()).isEqualTo("POST");
@@ -67,7 +67,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
 		MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret")
 				.buildRequest(this.servletContext);
 		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME);
+				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
 		assertThat(request.getParameter("username")).isEqualTo("admin");
 		assertThat(request.getParameter("password")).isEqualTo("secret");
 		assertThat(request.getMethod()).isEqualTo("POST");
@@ -80,7 +80,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
 		MockHttpServletRequest request = formLogin().loginProcessingUrl("/uri-login/{var1}/{var2}", "val1", "val2")
 				.user("username", "admin").password("password", "secret").buildRequest(this.servletContext);
 		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME);
+				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
 		assertThat(request.getParameter("username")).isEqualTo("admin");
 		assertThat(request.getParameter("password")).isEqualTo("secret");
 		assertThat(request.getMethod()).isEqualTo("POST");

+ 3 - 3
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java

@@ -53,7 +53,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
 	public void defaults() {
 		MockHttpServletRequest request = logout().buildRequest(this.servletContext);
 		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME);
+				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
 		assertThat(request.getRequestURI()).isEqualTo("/logout");
@@ -63,7 +63,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
 	public void custom() {
 		MockHttpServletRequest request = logout("/admin/logout").buildRequest(this.servletContext);
 		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME);
+				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
 		assertThat(request.getRequestURI()).isEqualTo("/admin/logout");
@@ -74,7 +74,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
 		MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2")
 				.buildRequest(this.servletContext);
 		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRequestHandler.TOKEN_ATTR_NAME);
+				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
 		assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2");

+ 74 - 0
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests.java

@@ -0,0 +1,74 @@
+/*
+ * Copyright 2002-2016 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.test.web.servlet.request;
+
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.security.config.annotation.web.builders.HttpSecurity;
+import org.springframework.security.config.annotation.web.builders.WebSecurity;
+import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
+import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
+import org.springframework.security.test.web.support.WebTestUtils;
+import org.springframework.security.web.csrf.CookieCsrfTokenRepository;
+import org.springframework.security.web.csrf.CsrfTokenRepository;
+import org.springframework.test.context.ContextConfiguration;
+import org.springframework.test.context.junit.jupiter.SpringExtension;
+import org.springframework.test.context.web.WebAppConfiguration;
+import org.springframework.web.context.WebApplicationContext;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
+
+@ExtendWith(SpringExtension.class)
+@ContextConfiguration
+@WebAppConfiguration
+public class SecurityMockMvcRequestPostProcessorsCsrfDebugFilterTests {
+
+	@Autowired
+	private WebApplicationContext wac;
+
+	// SEC-3836
+	@Test
+	public void findCookieCsrfTokenRepository() {
+		MockHttpServletRequest request = post("/").buildRequest(this.wac.getServletContext());
+		CsrfTokenRepository csrfTokenRepository = WebTestUtils.getCsrfTokenRepository(request);
+		assertThat(csrfTokenRepository).isNotNull();
+		assertThat(csrfTokenRepository).isEqualTo(Config.cookieCsrfTokenRepository);
+	}
+
+	@EnableWebSecurity
+	static class Config extends WebSecurityConfigurerAdapter {
+
+		static CsrfTokenRepository cookieCsrfTokenRepository = new CookieCsrfTokenRepository();
+
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			http.csrf().csrfTokenRepository(cookieCsrfTokenRepository);
+		}
+
+		@Override
+		public void configure(WebSecurity web) {
+			// Enable the DebugFilter
+			web.debug(true);
+		}
+
+	}
+
+}

+ 7 - 8
test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java

@@ -39,7 +39,6 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
-import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.security.web.util.matcher.AnyRequestMatcher;
 import org.springframework.web.context.WebApplicationContext;
@@ -75,22 +74,22 @@ public class WebTestUtilsTests {
 
 	@Test
 	public void getCsrfTokenRepositorytNoWac() {
-		assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request))
-				.isInstanceOf(CsrfTokenRepositoryRequestHandler.class);
+		assertThat(WebTestUtils.getCsrfTokenRepository(this.request))
+				.isInstanceOf(HttpSessionCsrfTokenRepository.class);
 	}
 
 	@Test
 	public void getCsrfTokenRepositorytNoSecurity() {
 		loadConfig(Config.class);
-		assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request))
-				.isInstanceOf(CsrfTokenRepositoryRequestHandler.class);
+		assertThat(WebTestUtils.getCsrfTokenRepository(this.request))
+				.isInstanceOf(HttpSessionCsrfTokenRepository.class);
 	}
 
 	@Test
 	public void getCsrfTokenRepositorytSecurityNoCsrf() {
 		loadConfig(SecurityNoCsrfConfig.class);
-		assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request))
-				.isInstanceOf(CsrfTokenRepositoryRequestHandler.class);
+		assertThat(WebTestUtils.getCsrfTokenRepository(this.request))
+				.isInstanceOf(HttpSessionCsrfTokenRepository.class);
 	}
 
 	@Test
@@ -98,7 +97,7 @@ public class WebTestUtilsTests {
 		CustomSecurityConfig.CONTEXT_REPO = this.contextRepo;
 		CustomSecurityConfig.CSRF_REPO = this.csrfRepo;
 		loadConfig(CustomSecurityConfig.class);
-		// assertThat(WebTestUtils.getCsrfTokenRepository(this.request)).isSameAs(this.csrfRepo);
+		assertThat(WebTestUtils.getCsrfTokenRepository(this.request)).isSameAs(this.csrfRepo);
 	}
 
 	// getSecurityContextRepository

+ 9 - 8
web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java

@@ -39,17 +39,17 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
 
 	private final Log logger = LogFactory.getLog(getClass());
 
-	private final CsrfTokenRepository csrfTokenRepository;
+	private final CsrfTokenRepository tokenRepository;
 
-	private CsrfTokenRequestHandler requestHandler;
+	private CsrfTokenRequestHandler requestHandler = new CsrfTokenRequestAttributeHandler();
 
 	/**
 	 * Creates a new instance
-	 * @param csrfTokenRepository the {@link CsrfTokenRepository} to use
+	 * @param tokenRepository the {@link CsrfTokenRepository} to use
 	 */
-	public CsrfAuthenticationStrategy(CsrfTokenRepository csrfTokenRepository) {
-		this.requestHandler = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository);
-		this.csrfTokenRepository = csrfTokenRepository;
+	public CsrfAuthenticationStrategy(CsrfTokenRepository tokenRepository) {
+		Assert.notNull(tokenRepository, "tokenRepository cannot be null");
+		this.tokenRepository = tokenRepository;
 	}
 
 	/**
@@ -65,8 +65,9 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
 	@Override
 	public void onAuthentication(Authentication authentication, HttpServletRequest request,
 			HttpServletResponse response) throws SessionAuthenticationException {
-		this.csrfTokenRepository.saveToken(null, request, response);
-		this.requestHandler.handle(request, response);
+		this.tokenRepository.saveToken(null, request, response);
+		DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response);
+		this.requestHandler.handle(request, response, deferredCsrfToken::get);
 		this.logger.debug("Replaced CSRF Token");
 	}
 

+ 23 - 16
web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java

@@ -82,30 +82,21 @@ public final class CsrfFilter extends OncePerRequestFilter {
 
 	private final Log logger = LogFactory.getLog(getClass());
 
-	private final CsrfTokenRequestHandler requestHandler;
+	private final CsrfTokenRepository tokenRepository;
 
 	private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
 
 	private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl();
 
-	/**
-	 * Creates a new instance.
-	 * @param csrfTokenRepository the {@link CsrfTokenRepository} to use
-	 * @deprecated Use {@link CsrfFilter#CsrfFilter(CsrfTokenRequestHandler)} instead
-	 */
-	@Deprecated
-	public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
-		this(new CsrfTokenRepositoryRequestHandler(csrfTokenRepository));
-	}
+	private CsrfTokenRequestHandler requestHandler = new CsrfTokenRequestAttributeHandler();
 
 	/**
 	 * Creates a new instance.
-	 * @param requestHandler the {@link CsrfTokenRequestHandler} to use. Default is
-	 * {@link CsrfTokenRepositoryRequestHandler}.
+	 * @param tokenRepository the {@link CsrfTokenRepository} to use
 	 */
-	public CsrfFilter(CsrfTokenRequestHandler requestHandler) {
-		Assert.notNull(requestHandler, "requestHandler cannot be null");
-		this.requestHandler = requestHandler;
+	public CsrfFilter(CsrfTokenRepository tokenRepository) {
+		Assert.notNull(tokenRepository, "tokenRepository cannot be null");
+		this.tokenRepository = tokenRepository;
 	}
 
 	@Override
@@ -116,7 +107,8 @@ public final class CsrfFilter extends OncePerRequestFilter {
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
-		DeferredCsrfToken deferredCsrfToken = this.requestHandler.handle(request, response);
+		DeferredCsrfToken deferredCsrfToken = this.tokenRepository.loadDeferredToken(request, response);
+		this.requestHandler.handle(request, response, deferredCsrfToken::get);
 		if (!this.requireCsrfProtectionMatcher.matches(request)) {
 			if (this.logger.isTraceEnabled()) {
 				this.logger.trace("Did not protect against CSRF since request did not match "
@@ -174,6 +166,21 @@ public final class CsrfFilter extends OncePerRequestFilter {
 		this.accessDeniedHandler = accessDeniedHandler;
 	}
 
+	/**
+	 * Specifies a {@link CsrfTokenRequestHandler} that is used to make the
+	 * {@link CsrfToken} available as a request attribute.
+	 *
+	 * <p>
+	 * The default is {@link CsrfTokenRequestAttributeHandler}.
+	 * </p>
+	 * @param requestHandler the {@link CsrfTokenRequestHandler} to use
+	 * @since 5.8
+	 */
+	public void setRequestHandler(CsrfTokenRequestHandler requestHandler) {
+		Assert.notNull(requestHandler, "requestHandler cannot be null");
+		this.requestHandler = requestHandler;
+	}
+
 	/**
 	 * Constant time comparison to prevent against timing attacks.
 	 * @param expected

+ 18 - 1
web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepository.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2013 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@ import javax.servlet.http.HttpSession;
  * {@link HttpSession}.
  *
  * @author Rob Winch
+ * @author Steve Riesenberg
  * @since 3.2
  * @see HttpSessionCsrfTokenRepository
  */
@@ -55,4 +56,20 @@ public interface CsrfTokenRepository {
 	 */
 	CsrfToken loadToken(HttpServletRequest request);
 
+	/**
+	 * Defers loading the {@link CsrfToken} using the {@link HttpServletRequest} and
+	 * {@link HttpServletResponse} until it is needed by the application.
+	 * <p>
+	 * The returned {@link DeferredCsrfToken} is cached to allow subsequent calls to
+	 * {@link DeferredCsrfToken#get()} to return the same {@link CsrfToken} without the
+	 * cost of loading or generating the token again.
+	 * @param request the {@link HttpServletRequest} to use
+	 * @param response the {@link HttpServletResponse} to use
+	 * @return a {@link DeferredCsrfToken} that will load the {@link CsrfToken}
+	 * @since 5.8
+	 */
+	default DeferredCsrfToken loadDeferredToken(HttpServletRequest request, HttpServletResponse response) {
+		return new RepositoryDeferredCsrfToken(this, request, response);
+	}
+
 }

+ 5 - 66
web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandler.java → web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandler.java

@@ -31,29 +31,10 @@ import org.springframework.util.Assert;
  * @author Steve Riesenberg
  * @since 5.8
  */
-public class CsrfTokenRepositoryRequestHandler implements CsrfTokenRequestHandler {
-
-	private final CsrfTokenRepository csrfTokenRepository;
+public class CsrfTokenRequestAttributeHandler implements CsrfTokenRequestHandler {
 
 	private String csrfRequestAttributeName;
 
-	/**
-	 * Creates a new instance.
-	 */
-	public CsrfTokenRepositoryRequestHandler() {
-		this(new HttpSessionCsrfTokenRepository());
-	}
-
-	/**
-	 * Creates a new instance.
-	 * @param csrfTokenRepository the {@link CsrfTokenRepository} to use. Default
-	 * {@link HttpSessionCsrfTokenRepository}
-	 */
-	public CsrfTokenRepositoryRequestHandler(CsrfTokenRepository csrfTokenRepository) {
-		Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
-		this.csrfTokenRepository = csrfTokenRepository;
-	}
-
 	/**
 	 * The {@link CsrfToken} is available as a request attribute named
 	 * {@code CsrfToken.class.getName()}. By default, an additional request attribute that
@@ -67,18 +48,18 @@ public class CsrfTokenRepositoryRequestHandler implements CsrfTokenRequestHandle
 	}
 
 	@Override
-	public DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response) {
+	public void handle(HttpServletRequest request, HttpServletResponse response,
+			Supplier<CsrfToken> deferredCsrfToken) {
 		Assert.notNull(request, "request cannot be null");
 		Assert.notNull(response, "response cannot be null");
+		Assert.notNull(deferredCsrfToken, "deferredCsrfToken cannot be null");
 
 		request.setAttribute(HttpServletResponse.class.getName(), response);
-		DeferredCsrfToken deferredCsrfToken = new RepositoryDeferredCsrfToken(request, response);
-		CsrfToken csrfToken = new SupplierCsrfToken(deferredCsrfToken::get);
+		CsrfToken csrfToken = new SupplierCsrfToken(deferredCsrfToken);
 		request.setAttribute(CsrfToken.class.getName(), csrfToken);
 		String csrfAttrName = (this.csrfRequestAttributeName != null) ? this.csrfRequestAttributeName
 				: csrfToken.getParameterName();
 		request.setAttribute(csrfAttrName, csrfToken);
-		return deferredCsrfToken;
 	}
 
 	private static final class SupplierCsrfToken implements CsrfToken {
@@ -114,46 +95,4 @@ public class CsrfTokenRepositoryRequestHandler implements CsrfTokenRequestHandle
 
 	}
 
-	private final class RepositoryDeferredCsrfToken implements DeferredCsrfToken {
-
-		private final HttpServletRequest request;
-
-		private final HttpServletResponse response;
-
-		private CsrfToken csrfToken;
-
-		private Boolean missingToken;
-
-		RepositoryDeferredCsrfToken(HttpServletRequest request, HttpServletResponse response) {
-			this.request = request;
-			this.response = response;
-		}
-
-		@Override
-		public CsrfToken get() {
-			init();
-			return this.csrfToken;
-		}
-
-		@Override
-		public boolean isGenerated() {
-			init();
-			return this.missingToken;
-		}
-
-		private void init() {
-			if (this.csrfToken != null) {
-				return;
-			}
-			this.csrfToken = CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.loadToken(this.request);
-			this.missingToken = (this.csrfToken == null);
-			if (this.missingToken) {
-				this.csrfToken = CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.generateToken(this.request);
-				CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.saveToken(this.csrfToken, this.request,
-						this.response);
-			}
-		}
-
-	}
-
 }

+ 9 - 6
web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java

@@ -16,20 +16,22 @@
 
 package org.springframework.security.web.csrf;
 
+import java.util.function.Supplier;
+
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
 import org.springframework.util.Assert;
 
 /**
- * An interface that is used to determine the {@link CsrfToken} to use and make the
- * {@link CsrfToken} available as a request attribute. Implementations of this interface
- * may choose to perform additional tasks or customize how the token is made available to
- * the application through request attributes.
+ * A callback interface that is used to make the {@link CsrfToken} created by the
+ * {@link CsrfTokenRepository} available as a request attribute. Implementations of this
+ * interface may choose to perform additional tasks or customize how the token is made
+ * available to the application through request attributes.
  *
  * @author Steve Riesenberg
  * @since 5.8
- * @see CsrfTokenRepositoryRequestHandler
+ * @see CsrfTokenRequestAttributeHandler
  */
 @FunctionalInterface
 public interface CsrfTokenRequestHandler extends CsrfTokenRequestResolver {
@@ -38,8 +40,9 @@ public interface CsrfTokenRequestHandler extends CsrfTokenRequestResolver {
 	 * Handles a request using a {@link CsrfToken}.
 	 * @param request the {@code HttpServletRequest} being handled
 	 * @param response the {@code HttpServletResponse} being handled
+	 * @param csrfToken the {@link CsrfToken} created by the {@link CsrfTokenRepository}
 	 */
-	DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response);
+	void handle(HttpServletRequest request, HttpServletResponse response, Supplier<CsrfToken> csrfToken);
 
 	@Override
 	default String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) {

+ 1 - 1
web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestResolver.java

@@ -25,7 +25,7 @@ import javax.servlet.http.HttpServletRequest;
  *
  * @author Steve Riesenberg
  * @since 5.8
- * @see CsrfTokenRepositoryRequestHandler
+ * @see CsrfTokenRequestAttributeHandler
  */
 @FunctionalInterface
 public interface CsrfTokenRequestResolver {

+ 2 - 1
web/src/main/java/org/springframework/security/web/csrf/DeferredCsrfToken.java

@@ -20,11 +20,12 @@ package org.springframework.security.web.csrf;
  * An interface that allows delayed access to a {@link CsrfToken} that may be generated.
  *
  * @author Rob Winch
+ * @author Steve Riesenberg
  * @since 5.8
  */
 public interface DeferredCsrfToken {
 
-	/***
+	/**
 	 * Gets the {@link CsrfToken}
 	 * @return a non-null {@link CsrfToken}
 	 */

+ 3 - 2
web/src/main/java/org/springframework/security/web/csrf/LazyCsrfTokenRepository.java

@@ -27,8 +27,9 @@ import org.springframework.util.Assert;
  *
  * @author Rob Winch
  * @since 4.1
- * @deprecated Use org.springframework.security.web.csrf.CsrfTokenRequestHandler which
- * returns a {@link DeferredCsrfToken}
+ * @deprecated Use
+ * {@link CsrfTokenRepository#loadDeferredToken(HttpServletRequest, HttpServletResponse)}
+ * which returns a {@link DeferredCsrfToken}
  */
 @Deprecated
 public final class LazyCsrfTokenRepository implements CsrfTokenRepository {

+ 71 - 0
web/src/main/java/org/springframework/security/web/csrf/RepositoryDeferredCsrfToken.java

@@ -0,0 +1,71 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.csrf;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+/**
+ * @author Rob Winch
+ * @author Steve Riesenberg
+ * @since 5.8
+ */
+final class RepositoryDeferredCsrfToken implements DeferredCsrfToken {
+
+	private final CsrfTokenRepository csrfTokenRepository;
+
+	private final HttpServletRequest request;
+
+	private final HttpServletResponse response;
+
+	private CsrfToken csrfToken;
+
+	private boolean missingToken;
+
+	RepositoryDeferredCsrfToken(CsrfTokenRepository csrfTokenRepository, HttpServletRequest request,
+			HttpServletResponse response) {
+		this.csrfTokenRepository = csrfTokenRepository;
+		this.request = request;
+		this.response = response;
+	}
+
+	@Override
+	public CsrfToken get() {
+		init();
+		return this.csrfToken;
+	}
+
+	@Override
+	public boolean isGenerated() {
+		init();
+		return this.missingToken;
+	}
+
+	private void init() {
+		if (this.csrfToken != null) {
+			return;
+		}
+
+		this.csrfToken = this.csrfTokenRepository.loadToken(this.request);
+		this.missingToken = (this.csrfToken == null);
+		if (this.missingToken) {
+			this.csrfToken = this.csrfTokenRepository.generateToken(this.request);
+			this.csrfTokenRepository.saveToken(this.csrfToken, this.request, this.response);
+		}
+	}
+
+}

+ 29 - 1
web/src/test/java/org/springframework/security/web/csrf/CookieCsrfTokenRepositoryTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@ import org.springframework.mock.web.MockHttpServletResponse;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
 
 /**
  * @author Rob Winch
@@ -246,6 +247,33 @@ public class CookieCsrfTokenRepositoryTests {
 		assertThat(loadToken.getToken()).isEqualTo(value);
 	}
 
+	@Test
+	public void loadDeferredTokenWhenDoesNotExistThenGeneratedAndSaved() {
+		DeferredCsrfToken deferredCsrfToken = this.repository.loadDeferredToken(this.request, this.response);
+		CsrfToken csrfToken = deferredCsrfToken.get();
+		assertThat(csrfToken).isNotNull();
+		assertThat(deferredCsrfToken.isGenerated()).isTrue();
+		Cookie tokenCookie = this.response.getCookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME);
+		assertThat(tokenCookie).isNotNull();
+		assertThat(tokenCookie.getMaxAge()).isEqualTo(-1);
+		assertThat(tokenCookie.getName()).isEqualTo(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME);
+		assertThat(tokenCookie.getPath()).isEqualTo(this.request.getContextPath());
+		assertThat(tokenCookie.getSecure()).isEqualTo(this.request.isSecure());
+		assertThat(tokenCookie.getValue()).isEqualTo(csrfToken.getToken());
+		assertThat(tokenCookie.isHttpOnly()).isEqualTo(true);
+	}
+
+	@Test
+	public void loadDeferredTokenWhenExistsThenLoaded() {
+		CsrfToken generatedToken = this.repository.generateToken(this.request);
+		this.request
+				.setCookies(new Cookie(CookieCsrfTokenRepository.DEFAULT_CSRF_COOKIE_NAME, generatedToken.getToken()));
+		DeferredCsrfToken deferredCsrfToken = this.repository.loadDeferredToken(this.request, this.response);
+		CsrfToken csrfToken = deferredCsrfToken.get();
+		assertThatCsrfToken(csrfToken).isEqualTo(generatedToken);
+		assertThat(deferredCsrfToken.isGenerated()).isFalse();
+	}
+
 	@Test
 	public void setCookieNameNullIllegalArgumentException() {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.repository.setCookieName(null));

+ 11 - 16
web/src/test/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategyTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2013 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -82,23 +82,25 @@ public class CsrfAuthenticationStrategyTests {
 
 	@Test
 	public void onAuthenticationWhenCustomRequestHandlerThenUsed() {
+		given(this.csrfTokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(this.existingToken, false));
+
 		CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
 		this.strategy.setRequestHandler(requestHandler);
 		this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request,
 				this.response);
-		verify(requestHandler).handle(eq(this.request), eq(this.response));
+		verify(requestHandler).handle(eq(this.request), eq(this.response), any());
 		verifyNoMoreInteractions(requestHandler);
 	}
 
 	@Test
-	public void logoutRemovesCsrfTokenAndSavesNew() {
-		given(this.csrfTokenRepository.loadToken(this.request)).willReturn(null, this.existingToken);
-		given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken);
+	public void logoutRemovesCsrfTokenAndLoadsNewDeferredCsrfToken() {
+		given(this.csrfTokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(this.generatedToken, false));
 		this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request,
 				this.response);
 		verify(this.csrfTokenRepository).saveToken(null, this.request, this.response);
-		verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken), any(HttpServletRequest.class),
-				any(HttpServletResponse.class));
+		verify(this.csrfTokenRepository).loadDeferredToken(this.request, this.response);
 		// SEC-2404, SEC-2832
 		CsrfToken tokenInRequest = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
 		assertThat(tokenInRequest.getToken()).isSameAs(this.generatedToken.getToken());
@@ -119,17 +121,10 @@ public class CsrfAuthenticationStrategyTests {
 				any(HttpServletResponse.class));
 		CsrfToken tokenInRequest = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
 		tokenInRequest.getToken();
+		verify(this.csrfTokenRepository).loadToken(this.request);
+		verify(this.csrfTokenRepository).generateToken(this.request);
 		verify(this.csrfTokenRepository).saveToken(eq(this.generatedToken), any(HttpServletRequest.class),
 				any(HttpServletResponse.class));
 	}
 
-	@Test
-	public void logoutWhenNoCsrfToken() {
-		given(this.csrfTokenRepository.generateToken(this.request)).willReturn(this.generatedToken);
-		this.strategy.onAuthentication(new TestingAuthenticationToken("user", "password", "ROLE_USER"), this.request,
-				this.response);
-		verify(this.csrfTokenRepository).saveToken(any(CsrfToken.class), any(HttpServletRequest.class),
-				any(HttpServletResponse.class));
-	}
-
 }

+ 50 - 60
web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java

@@ -44,7 +44,6 @@ import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.lenient;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
-import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
@@ -86,11 +85,7 @@ public class CsrfFilterTests {
 	}
 
 	private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) {
-		return createCsrfFilter(new CsrfTokenRepositoryRequestHandler(repository));
-	}
-
-	private CsrfFilter createCsrfFilter(CsrfTokenRequestHandler requestHandler) {
-		CsrfFilter filter = new CsrfFilter(requestHandler);
+		CsrfFilter filter = new CsrfFilter(repository);
 		filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
 		filter.setAccessDeniedHandler(this.deniedHandler);
 		return filter;
@@ -103,7 +98,7 @@ public class CsrfFilterTests {
 
 	@Test
 	public void constructorNullRepository() {
-		assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter((CsrfTokenRequestHandler) null));
+		assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter(null));
 	}
 
 	// SEC-2276
@@ -128,7 +123,8 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterAccessDeniedNoTokenPresent() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(this.token, false));
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
@@ -139,7 +135,8 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterAccessDeniedIncorrectTokenPresent() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(this.token, false));
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
@@ -151,7 +148,8 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterAccessDeniedIncorrectTokenPresentHeader() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(this.token, false));
 		this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
@@ -164,7 +162,8 @@ public class CsrfFilterTests {
 	public void doFilterAccessDeniedIncorrectTokenPresentHeaderPreferredOverParameter()
 			throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(this.token, false));
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
 		this.request.addHeader(this.token.getHeaderName(), this.token.getToken() + " INVALID");
 		this.filter.doFilter(this.request, this.response, this.filterChain);
@@ -177,7 +176,8 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterNotCsrfRequestExistingToken() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(false);
-		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(this.token, false));
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
@@ -188,7 +188,8 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterNotCsrfRequestGenerateToken() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(false);
-		given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(this.token, true));
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
@@ -199,7 +200,8 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterIsCsrfRequestExistingTokenHeader() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(this.token, false));
 		this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
@@ -212,7 +214,8 @@ public class CsrfFilterTests {
 	public void doFilterIsCsrfRequestExistingTokenHeaderPreferredOverInvalidParam()
 			throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(this.token, false));
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken() + " INVALID");
 		this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
@@ -225,7 +228,8 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterIsCsrfRequestExistingToken() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(this.token, false));
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
@@ -239,7 +243,8 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterIsCsrfRequestGenerateToken() throws ServletException, IOException {
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(this.token, true));
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
@@ -247,17 +252,17 @@ public class CsrfFilterTests {
 		// LazyCsrfTokenRepository requires the response as an attribute
 		assertThat(this.request.getAttribute(HttpServletResponse.class.getName())).isEqualTo(this.response);
 		verify(this.filterChain).doFilter(this.request, this.response);
-		verify(this.tokenRepository).saveToken(this.token, this.request, this.response);
 		verifyNoMoreInteractions(this.deniedHandler);
 	}
 
 	@Test
 	public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException {
-		this.filter = createCsrfFilter(this.tokenRepository);
+		this.filter = new CsrfFilter(this.tokenRepository);
 		this.filter.setAccessDeniedHandler(this.deniedHandler);
 		for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
 			resetRequestResponse();
-			given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
+			given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+					.willReturn(new TestDeferredCsrfToken(this.token, false));
 			this.request.setMethod(method);
 			this.filter.doFilter(this.request, this.response, this.filterChain);
 			verify(this.filterChain).doFilter(this.request, this.response);
@@ -273,11 +278,12 @@ public class CsrfFilterTests {
 	 */
 	@Test
 	public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception {
-		this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
+		this.filter = new CsrfFilter(this.tokenRepository);
 		this.filter.setAccessDeniedHandler(this.deniedHandler);
 		for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) {
 			resetRequestResponse();
-			given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
+			given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+					.willReturn(new TestDeferredCsrfToken(this.token, false));
 			this.request.setMethod(method);
 			this.filter.doFilter(this.request, this.response, this.filterChain);
 			verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
@@ -288,11 +294,12 @@ public class CsrfFilterTests {
 
 	@Test
 	public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException {
-		this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
+		this.filter = new CsrfFilter(this.tokenRepository);
 		this.filter.setAccessDeniedHandler(this.deniedHandler);
 		for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) {
 			resetRequestResponse();
-			given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
+			given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+					.willReturn(new TestDeferredCsrfToken(this.token, false));
 			this.request.setMethod(method);
 			this.filter.doFilter(this.request, this.response, this.filterChain);
 			verify(this.deniedHandler).handle(eq(this.request), eq(this.response),
@@ -303,10 +310,11 @@ public class CsrfFilterTests {
 
 	@Test
 	public void doFilterDefaultAccessDenied() throws ServletException, IOException {
-		this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
+		this.filter = new CsrfFilter(this.tokenRepository);
 		this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(this.token, false));
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
@@ -317,7 +325,7 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception {
 		CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
-		CsrfFilter filter = createCsrfFilter(repository);
+		CsrfFilter filter = new CsrfFilter(repository);
 		lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
 		MockHttpServletRequest request = new MockHttpServletRequest();
 		CsrfFilter.skipRequest(request);
@@ -333,7 +341,8 @@ public class CsrfFilterTests {
 		given(token.getToken()).willReturn(null);
 		given(token.getHeaderName()).willReturn(this.token.getHeaderName());
 		given(token.getParameterName()).willReturn(this.token.getParameterName());
-		given(this.tokenRepository.loadToken(this.request)).willReturn(token);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(token, false));
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
 		filter.doFilterInternal(this.request, this.response, this.filterChain);
 		assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
@@ -341,13 +350,15 @@ public class CsrfFilterTests {
 
 	@Test
 	public void doFilterWhenRequestHandlerThenUsed() throws Exception {
-		CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
-		given(requestHandler.handle(this.request, this.response))
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
 				.willReturn(new TestDeferredCsrfToken(this.token, false));
-		this.filter = createCsrfFilter(requestHandler);
+		CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
+		this.filter = createCsrfFilter(this.tokenRepository);
+		this.filter.setRequestHandler(requestHandler);
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
-		verify(requestHandler).handle(eq(this.request), eq(this.response));
+		verify(this.tokenRepository).loadDeferredToken(this.request, this.response);
+		verify(requestHandler).handle(eq(this.request), eq(this.response), any());
 		verify(this.filterChain).doFilter(this.request, this.response);
 	}
 
@@ -365,41 +376,20 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet()
 			throws ServletException, IOException {
+		CsrfFilter filter = createCsrfFilter(this.tokenRepository);
 		String csrfAttrName = "_csrf";
-		CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository);
+		CsrfTokenRequestAttributeHandler requestHandler = new CsrfTokenRequestAttributeHandler();
 		requestHandler.setCsrfRequestAttributeName(csrfAttrName);
-		this.filter = createCsrfFilter(requestHandler);
-		CsrfToken expectedCsrfToken = spy(this.token);
-		given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken);
+		filter.setRequestHandler(requestHandler);
+		CsrfToken expectedCsrfToken = mock(CsrfToken.class);
+		given(this.tokenRepository.loadDeferredToken(this.request, this.response))
+				.willReturn(new TestDeferredCsrfToken(expectedCsrfToken, true));
 
-		this.filter.doFilter(this.request, this.response, this.filterChain);
+		filter.doFilter(this.request, this.response, this.filterChain);
 
 		verifyNoInteractions(expectedCsrfToken);
 		CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
 		assertThatCsrfToken(tokenFromRequest).isEqualTo(expectedCsrfToken);
 	}
 
-	private static final class TestDeferredCsrfToken implements DeferredCsrfToken {
-
-		private final CsrfToken csrfToken;
-
-		private final boolean isGenerated;
-
-		private TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) {
-			this.csrfToken = csrfToken;
-			this.isGenerated = isGenerated;
-		}
-
-		@Override
-		public CsrfToken get() {
-			return this.csrfToken;
-		}
-
-		@Override
-		public boolean isGenerated() {
-			return this.isGenerated;
-		}
-
-	}
-
 }

+ 26 - 36
web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandlerTests.java → web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestAttributeHandlerTests.java

@@ -18,29 +18,22 @@ package org.springframework.security.web.csrf;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
-import org.junit.jupiter.api.extension.ExtendWith;
-import org.mockito.Mock;
-import org.mockito.junit.jupiter.MockitoExtension;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
-import static org.mockito.BDDMockito.given;
+import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
 import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
 
 /**
- * Tests for {@link CsrfTokenRepositoryRequestHandler}.
+ * Tests for {@link CsrfTokenRequestAttributeHandler}.
  *
  * @author Steve Riesenberg
  * @since 5.8
  */
-@ExtendWith(MockitoExtension.class)
-public class CsrfTokenRepositoryRequestHandlerTests {
-
-	@Mock
-	CsrfTokenRepository tokenRepository;
+public class CsrfTokenRequestAttributeHandlerTests {
 
 	private MockHttpServletRequest request;
 
@@ -48,76 +41,73 @@ public class CsrfTokenRepositoryRequestHandlerTests {
 
 	private CsrfToken token;
 
-	private CsrfTokenRepositoryRequestHandler handler;
+	private CsrfTokenRequestAttributeHandler handler;
 
 	@BeforeEach
 	public void setup() {
 		this.request = new MockHttpServletRequest();
 		this.response = new MockHttpServletResponse();
 		this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue");
-		this.handler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository);
+		this.handler = new CsrfTokenRequestAttributeHandler();
 	}
 
 	@Test
-	public void constructorWhenCsrfTokenRepositoryIsNullThenThrowsIllegalArgumentException() {
-		// @formatter:off
+	public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() {
 		assertThatIllegalArgumentException()
-				.isThrownBy(() -> new CsrfTokenRepositoryRequestHandler(null))
-				.withMessage("csrfTokenRepository cannot be null");
-		// @formatter:on
+				.isThrownBy(() -> this.handler.handle(null, this.response, () -> this.token))
+				.withMessage("request cannot be null");
 	}
 
 	@Test
-	public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() {
+	public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() {
 		// @formatter:off
 		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.handler.handle(null, this.response))
-				.withMessage("request cannot be null");
+				.isThrownBy(() -> this.handler.handle(this.request, null, () -> this.token))
+				.withMessage("response cannot be null");
 		// @formatter:on
 	}
 
 	@Test
-	public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() {
+	public void handleWhenCsrfTokenSupplierIsNullThenThrowsIllegalArgumentException() {
+		assertThatIllegalArgumentException().isThrownBy(() -> this.handler.handle(this.request, this.response, null))
+				.withMessage("deferredCsrfToken cannot be null");
+	}
+
+	@Test
+	public void handleWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() {
 		// @formatter:off
-		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.handler.handle(this.request, null))
-				.withMessage("response cannot be null");
+		this.handler.setCsrfRequestAttributeName(null);
+		assertThatIllegalStateException()
+				.isThrownBy(() -> this.handler.handle(this.request, this.response, () -> null))
+				.withMessage("csrfTokenSupplier returned null delegate");
 		// @formatter:on
 	}
 
 	@Test
 	public void handleWhenCsrfRequestAttributeSetThenUsed() {
-		given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
 		this.handler.setCsrfRequestAttributeName("_csrf");
-		this.handler.handle(this.request, this.response);
+		this.handler.handle(this.request, this.response, () -> this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token);
 	}
 
 	@Test
 	public void handleWhenValidParametersThenRequestAttributesSet() {
-		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
-		this.handler.handle(this.request, this.response);
+		this.handler.handle(this.request, this.response, () -> this.token);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 	}
 
 	@Test
 	public void resolveCsrfTokenValueWhenRequestIsNullThenThrowsIllegalArgumentException() {
-		// @formatter:off
-		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token))
+		assertThatIllegalArgumentException().isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token))
 				.withMessage("request cannot be null");
-		// @formatter:on
 	}
 
 	@Test
 	public void resolveCsrfTokenValueWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() {
-		// @formatter:off
-		assertThatIllegalArgumentException()
-				.isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.request, null))
+		assertThatIllegalArgumentException().isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.request, null))
 				.withMessage("csrfToken cannot be null");
-		// @formatter:on
 	}
 
 	@Test

+ 22 - 1
web/src/test/java/org/springframework/security/web/csrf/HttpSessionCsrfTokenRepositoryTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2013 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@ import org.springframework.mock.web.MockHttpServletResponse;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
 
 /**
  * @author Rob Winch
@@ -85,6 +86,26 @@ public class HttpSessionCsrfTokenRepositoryTests {
 		assertThat(this.repo.loadToken(this.request)).isNull();
 	}
 
+	@Test
+	public void loadDeferredTokenWhenDoesNotExistThenGeneratedAndSaved() {
+		DeferredCsrfToken deferredCsrfToken = this.repo.loadDeferredToken(this.request, this.response);
+		CsrfToken csrfToken = deferredCsrfToken.get();
+		assertThat(csrfToken).isNotNull();
+		assertThat(deferredCsrfToken.isGenerated()).isTrue();
+		String attrName = this.request.getSession().getAttributeNames().nextElement();
+		assertThatCsrfToken(this.request.getSession().getAttribute(attrName)).isEqualTo(csrfToken);
+	}
+
+	@Test
+	public void loadDeferredTokenWhenExistsThenLoaded() {
+		CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def");
+		this.repo.saveToken(tokenToSave, this.request, this.response);
+		DeferredCsrfToken deferredCsrfToken = this.repo.loadDeferredToken(this.request, this.response);
+		CsrfToken csrfToken = deferredCsrfToken.get();
+		assertThatCsrfToken(csrfToken).isEqualTo(tokenToSave);
+		assertThat(deferredCsrfToken.isGenerated()).isFalse();
+	}
+
 	@Test
 	public void saveToken() {
 		CsrfToken tokenToSave = new DefaultCsrfToken("123", "abc", "def");

+ 40 - 0
web/src/test/java/org/springframework/security/web/csrf/TestDeferredCsrfToken.java

@@ -0,0 +1,40 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.csrf;
+
+final class TestDeferredCsrfToken implements DeferredCsrfToken {
+
+	private final CsrfToken csrfToken;
+
+	private final boolean isGenerated;
+
+	TestDeferredCsrfToken(CsrfToken csrfToken, boolean isGenerated) {
+		this.csrfToken = csrfToken;
+		this.isGenerated = isGenerated;
+	}
+
+	@Override
+	public CsrfToken get() {
+		return this.csrfToken;
+	}
+
+	@Override
+	public boolean isGenerated() {
+		return this.isGenerated;
+	}
+
+}