|
@@ -1,5 +1,5 @@
|
|
|
/*
|
|
|
- * Copyright 2002-2021 the original author or authors.
|
|
|
+ * Copyright 2002-2022 the original author or authors.
|
|
|
*
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
* you may not use this file except in compliance with the License.
|
|
@@ -34,13 +34,17 @@ import org.springframework.test.web.reactive.server.WebTestClient;
|
|
|
import org.springframework.web.bind.annotation.RequestMapping;
|
|
|
import org.springframework.web.bind.annotation.RestController;
|
|
|
import org.springframework.web.reactive.function.BodyInserters;
|
|
|
+import org.springframework.web.server.ServerWebExchange;
|
|
|
import org.springframework.web.server.WebFilterChain;
|
|
|
import org.springframework.web.server.WebSession;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.assertThat;
|
|
|
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
|
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
|
+import static org.mockito.ArgumentMatchers.eq;
|
|
|
import static org.mockito.BDDMockito.given;
|
|
|
import static org.mockito.Mockito.mock;
|
|
|
+import static org.mockito.Mockito.verify;
|
|
|
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
|
|
|
|
|
/**
|
|
@@ -65,6 +69,15 @@ public class CsrfWebFilterTests {
|
|
|
|
|
|
private MockServerWebExchange post = MockServerWebExchange.from(MockServerHttpRequest.post("/"));
|
|
|
|
|
|
+ @Test
|
|
|
+ public void setRequestHandlerWhenNullThenThrowsIllegalArgumentException() {
|
|
|
+ // @formatter:off
|
|
|
+ assertThatIllegalArgumentException()
|
|
|
+ .isThrownBy(() -> this.csrfFilter.setRequestHandler(null))
|
|
|
+ .withMessage("requestHandler cannot be null");
|
|
|
+ // @formatter:on
|
|
|
+ }
|
|
|
+
|
|
|
@Test
|
|
|
public void filterWhenGetThenSessionNotCreatedAndChainContinues() {
|
|
|
PublisherProbe<Void> chainResult = PublisherProbe.empty();
|
|
@@ -145,6 +158,66 @@ public class CsrfWebFilterTests {
|
|
|
chainResult.assertWasSubscribed();
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void filterWhenRequestHandlerSetThenUsed() {
|
|
|
+ ServerCsrfTokenRequestHandler requestHandler = mock(ServerCsrfTokenRequestHandler.class);
|
|
|
+ given(requestHandler.resolveCsrfTokenValue(any(ServerWebExchange.class), any(CsrfToken.class)))
|
|
|
+ .willReturn(Mono.just(this.token.getToken()));
|
|
|
+ this.csrfFilter.setRequestHandler(requestHandler);
|
|
|
+
|
|
|
+ PublisherProbe<Void> chainResult = PublisherProbe.empty();
|
|
|
+ given(this.chain.filter(any())).willReturn(chainResult.mono());
|
|
|
+ this.csrfFilter.setCsrfTokenRepository(this.repository);
|
|
|
+ given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
|
|
|
+ given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
|
|
|
+ this.post = MockServerWebExchange
|
|
|
+ .from(MockServerHttpRequest.post("/").header(this.token.getHeaderName(), this.token.getToken()));
|
|
|
+ Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
|
|
|
+ StepVerifier.create(result).verifyComplete();
|
|
|
+ chainResult.assertWasSubscribed();
|
|
|
+
|
|
|
+ verify(requestHandler).handle(eq(this.post), any());
|
|
|
+ verify(requestHandler).resolveCsrfTokenValue(this.post, this.token);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void filterWhenXorServerCsrfTokenRequestProcessorAndValidTokenThenSuccess() {
|
|
|
+ PublisherProbe<Void> chainResult = PublisherProbe.empty();
|
|
|
+ given(this.chain.filter(any())).willReturn(chainResult.mono());
|
|
|
+ this.csrfFilter.setCsrfTokenRepository(this.repository);
|
|
|
+ given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
|
|
|
+ given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
|
|
|
+ XorServerCsrfTokenRequestAttributeHandler requestHandler = new XorServerCsrfTokenRequestAttributeHandler();
|
|
|
+ this.csrfFilter.setRequestHandler(requestHandler);
|
|
|
+ StepVerifier.create(this.csrfFilter.filter(this.get, this.chain)).verifyComplete();
|
|
|
+ chainResult.assertWasSubscribed();
|
|
|
+
|
|
|
+ Mono<CsrfToken> csrfTokenAttribute = this.get.getAttribute(CsrfToken.class.getName());
|
|
|
+ assertThat(csrfTokenAttribute).isNotNull();
|
|
|
+ StepVerifier.create(csrfTokenAttribute)
|
|
|
+ .consumeNextWith((csrfToken) -> this.post = MockServerWebExchange
|
|
|
+ .from(MockServerHttpRequest.post("/").header(csrfToken.getHeaderName(), csrfToken.getToken())))
|
|
|
+ .verifyComplete();
|
|
|
+
|
|
|
+ StepVerifier.create(this.csrfFilter.filter(this.post, this.chain)).verifyComplete();
|
|
|
+ chainResult.assertWasSubscribed();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void filterWhenXorServerCsrfTokenRequestProcessorAndRawTokenThenAccessDeniedException() {
|
|
|
+ PublisherProbe<Void> chainResult = PublisherProbe.empty();
|
|
|
+ this.csrfFilter.setCsrfTokenRepository(this.repository);
|
|
|
+ given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
|
|
|
+ XorServerCsrfTokenRequestAttributeHandler requestHandler = new XorServerCsrfTokenRequestAttributeHandler();
|
|
|
+ this.csrfFilter.setRequestHandler(requestHandler);
|
|
|
+ this.post = MockServerWebExchange
|
|
|
+ .from(MockServerHttpRequest.post("/").header(this.token.getHeaderName(), this.token.getToken()));
|
|
|
+ Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
|
|
|
+ StepVerifier.create(result).verifyComplete();
|
|
|
+ chainResult.assertWasNotSubscribed();
|
|
|
+ assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN);
|
|
|
+ }
|
|
|
+
|
|
|
@Test
|
|
|
// gh-8452
|
|
|
public void matchesRequireCsrfProtectionWhenNonStandardHTTPMethodIsUsed() {
|
|
@@ -180,7 +253,9 @@ public class CsrfWebFilterTests {
|
|
|
@Test
|
|
|
public void filterWhenMultipartFormDataAndEnabledThenGranted() {
|
|
|
this.csrfFilter.setCsrfTokenRepository(this.repository);
|
|
|
- this.csrfFilter.setTokenFromMultipartDataEnabled(true);
|
|
|
+ ServerCsrfTokenRequestAttributeHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler();
|
|
|
+ requestHandler.setTokenFromMultipartDataEnabled(true);
|
|
|
+ this.csrfFilter.setRequestHandler(requestHandler);
|
|
|
given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
|
|
|
given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
|
|
|
WebTestClient client = WebTestClient.bindToController(new OkController()).webFilter(this.csrfFilter).build();
|
|
@@ -192,7 +267,9 @@ public class CsrfWebFilterTests {
|
|
|
@Test
|
|
|
public void filterWhenPostAndMultipartFormDataEnabledAndNoBodyProvided() {
|
|
|
this.csrfFilter.setCsrfTokenRepository(this.repository);
|
|
|
- this.csrfFilter.setTokenFromMultipartDataEnabled(true);
|
|
|
+ ServerCsrfTokenRequestAttributeHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler();
|
|
|
+ requestHandler.setTokenFromMultipartDataEnabled(true);
|
|
|
+ this.csrfFilter.setRequestHandler(requestHandler);
|
|
|
given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
|
|
|
given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
|
|
|
WebTestClient client = WebTestClient.bindToController(new OkController()).webFilter(this.csrfFilter).build();
|
|
@@ -203,7 +280,9 @@ public class CsrfWebFilterTests {
|
|
|
@Test
|
|
|
public void filterWhenFormDataAndEnabledThenGranted() {
|
|
|
this.csrfFilter.setCsrfTokenRepository(this.repository);
|
|
|
- this.csrfFilter.setTokenFromMultipartDataEnabled(true);
|
|
|
+ ServerCsrfTokenRequestAttributeHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler();
|
|
|
+ requestHandler.setTokenFromMultipartDataEnabled(true);
|
|
|
+ this.csrfFilter.setRequestHandler(requestHandler);
|
|
|
given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
|
|
|
given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
|
|
|
WebTestClient client = WebTestClient.bindToController(new OkController()).webFilter(this.csrfFilter).build();
|
|
@@ -215,7 +294,9 @@ public class CsrfWebFilterTests {
|
|
|
@Test
|
|
|
public void filterWhenMultipartMixedAndEnabledThenNotRead() {
|
|
|
this.csrfFilter.setCsrfTokenRepository(this.repository);
|
|
|
- this.csrfFilter.setTokenFromMultipartDataEnabled(true);
|
|
|
+ ServerCsrfTokenRequestAttributeHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler();
|
|
|
+ requestHandler.setTokenFromMultipartDataEnabled(true);
|
|
|
+ this.csrfFilter.setRequestHandler(requestHandler);
|
|
|
given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
|
|
|
WebTestClient client = WebTestClient.bindToController(new OkController()).webFilter(this.csrfFilter).build();
|
|
|
client.post().uri("/").contentType(MediaType.MULTIPART_MIXED)
|