Browse Source

CsrfTokenRequestHandler extends CsrfTokenRequestResolver

Closes gh-11896
Steve Riesenberg 2 years ago
parent
commit
46696a9226
18 changed files with 155 additions and 188 deletions
  1. 9 21
      config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java
  2. 10 12
      config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java
  3. 0 3
      config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc
  4. 1 7
      config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd
  5. 2 2
      config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java
  6. 5 10
      config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java
  7. 1 1
      config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml
  8. 1 1
      config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml
  9. 1 5
      docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc
  10. 3 3
      test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java
  11. 7 4
      test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java
  12. 1 4
      web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java
  13. 20 40
      web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java
  14. 21 25
      web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandler.java
  15. 19 6
      web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java
  16. 1 1
      web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestResolver.java
  17. 17 27
      web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java
  18. 36 16
      web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandlerTests.java

+ 9 - 21
config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java

@@ -36,8 +36,8 @@ import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfLogoutHandler;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
+import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
 import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
-import org.springframework.security.web.csrf.CsrfTokenRequestResolver;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
 import org.springframework.security.web.csrf.MissingCsrfTokenException;
@@ -93,8 +93,6 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 
 	private CsrfTokenRequestHandler requestHandler;
 
-	private CsrfTokenRequestResolver requestResolver;
-
 	private final ApplicationContext context;
 
 	/**
@@ -135,23 +133,13 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 	 * available as a request attribute.
 	 * @param requestHandler the {@link CsrfTokenRequestHandler} to use
 	 * @return the {@link CsrfConfigurer} for further customizations
+	 * @since 5.8
 	 */
 	public CsrfConfigurer<H> csrfTokenRequestHandler(CsrfTokenRequestHandler requestHandler) {
 		this.requestHandler = requestHandler;
 		return this;
 	}
 
