Browse Source

Default to Xor CSRF tokens in CsrfWebFilter

Closes gh-11960
Steve Riesenberg 2 years ago
parent
commit
2407d07890

+ 13 - 1
config/src/test/kotlin/org/springframework/security/config/web/server/ServerCsrfDslTests.kt

@@ -27,6 +27,8 @@ import org.springframework.context.annotation.Bean
 import org.springframework.context.annotation.Configuration
 import org.springframework.http.HttpStatus
 import org.springframework.http.MediaType
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest
+import org.springframework.mock.web.server.MockServerWebExchange
 import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity
 import org.springframework.security.config.test.SpringTestContext
 import org.springframework.security.config.test.SpringTestContextExtension
@@ -39,6 +41,7 @@ import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository
 import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestAttributeHandler
 import org.springframework.security.web.server.csrf.ServerCsrfTokenRequestHandler
 import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository
+import org.springframework.security.web.server.csrf.XorServerCsrfTokenRequestAttributeHandler
 import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher
 import org.springframework.test.web.reactive.server.WebTestClient
 import org.springframework.web.bind.annotation.PostMapping
@@ -278,14 +281,23 @@ class ServerCsrfDslTests {
             MultipartFormDataEnabledConfig.TOKEN_REPOSITORY.generateToken(any())
         } returns Mono.just(this.token)
 
+        val csrfToken = createXorCsrfToken()
         this.client.post()
                 .uri("/")
                 .contentType(MediaType.MULTIPART_FORM_DATA)
-                .body(fromMultipartData(this.token.parameterName, this.token.token))
+                .body(fromMultipartData(csrfToken.parameterName, csrfToken.token))
                 .exchange()
                 .expectStatus().isOk
     }
 
+    private fun createXorCsrfToken(): CsrfToken {
+        val handler = XorServerCsrfTokenRequestAttributeHandler()
+        val exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"))
+        handler.handle(exchange, Mono.just(this.token))
+        val deferredCsrfToken: Mono<CsrfToken>? = exchange.getAttribute(CsrfToken::class.java.name)
+        return deferredCsrfToken?.block()!!
+    }
+
     @Configuration
     @EnableWebFluxSecurity
     @EnableWebFlux

+ 1 - 1
web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java

@@ -83,7 +83,7 @@ public class CsrfWebFilter implements WebFilter {
 	private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(
 			HttpStatus.FORBIDDEN);
 
-	private ServerCsrfTokenRequestHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler();
+	private ServerCsrfTokenRequestHandler requestHandler = new XorServerCsrfTokenRequestAttributeHandler();
 
 	public void setAccessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) {
 		Assert.notNull(accessDeniedHandler, "accessDeniedHandler");

+ 19 - 16
web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java

@@ -125,9 +125,10 @@ public class CsrfWebFilterTests {
 		this.csrfFilter.setCsrfTokenRepository(this.repository);
 		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
 		given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
+		CsrfToken csrfToken = createXorCsrfToken();
 		this.post = MockServerWebExchange
 				.from(MockServerHttpRequest.post("/").contentType(MediaType.APPLICATION_FORM_URLENCODED)
-						.body(this.token.getParameterName() + "=" + this.token.getToken()));
+						.body(csrfToken.getParameterName() + "=" + csrfToken.getToken()));
 		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
 		StepVerifier.create(result).verifyComplete();
 		chainResult.assertWasSubscribed();
@@ -151,8 +152,9 @@ public class CsrfWebFilterTests {
 		this.csrfFilter.setCsrfTokenRepository(this.repository);
 		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
 		given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
+		CsrfToken csrfToken = createXorCsrfToken();
 		this.post = MockServerWebExchange
-				.from(MockServerHttpRequest.post("/").header(this.token.getHeaderName(), this.token.getToken()));
+				.from(MockServerHttpRequest.post("/").header(csrfToken.getHeaderName(), csrfToken.getToken()));
 		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
 		StepVerifier.create(result).verifyComplete();
 		chainResult.assertWasSubscribed();
@@ -181,30 +183,22 @@ public class CsrfWebFilterTests {
 	}
 
 	@Test
-	public void filterWhenXorServerCsrfTokenRequestProcessorAndValidTokenThenSuccess() {
+	public void filterWhenXorServerCsrfTokenRequestAttributeHandlerAndValidTokenThenSuccess() {
 		PublisherProbe<Void> chainResult = PublisherProbe.empty();
 		given(this.chain.filter(any())).willReturn(chainResult.mono());
 		this.csrfFilter.setCsrfTokenRepository(this.repository);
 		given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
 		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
-		XorServerCsrfTokenRequestAttributeHandler requestHandler = new XorServerCsrfTokenRequestAttributeHandler();
-		this.csrfFilter.setRequestHandler(requestHandler);
-		StepVerifier.create(this.csrfFilter.filter(this.get, this.chain)).verifyComplete();
-		chainResult.assertWasSubscribed();
-
-		Mono<CsrfToken> csrfTokenAttribute = this.get.getAttribute(CsrfToken.class.getName());
-		assertThat(csrfTokenAttribute).isNotNull();
-		StepVerifier.create(csrfTokenAttribute)
-				.consumeNextWith((csrfToken) -> this.post = MockServerWebExchange
-						.from(MockServerHttpRequest.post("/").header(csrfToken.getHeaderName(), csrfToken.getToken())))
-				.verifyComplete();
 
+		CsrfToken csrfToken = createXorCsrfToken();
+		this.post = MockServerWebExchange
+				.from(MockServerHttpRequest.post("/").header(csrfToken.getHeaderName(), csrfToken.getToken()));
 		StepVerifier.create(this.csrfFilter.filter(this.post, this.chain)).verifyComplete();
 		chainResult.assertWasSubscribed();
 	}
 
 	@Test
-	public void filterWhenXorServerCsrfTokenRequestProcessorAndRawTokenThenAccessDeniedException() {
+	public void filterWhenXorServerCsrfTokenRequestAttributeHandlerAndRawTokenThenAccessDeniedException() {
 		PublisherProbe<Void> chainResult = PublisherProbe.empty();
 		this.csrfFilter.setCsrfTokenRepository(this.repository);
 		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
@@ -305,6 +299,7 @@ public class CsrfWebFilterTests {
 	}
 
 	// gh-9561
+
 	@Test
 	public void doFilterWhenTokenIsNullThenNoNullPointer() {
 		this.csrfFilter.setCsrfTokenRepository(this.repository);
@@ -318,8 +313,8 @@ public class CsrfWebFilterTests {
 				.bodyValue(this.token.getParameterName() + "=" + this.token.getToken()).exchange().expectStatus()
 				.isForbidden();
 	}
-
 	// gh-9113
+
 	@Test
 	public void filterWhenSubscribingCsrfTokenMultipleTimesThenGenerateOnlyOnce() {
 		PublisherProbe<CsrfToken> chainResult = PublisherProbe.empty();
@@ -334,6 +329,14 @@ public class CsrfWebFilterTests {
 		assertThat(chainResult.subscribeCount()).isEqualTo(1);
 	}
 
+	private CsrfToken createXorCsrfToken() {
+		ServerCsrfTokenRequestHandler handler = new XorServerCsrfTokenRequestAttributeHandler();
+		MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
+		handler.handle(exchange, Mono.just(this.token));
+		Mono<CsrfToken> csrfToken = exchange.getAttribute(CsrfToken.class.getName());
+		return csrfToken.block();
+	}
+
 	@RestController
 	static class OkController {