-	/**
-	 * Specify a {@link CsrfTokenRequestResolver} to use for resolving the token value
-	 * from the request.
-	 * @param requestResolver the {@link CsrfTokenRequestResolver} to use
-	 * @return the {@link CsrfConfigurer} for further customizations
-	 */
-	public CsrfConfigurer<H> csrfTokenRequestResolver(CsrfTokenRequestResolver requestResolver) {
-		this.requestResolver = requestResolver;
-		return this;
-	}
-
 	/**
 	 * <p>
 	 * Allows specifying {@link HttpServletRequest} that should not use CSRF Protection
@@ -229,7 +217,13 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 	@SuppressWarnings("unchecked")
 	@Override
 	public void configure(H http) {
-		CsrfFilter filter = new CsrfFilter(this.csrfTokenRepository);
+		CsrfFilter filter;
+		if (this.requestHandler != null) {
+			filter = new CsrfFilter(this.requestHandler);
+		}
+		else {
+			filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.csrfTokenRepository));
+		}
 		RequestMatcher requireCsrfProtectionMatcher = getRequireCsrfProtectionMatcher();
 		if (requireCsrfProtectionMatcher != null) {
 			filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);
@@ -246,12 +240,6 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 		if (sessionConfigurer != null) {
 			sessionConfigurer.addSessionAuthenticationStrategy(getSessionAuthenticationStrategy());
 		}
-		if (this.requestHandler != null) {
-			filter.setRequestHandler(this.requestHandler);
-		}
-		if (this.requestResolver != null) {
-			filter.setRequestResolver(this.requestResolver);
-		}
 		filter = postProcess(filter);
 		http.addFilter(filter);
 	}

+ 10 - 12
config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java

@@ -41,6 +41,7 @@ import org.springframework.security.web.access.DelegatingAccessDeniedHandler;
 import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfLogoutHandler;
+import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
 import org.springframework.security.web.csrf.MissingCsrfTokenException;
@@ -73,8 +74,6 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
 
 	private static final String ATT_REQUEST_HANDLER = "request-handler-ref";
 
-	private static final String ATT_REQUEST_RESOLVER = "request-resolver-ref";
-
 	private String csrfRepositoryRef;
 
 	private BeanDefinition csrfFilter;
@@ -83,8 +82,6 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
 
 	private String requestHandlerRef;
 
-	private String requestResolverRef;
-
 	@Override
 	public BeanDefinition parse(Element element, ParserContext pc) {
 		boolean disabled = element != null && "true".equals(element.getAttribute("disabled"));
@@ -104,7 +101,6 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
 			this.csrfRepositoryRef = element.getAttribute(ATT_REPOSITORY);
 			this.requestMatcherRef = element.getAttribute(ATT_MATCHER);
 			this.requestHandlerRef = element.getAttribute(ATT_REQUEST_HANDLER);
-			this.requestResolverRef = element.getAttribute(ATT_REQUEST_RESOLVER);
 		}
 		if (!StringUtils.hasText(this.csrfRepositoryRef)) {
 			RootBeanDefinition csrfTokenRepository = new RootBeanDefinition(HttpSessionCsrfTokenRepository.class);
@@ -116,15 +112,17 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
 					new BeanComponentDefinition(lazyTokenRepository.getBeanDefinition(), this.csrfRepositoryRef));
 		}
 		BeanDefinitionBuilder builder = BeanDefinitionBuilder.rootBeanDefinition(CsrfFilter.class);
-		builder.addConstructorArgReference(this.csrfRepositoryRef);
-		if (StringUtils.hasText(this.requestMatcherRef)) {
-			builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef);
+		if (!StringUtils.hasText(this.requestHandlerRef)) {
+			BeanDefinition csrfTokenRequestHandler = BeanDefinitionBuilder
+					.rootBeanDefinition(CsrfTokenRepositoryRequestHandler.class)
+					.addConstructorArgReference(this.csrfRepositoryRef).getBeanDefinition();
+			builder.addConstructorArgValue(csrfTokenRequestHandler);
 		}
-		if (StringUtils.hasText(this.requestHandlerRef)) {
-			builder.addPropertyReference("requestHandler", this.requestHandlerRef);
+		else {
+			builder.addConstructorArgReference(this.requestHandlerRef);
 		}
-		if (StringUtils.hasText(this.requestResolverRef)) {
-			builder.addPropertyReference("requestResolver", this.requestResolverRef);
+		if (StringUtils.hasText(this.requestMatcherRef)) {
+			builder.addPropertyReference("requireCsrfProtectionMatcher", this.requestMatcherRef);
 		}
 		this.csrfFilter = builder.getBeanDefinition();
 		return this.csrfFilter;

+ 0 - 3
config/src/main/resources/org/springframework/security/config/spring-security-5.8.rnc

@@ -1154,9 +1154,6 @@ csrf-options.attlist &=
 csrf-options.attlist &=
 	## The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor.
 	attribute request-handler-ref { xsd:token }?
-csrf-options.attlist &=
-	## The CsrfTokenRequestResolver to use. The default is CsrfTokenRequestProcessor.
-	attribute request-resolver-ref { xsd:token }?
 
 headers =
 ## Element for configuration of the HeaderWritersFilter. Enables easy setting for the X-Frame-Options, X-XSS-Protection and X-Content-Type-Options headers.

+ 1 - 7
config/src/main/resources/org/springframework/security/config/spring-security-5.8.xsd

@@ -3258,13 +3258,7 @@
       </xs:attribute>
       <xs:attribute name="request-handler-ref" type="xs:token">
          <xs:annotation>
-            <xs:documentation>The CsrfTokenRequestHandler to use. The default is CsrfTokenRequestProcessor.
-                </xs:documentation>
-         </xs:annotation>
-      </xs:attribute>
-      <xs:attribute name="request-resolver-ref" type="xs:token">
-         <xs:annotation>
-            <xs:documentation>The CsrfTokenRequestResolver to use. The default is CsrfTokenRequestProcessor.
+            <xs:documentation>The CsrfTokenRequestHandler to use. The default is CsrfTokenRepositoryRequestHandler.
                 </xs:documentation>
          </xs:annotation>
       </xs:attribute>

+ 2 - 2
config/src/test/java/org/springframework/security/config/annotation/web/configuration/DeferHttpSessionJavaConfigTests.java

@@ -33,7 +33,7 @@ import org.springframework.security.config.test.SpringTestContext;
 import org.springframework.security.config.test.SpringTestContextExtension;
 import org.springframework.security.web.DefaultSecurityFilterChain;
 import org.springframework.security.web.FilterChainProxy;
-import org.springframework.security.web.csrf.CsrfTokenRequestProcessor;
+import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.security.web.csrf.LazyCsrfTokenRepository;
 import org.springframework.security.web.savedrequest.HttpSessionRequestCache;
@@ -85,7 +85,7 @@ public class DeferHttpSessionJavaConfigTests {
 			csrfRepository.setDeferLoadToken(true);
 			HttpSessionRequestCache requestCache = new HttpSessionRequestCache();
 			requestCache.setMatchingRequestParameterName("continue");
-			CsrfTokenRequestProcessor requestHandler = new CsrfTokenRequestProcessor();
+			CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler();
 			requestHandler.setCsrfRequestAttributeName("_csrf");
 			// @formatter:off
 			http

+ 5 - 10
config/src/test/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.java

@@ -44,7 +44,7 @@ import org.springframework.security.web.access.AccessDeniedHandler;
 import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
-import org.springframework.security.web.csrf.CsrfTokenRequestProcessor;
+import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
 import org.springframework.security.web.csrf.DefaultCsrfToken;
 import org.springframework.security.web.firewall.StrictHttpFirewall;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@@ -422,8 +422,7 @@ public class CsrfConfigurerTests {
 		CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
 		CsrfToken csrfToken = new DefaultCsrfToken("X-CSRF-TOKEN", "_csrf", "token");
 		given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken);
-		CsrfTokenRequestProcessorConfig.PROCESSOR = new CsrfTokenRequestProcessor();
-		CsrfTokenRequestProcessorConfig.PROCESSOR.setTokenRepository(csrfTokenRepository);
+		CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository);
 		this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire();
 		this.mvc.perform(get("/login")).andExpect(status().isOk())
 				.andExpect(content().string(containsString(csrfToken.getToken())));
@@ -440,8 +439,7 @@ public class CsrfConfigurerTests {
 		CsrfTokenRepository csrfTokenRepository = mock(CsrfTokenRepository.class);
 		given(csrfTokenRepository.loadToken(any(HttpServletRequest.class))).willReturn(null, csrfToken);
 		given(csrfTokenRepository.generateToken(any(HttpServletRequest.class))).willReturn(csrfToken);
-		CsrfTokenRequestProcessorConfig.PROCESSOR = new CsrfTokenRequestProcessor();
-		CsrfTokenRequestProcessorConfig.PROCESSOR.setTokenRepository(csrfTokenRepository);
+		CsrfTokenRequestProcessorConfig.HANDLER = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository);
 
 		this.spring.register(CsrfTokenRequestProcessorConfig.class, BasicController.class).autowire();
 		// @formatter:off
@@ -803,7 +801,7 @@ public class CsrfConfigurerTests {
 	@EnableWebSecurity
 	static class CsrfTokenRequestProcessorConfig {
 
-		static CsrfTokenRequestProcessor PROCESSOR;
+		static CsrfTokenRepositoryRequestHandler HANDLER;
 
 		@Bean
 		SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
@@ -813,10 +811,7 @@ public class CsrfConfigurerTests {
 					.anyRequest().authenticated()
 				)
 				.formLogin(Customizer.withDefaults())
-				.csrf((csrf) -> csrf
-					.csrfTokenRequestHandler(PROCESSOR)
-					.csrfTokenRequestResolver(PROCESSOR)
-				);
+				.csrf((csrf) -> csrf.csrfTokenRequestHandler(HANDLER));
 			// @formatter:on
 
 			return http.build();

+ 1 - 1
config/src/test/resources/org/springframework/security/config/http/CsrfConfigTests-WithRequestAttrName.xml

@@ -26,7 +26,7 @@
 		<csrf request-handler-ref="requestHandler"/>
 	</http>
 
-	<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRequestProcessor"
+	<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler"
 		p:csrfRequestAttributeName="csrf-attribute-name"/>
 	<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
 </b:beans>

+ 1 - 1
config/src/test/resources/org/springframework/security/config/http/DeferHttpSessionTests-Explicit.xml

@@ -42,7 +42,7 @@
 	<b:bean id="csrfRepository" class="org.springframework.security.web.csrf.LazyCsrfTokenRepository"
 		c:delegate-ref="httpSessionCsrfRepository"
 	 	p:deferLoadToken="true"/>
-	<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRequestProcessor"
+	<b:bean id="requestHandler" class="org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler"
 		p:csrfRequestAttributeName="_csrf"/>
 	<b:import resource="CsrfConfigTests-shared-userservice.xml"/>
 </b:beans>

+ 1 - 5
docs/modules/ROOT/pages/servlet/appendix/namespace/http.adoc

@@ -777,11 +777,7 @@ The default is `HttpSessionCsrfTokenRepository`.
 
 [[nsa-csrf-request-handler-ref]]
 * **request-handler-ref**
-The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRequestProcessor`.
-
-[[nsa-csrf-request-resolver-ref]]
-* **request-resolver-ref**
-The optional `CsrfTokenRequestResolver` to use. The default is `CsrfTokenRequestProcessor`.
+The optional `CsrfTokenRequestHandler` to use. The default is `CsrfTokenRepositoryRequestHandler`.
 
 [[nsa-csrf-request-matcher-ref]]
 * **request-matcher-ref**

+ 3 - 3
test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java

@@ -31,8 +31,8 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
+import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
 import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
-import org.springframework.security.web.csrf.CsrfTokenRequestProcessor;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.web.context.WebApplicationContext;
@@ -48,7 +48,7 @@ public abstract class WebTestUtils {
 
 	private static final SecurityContextRepository DEFAULT_CONTEXT_REPO = new HttpSessionSecurityContextRepository();
 
-	private static final CsrfTokenRequestProcessor DEFAULT_CSRF_PROCESSOR = new CsrfTokenRequestProcessor();
+	private static final CsrfTokenRepositoryRequestHandler DEFAULT_CSRF_HANDLER = new CsrfTokenRepositoryRequestHandler();
 
 	private WebTestUtils() {
 	}
@@ -104,7 +104,7 @@ public abstract class WebTestUtils {
 	public static CsrfTokenRequestHandler getCsrfTokenRequestHandler(HttpServletRequest request) {
 		CsrfFilter filter = findFilter(request, CsrfFilter.class);
 		if (filter == null) {
-			return DEFAULT_CSRF_PROCESSOR;
+			return DEFAULT_CSRF_HANDLER;
 		}
 		return (CsrfTokenRequestHandler) ReflectionTestUtils.getField(filter, "requestHandler");
 	}

+ 7 - 4
test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java

@@ -39,7 +39,7 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
-import org.springframework.security.web.csrf.CsrfTokenRequestProcessor;
+import org.springframework.security.web.csrf.CsrfTokenRepositoryRequestHandler;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.security.web.util.matcher.AnyRequestMatcher;
 import org.springframework.web.context.WebApplicationContext;
@@ -75,19 +75,22 @@ public class WebTestUtilsTests {
 
 	@Test
 	public void getCsrfTokenRepositorytNoWac() {
-		assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class);
+		assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request))
+				.isInstanceOf(CsrfTokenRepositoryRequestHandler.class);
 	}
 
 	@Test
 	public void getCsrfTokenRepositorytNoSecurity() {
 		loadConfig(Config.class);
-		assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class);
+		assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request))
+				.isInstanceOf(CsrfTokenRepositoryRequestHandler.class);
 	}
 
 	@Test
 	public void getCsrfTokenRepositorytSecurityNoCsrf() {
 		loadConfig(SecurityNoCsrfConfig.class);
-		assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request)).isInstanceOf(CsrfTokenRequestProcessor.class);
+		assertThat(WebTestUtils.getCsrfTokenRequestHandler(this.request))
+				.isInstanceOf(CsrfTokenRepositoryRequestHandler.class);
 	}
 
 	@Test

+ 1 - 4
web/src/main/java/org/springframework/security/web/csrf/CsrfAuthenticationStrategy.java

@@ -48,10 +48,7 @@ public final class CsrfAuthenticationStrategy implements SessionAuthenticationSt
 	 * @param csrfTokenRepository the {@link CsrfTokenRepository} to use
 	 */
 	public CsrfAuthenticationStrategy(CsrfTokenRepository csrfTokenRepository) {
-		Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
-		CsrfTokenRequestProcessor processor = new CsrfTokenRequestProcessor();
-		processor.setTokenRepository(csrfTokenRepository);
-		this.requestHandler = processor;
+		this.requestHandler = new CsrfTokenRepositoryRequestHandler(csrfTokenRepository);
 		this.csrfTokenRepository = csrfTokenRepository;
 	}
 

+ 20 - 40
web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java

@@ -82,20 +82,30 @@ public final class CsrfFilter extends OncePerRequestFilter {
 
 	private final Log logger = LogFactory.getLog(getClass());
 
+	private final CsrfTokenRequestHandler requestHandler;
+
 	private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
 
 	private AccessDeniedHandler accessDeniedHandler = new AccessDeniedHandlerImpl();
 
-	private CsrfTokenRequestHandler requestHandler;
-
-	private CsrfTokenRequestResolver requestResolver;
-
+	/**
+	 * Creates a new instance.
+	 * @param csrfTokenRepository the {@link CsrfTokenRepository} to use
+	 * @deprecated Use {@link CsrfFilter#CsrfFilter(CsrfTokenRequestHandler)} instead
+	 */
+	@Deprecated
 	public CsrfFilter(CsrfTokenRepository csrfTokenRepository) {
-		Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
-		CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor();
-		csrfTokenRequestProcessor.setTokenRepository(csrfTokenRepository);
-		this.requestHandler = csrfTokenRequestProcessor;
-		this.requestResolver = csrfTokenRequestProcessor;
+		this(new CsrfTokenRepositoryRequestHandler(csrfTokenRepository));
+	}
+
+	/**
+	 * Creates a new instance.
+	 * @param requestHandler the {@link CsrfTokenRequestHandler} to use. Default is
+	 * {@link CsrfTokenRepositoryRequestHandler}.
+	 */
+	public CsrfFilter(CsrfTokenRequestHandler requestHandler) {
+		Assert.notNull(requestHandler, "requestHandler cannot be null");
+		this.requestHandler = requestHandler;
 	}
 
 	@Override
@@ -116,7 +126,7 @@ public final class CsrfFilter extends OncePerRequestFilter {
 			return;
 		}
 		CsrfToken csrfToken = deferredCsrfToken.get();
-		String actualToken = this.requestResolver.resolveCsrfTokenValue(request, csrfToken);
+		String actualToken = this.requestHandler.resolveCsrfTokenValue(request, csrfToken);
 		if (!equalsConstantTime(csrfToken.getToken(), actualToken)) {
 			boolean missingToken = deferredCsrfToken.isGenerated();
 			this.logger.debug(
@@ -164,36 +174,6 @@ public final class CsrfFilter extends OncePerRequestFilter {
 		this.accessDeniedHandler = accessDeniedHandler;
 	}
 
-	/**
-	 * Specifies a {@link CsrfTokenRequestHandler} that is used to make the
-	 * {@link CsrfToken} available as a request attribute.
-	 *
-	 * <p>
-	 * The default is {@link CsrfTokenRequestProcessor}.
-	 * </p>
-	 * @param requestHandler the {@link CsrfTokenRequestHandler} to use
-	 * @since 5.8
-	 */
-	public void setRequestHandler(CsrfTokenRequestHandler requestHandler) {
-		Assert.notNull(requestHandler, "requestHandler cannot be null");
-		this.requestHandler = requestHandler;
-	}
-
-	/**
-	 * Specifies a {@link CsrfTokenRequestResolver} that is used to resolve the token
-	 * value from the request.
-	 *
-	 * <p>
-	 * The default is {@link CsrfTokenRequestProcessor}.
-	 * </p>
-	 * @param requestResolver the {@link CsrfTokenRequestResolver} to use
-	 * @since 5.8
-	 */
-	public void setRequestResolver(CsrfTokenRequestResolver requestResolver) {
-		Assert.notNull(requestResolver, "requestResolver cannot be null");
-		this.requestResolver = requestResolver;
-	}
-
 	/**
 	 * Constant time comparison to prevent against timing attacks.
 	 * @param expected

+ 21 - 25
web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessor.java → web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandler.java

@@ -24,28 +24,34 @@ import javax.servlet.http.HttpServletResponse;
 import org.springframework.util.Assert;
 
 /**
- * An implementation of the {@link CsrfTokenRequestHandler} and
- * {@link CsrfTokenRequestResolver} interfaces that is capable of making the
- * {@link CsrfToken} available as a request attribute and resolving the token value as
- * either a header or parameter value of the request.
+ * An implementation of the {@link CsrfTokenRequestHandler} interface that is capable of
+ * making the {@link CsrfToken} available as a request attribute and resolving the token
+ * value as either a header or parameter value of the request.
  *
  * @author Steve Riesenberg
  * @since 5.8
  */
-public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfTokenRequestResolver {
+public class CsrfTokenRepositoryRequestHandler implements CsrfTokenRequestHandler {
+
+	private final CsrfTokenRepository csrfTokenRepository;
 
 	private String csrfRequestAttributeName;
 
-	private CsrfTokenRepository tokenRepository = new HttpSessionCsrfTokenRepository();
+	/**
+	 * Creates a new instance.
+	 */
+	public CsrfTokenRepositoryRequestHandler() {
+		this(new HttpSessionCsrfTokenRepository());
+	}
 
 	/**
-	 * Sets the {@link CsrfTokenRepository} to use.
-	 * @param tokenRepository the {@link CsrfTokenRepository} to use. Default
+	 * Creates a new instance.
+	 * @param csrfTokenRepository the {@link CsrfTokenRepository} to use. Default
 	 * {@link HttpSessionCsrfTokenRepository}
 	 */
-	public void setTokenRepository(CsrfTokenRepository tokenRepository) {
-		Assert.notNull(tokenRepository, "tokenRepository cannot be null");
-		this.tokenRepository = tokenRepository;
+	public CsrfTokenRepositoryRequestHandler(CsrfTokenRepository csrfTokenRepository) {
+		Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
+		this.csrfTokenRepository = csrfTokenRepository;
 	}
 
 	/**
@@ -75,17 +81,6 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfT
 		return deferredCsrfToken;
 	}
 
-	@Override
-	public String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) {
-		Assert.notNull(request, "request cannot be null");
-		Assert.notNull(csrfToken, "csrfToken cannot be null");
-		String actualToken = request.getHeader(csrfToken.getHeaderName());
-		if (actualToken == null) {
-			actualToken = request.getParameter(csrfToken.getParameterName());
-		}
-		return actualToken;
-	}
-
 	private static final class SupplierCsrfToken implements CsrfToken {
 
 		private final Supplier<CsrfToken> csrfTokenSupplier;
@@ -150,11 +145,12 @@ public class CsrfTokenRequestProcessor implements CsrfTokenRequestHandler, CsrfT
 			if (this.csrfToken != null) {
 				return;
 			}
-			this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.loadToken(this.request);
+			this.csrfToken = CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.loadToken(this.request);
 			this.missingToken = (this.csrfToken == null);
 			if (this.missingToken) {
-				this.csrfToken = CsrfTokenRequestProcessor.this.tokenRepository.generateToken(this.request);
-				CsrfTokenRequestProcessor.this.tokenRepository.saveToken(this.csrfToken, this.request, this.response);
+				this.csrfToken = CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.generateToken(this.request);
+				CsrfTokenRepositoryRequestHandler.this.csrfTokenRepository.saveToken(this.csrfToken, this.request,
+						this.response);
 			}
 		}
 

+ 19 - 6
web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestHandler.java

@@ -19,18 +19,20 @@ package org.springframework.security.web.csrf;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
+import org.springframework.util.Assert;
+
 /**
- * A callback interface that is used to determine the {@link CsrfToken} to use and make
- * the {@link CsrfToken} available as a request attribute. Implementations of this
- * interface may choose to perform additional tasks or customize how the token is made
- * available to the application through request attributes.
+ * An interface that is used to determine the {@link CsrfToken} to use and make the
+ * {@link CsrfToken} available as a request attribute. Implementations of this interface
+ * may choose to perform additional tasks or customize how the token is made available to
+ * the application through request attributes.
  *
  * @author Steve Riesenberg
  * @since 5.8
- * @see CsrfTokenRequestProcessor
+ * @see CsrfTokenRepositoryRequestHandler
  */
 @FunctionalInterface
-public interface CsrfTokenRequestHandler {
+public interface CsrfTokenRequestHandler extends CsrfTokenRequestResolver {
 
 	/**
 	 * Handles a request using a {@link CsrfToken}.
@@ -39,4 +41,15 @@ public interface CsrfTokenRequestHandler {
 	 */
 	DeferredCsrfToken handle(HttpServletRequest request, HttpServletResponse response);
 
+	@Override
+	default String resolveCsrfTokenValue(HttpServletRequest request, CsrfToken csrfToken) {
+		Assert.notNull(request, "request cannot be null");
+		Assert.notNull(csrfToken, "csrfToken cannot be null");
+		String actualToken = request.getHeader(csrfToken.getHeaderName());
+		if (actualToken == null) {
+			actualToken = request.getParameter(csrfToken.getParameterName());
+		}
+		return actualToken;
+	}
+
 }

+ 1 - 1
web/src/main/java/org/springframework/security/web/csrf/CsrfTokenRequestResolver.java

@@ -25,7 +25,7 @@ import javax.servlet.http.HttpServletRequest;
  *
  * @author Steve Riesenberg
  * @since 5.8
- * @see CsrfTokenRequestProcessor
+ * @see CsrfTokenRepositoryRequestHandler
  */
 @FunctionalInterface
 public interface CsrfTokenRequestResolver {

+ 17 - 27
web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java

@@ -86,7 +86,11 @@ public class CsrfFilterTests {
 	}
 
 	private CsrfFilter createCsrfFilter(CsrfTokenRepository repository) {
-		CsrfFilter filter = new CsrfFilter(repository);
+		return createCsrfFilter(new CsrfTokenRepositoryRequestHandler(repository));
+	}
+
+	private CsrfFilter createCsrfFilter(CsrfTokenRequestHandler requestHandler) {
+		CsrfFilter filter = new CsrfFilter(requestHandler);
 		filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
 		filter.setAccessDeniedHandler(this.deniedHandler);
 		return filter;
@@ -99,7 +103,7 @@ public class CsrfFilterTests {
 
 	@Test
 	public void constructorNullRepository() {
-		assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter(null));
+		assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter((CsrfTokenRequestHandler) null));
 	}
 
 	// SEC-2276
@@ -249,7 +253,7 @@ public class CsrfFilterTests {
 
 	@Test
 	public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethods() throws ServletException, IOException {
-		this.filter = new CsrfFilter(this.tokenRepository);
+		this.filter = createCsrfFilter(this.tokenRepository);
 		this.filter.setAccessDeniedHandler(this.deniedHandler);
 		for (String method : Arrays.asList("GET", "TRACE", "OPTIONS", "HEAD")) {
 			resetRequestResponse();
@@ -269,7 +273,7 @@ public class CsrfFilterTests {
 	 */
 	@Test
 	public void doFilterDefaultRequireCsrfProtectionMatcherAllowedMethodsCaseSensitive() throws Exception {
-		this.filter = new CsrfFilter(this.tokenRepository);
+		this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
 		this.filter.setAccessDeniedHandler(this.deniedHandler);
 		for (String method : Arrays.asList("get", "TrAcE", "oPTIOnS", "hEaD")) {
 			resetRequestResponse();
@@ -284,7 +288,7 @@ public class CsrfFilterTests {
 
 	@Test
 	public void doFilterDefaultRequireCsrfProtectionMatcherDeniedMethods() throws ServletException, IOException {
-		this.filter = new CsrfFilter(this.tokenRepository);
+		this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
 		this.filter.setAccessDeniedHandler(this.deniedHandler);
 		for (String method : Arrays.asList("POST", "PUT", "PATCH", "DELETE", "INVALID")) {
 			resetRequestResponse();
@@ -299,7 +303,7 @@ public class CsrfFilterTests {
 
 	@Test
 	public void doFilterDefaultAccessDenied() throws ServletException, IOException {
-		this.filter = new CsrfFilter(this.tokenRepository);
+		this.filter = new CsrfFilter(new CsrfTokenRepositoryRequestHandler(this.tokenRepository));
 		this.filter.setRequireCsrfProtectionMatcher(this.requestMatcher);
 		given(this.requestMatcher.matches(this.request)).willReturn(true);
 		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
@@ -313,7 +317,7 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterWhenSkipRequestInvokedThenSkips() throws Exception {
 		CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
-		CsrfFilter filter = new CsrfFilter(repository);
+		CsrfFilter filter = createCsrfFilter(repository);
 		lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
 		MockHttpServletRequest request = new MockHttpServletRequest();
 		CsrfFilter.skipRequest(request);
@@ -340,25 +344,13 @@ public class CsrfFilterTests {
 		CsrfTokenRequestHandler requestHandler = mock(CsrfTokenRequestHandler.class);
 		given(requestHandler.handle(this.request, this.response))
 				.willReturn(new TestDeferredCsrfToken(this.token, false));
-		this.filter.setRequestHandler(requestHandler);
+		this.filter = createCsrfFilter(requestHandler);
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
 		this.filter.doFilter(this.request, this.response, this.filterChain);
 		verify(requestHandler).handle(eq(this.request), eq(this.response));
 		verify(this.filterChain).doFilter(this.request, this.response);
 	}
 
-	@Test
-	public void doFilterWhenRequestResolverThenUsed() throws Exception {
-		given(this.requestMatcher.matches(this.request)).willReturn(true);
-		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
-		CsrfTokenRequestResolver requestResolver = mock(CsrfTokenRequestResolver.class);
-		given(requestResolver.resolveCsrfTokenValue(this.request, this.token)).willReturn(this.token.getToken());
-		this.filter.setRequestResolver(requestResolver);
-		this.filter.doFilter(this.request, this.response, this.filterChain);
-		verify(requestResolver).resolveCsrfTokenValue(this.request, this.token);
-		verify(this.filterChain).doFilter(this.request, this.response);
-	}
-
 	@Test
 	public void setRequireCsrfProtectionMatcherNull() {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequireCsrfProtectionMatcher(null));
@@ -373,16 +365,14 @@ public class CsrfFilterTests {
 	@Test
 	public void doFilterWhenCsrfRequestAttributeNameThenNoCsrfTokenMethodInvokedOnGet()
 			throws ServletException, IOException {
-		CsrfFilter filter = createCsrfFilter(this.tokenRepository);
 		String csrfAttrName = "_csrf";
-		CsrfTokenRequestProcessor csrfTokenRequestProcessor = new CsrfTokenRequestProcessor();
-		csrfTokenRequestProcessor.setTokenRepository(this.tokenRepository);
-		csrfTokenRequestProcessor.setCsrfRequestAttributeName(csrfAttrName);
-		filter.setRequestHandler(csrfTokenRequestProcessor);
+		CsrfTokenRepositoryRequestHandler requestHandler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository);
+		requestHandler.setCsrfRequestAttributeName(csrfAttrName);
+		this.filter = createCsrfFilter(requestHandler);
 		CsrfToken expectedCsrfToken = spy(this.token);
 		given(this.tokenRepository.loadToken(this.request)).willReturn(expectedCsrfToken);
 
-		filter.doFilter(this.request, this.response, this.filterChain);
+		this.filter.doFilter(this.request, this.response, this.filterChain);
 
 		verifyNoInteractions(expectedCsrfToken);
 		CsrfToken tokenFromRequest = (CsrfToken) this.request.getAttribute(csrfAttrName);
@@ -410,6 +400,6 @@ public class CsrfFilterTests {
 			return this.isGenerated;
 		}
 
-	};
+	}
 
 }

+ 36 - 16
web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRequestProcessorTests.java → web/src/test/java/org/springframework/security/web/csrf/CsrfTokenRepositoryRequestHandlerTests.java

@@ -31,13 +31,13 @@ import static org.mockito.BDDMockito.given;
 import static org.springframework.security.web.csrf.CsrfTokenAssert.assertThatCsrfToken;
 
 /**
- * Tests for {@link CsrfTokenRequestProcessor}.
+ * Tests for {@link CsrfTokenRepositoryRequestHandler}.
  *
  * @author Steve Riesenberg
  * @since 5.8
  */
 @ExtendWith(MockitoExtension.class)
-public class CsrfTokenRequestProcessorTests {
+public class CsrfTokenRepositoryRequestHandlerTests {
 
 	@Mock
 	CsrfTokenRepository tokenRepository;
@@ -48,34 +48,48 @@ public class CsrfTokenRequestProcessorTests {
 
 	private CsrfToken token;
 
-	private CsrfTokenRequestProcessor processor;
+	private CsrfTokenRepositoryRequestHandler handler;
 
 	@BeforeEach
 	public void setup() {
 		this.request = new MockHttpServletRequest();
 		this.response = new MockHttpServletResponse();
 		this.token = new DefaultCsrfToken("headerName", "paramName", "csrfTokenValue");
-		this.processor = new CsrfTokenRequestProcessor();
-		this.processor.setTokenRepository(this.tokenRepository);
+		this.handler = new CsrfTokenRepositoryRequestHandler(this.tokenRepository);
+	}
+
+	@Test
+	public void constructorWhenCsrfTokenRepositoryIsNullThenThrowsIllegalArgumentException() {
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> new CsrfTokenRepositoryRequestHandler(null))
+				.withMessage("csrfTokenRepository cannot be null");
+		// @formatter:on
 	}
 
 	@Test
 	public void handleWhenRequestIsNullThenThrowsIllegalArgumentException() {
-		assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(null, this.response))
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.handler.handle(null, this.response))
 				.withMessage("request cannot be null");
+		// @formatter:on
 	}
 
 	@Test
 	public void handleWhenResponseIsNullThenThrowsIllegalArgumentException() {
-		assertThatIllegalArgumentException().isThrownBy(() -> this.processor.handle(this.request, null))
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.handler.handle(this.request, null))
 				.withMessage("response cannot be null");
+		// @formatter:on
 	}
 
 	@Test
 	public void handleWhenCsrfRequestAttributeSetThenUsed() {
 		given(this.tokenRepository.generateToken(this.request)).willReturn(this.token);
-		this.processor.setCsrfRequestAttributeName("_csrf");
-		this.processor.handle(this.request, this.response);
+		this.handler.setCsrfRequestAttributeName("_csrf");
+		this.handler.handle(this.request, this.response);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute("_csrf")).isEqualTo(this.token);
 	}
@@ -83,40 +97,46 @@ public class CsrfTokenRequestProcessorTests {
 	@Test
 	public void handleWhenValidParametersThenRequestAttributesSet() {
 		given(this.tokenRepository.loadToken(this.request)).willReturn(this.token);
-		this.processor.handle(this.request, this.response);
+		this.handler.handle(this.request, this.response);
 		assertThatCsrfToken(this.request.getAttribute(CsrfToken.class.getName())).isEqualTo(this.token);
 		assertThatCsrfToken(this.request.getAttribute(this.token.getParameterName())).isEqualTo(this.token);
 	}
 
 	@Test
 	public void resolveCsrfTokenValueWhenRequestIsNullThenThrowsIllegalArgumentException() {
-		assertThatIllegalArgumentException().isThrownBy(() -> this.processor.resolveCsrfTokenValue(null, this.token))
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token))
 				.withMessage("request cannot be null");
+		// @formatter:on
 	}
 
 	@Test
 	public void resolveCsrfTokenValueWhenCsrfTokenIsNullThenThrowsIllegalArgumentException() {
-		assertThatIllegalArgumentException().isThrownBy(() -> this.processor.resolveCsrfTokenValue(this.request, null))
+		// @formatter:off
+		assertThatIllegalArgumentException()
+				.isThrownBy(() -> this.handler.resolveCsrfTokenValue(this.request, null))
 				.withMessage("csrfToken cannot be null");
+		// @formatter:on
 	}
 
 	@Test
 	public void resolveCsrfTokenValueWhenTokenNotSetThenReturnsNull() {
-		String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token);
+		String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
 		assertThat(tokenValue).isNull();
 	}
 
 	@Test
 	public void resolveCsrfTokenValueWhenParameterSetThenReturnsTokenValue() {
 		this.request.setParameter(this.token.getParameterName(), this.token.getToken());
-		String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token);
+		String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
 		assertThat(tokenValue).isEqualTo(this.token.getToken());
 	}
 
 	@Test
 	public void resolveCsrfTokenValueWhenHeaderSetThenReturnsTokenValue() {
 		this.request.addHeader(this.token.getHeaderName(), this.token.getToken());
-		String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token);
+		String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
 		assertThat(tokenValue).isEqualTo(this.token.getToken());
 	}
 
@@ -124,7 +144,7 @@ public class CsrfTokenRequestProcessorTests {
 	public void resolveCsrfTokenValueWhenHeaderAndParameterSetThenHeaderIsPreferred() {
 		this.request.addHeader(this.token.getHeaderName(), "header");
 		this.request.setParameter(this.token.getParameterName(), "parameter");
-		String tokenValue = this.processor.resolveCsrfTokenValue(this.request, this.token);
+		String tokenValue = this.handler.resolveCsrfTokenValue(this.request, this.token);
 		assertThat(tokenValue).isEqualTo("header");
 	}