|  | @@ -1,5 +1,5 @@
 | 
	
		
			
				|  |  |  /*
 | 
	
		
			
				|  |  | - * Copyright 2002-2021 the original author or authors.
 | 
	
		
			
				|  |  | + * Copyright 2002-2022 the original author or authors.
 | 
	
		
			
				|  |  |   *
 | 
	
		
			
				|  |  |   * Licensed under the Apache License, Version 2.0 (the "License");
 | 
	
		
			
				|  |  |   * you may not use this file except in compliance with the License.
 | 
	
	
		
			
				|  | @@ -34,13 +34,17 @@ import org.springframework.test.web.reactive.server.WebTestClient;
 | 
	
		
			
				|  |  |  import org.springframework.web.bind.annotation.RequestMapping;
 | 
	
		
			
				|  |  |  import org.springframework.web.bind.annotation.RestController;
 | 
	
		
			
				|  |  |  import org.springframework.web.reactive.function.BodyInserters;
 | 
	
		
			
				|  |  | +import org.springframework.web.server.ServerWebExchange;
 | 
	
		
			
				|  |  |  import org.springframework.web.server.WebFilterChain;
 | 
	
		
			
				|  |  |  import org.springframework.web.server.WebSession;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import static org.assertj.core.api.Assertions.assertThat;
 | 
	
		
			
				|  |  | +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 | 
	
		
			
				|  |  |  import static org.mockito.ArgumentMatchers.any;
 | 
	
		
			
				|  |  | +import static org.mockito.ArgumentMatchers.eq;
 | 
	
		
			
				|  |  |  import static org.mockito.BDDMockito.given;
 | 
	
		
			
				|  |  |  import static org.mockito.Mockito.mock;
 | 
	
		
			
				|  |  | +import static org.mockito.Mockito.verify;
 | 
	
		
			
				|  |  |  import static org.mockito.Mockito.verifyNoMoreInteractions;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  /**
 | 
	
	
		
			
				|  | @@ -65,6 +69,15 @@ public class CsrfWebFilterTests {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	private MockServerWebExchange post = MockServerWebExchange.from(MockServerHttpRequest.post("/"));
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	@Test
 | 
	
		
			
				|  |  | +	public void setRequestHandlerWhenNullThenThrowsIllegalArgumentException() {
 | 
	
		
			
				|  |  | +		// @formatter:off
 | 
	
		
			
				|  |  | +		assertThatIllegalArgumentException()
 | 
	
		
			
				|  |  | +				.isThrownBy(() -> this.csrfFilter.setRequestHandler(null))
 | 
	
		
			
				|  |  | +				.withMessage("requestHandler cannot be null");
 | 
	
		
			
				|  |  | +		// @formatter:on
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  |  	public void filterWhenGetThenSessionNotCreatedAndChainContinues() {
 | 
	
		
			
				|  |  |  		PublisherProbe<Void> chainResult = PublisherProbe.empty();
 | 
	
	
		
			
				|  | @@ -145,6 +158,66 @@ public class CsrfWebFilterTests {
 | 
	
		
			
				|  |  |  		chainResult.assertWasSubscribed();
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	@Test
 | 
	
		
			
				|  |  | +	public void filterWhenRequestHandlerSetThenUsed() {
 | 
	
		
			
				|  |  | +		ServerCsrfTokenRequestHandler requestHandler = mock(ServerCsrfTokenRequestHandler.class);
 | 
	
		
			
				|  |  | +		given(requestHandler.resolveCsrfTokenValue(any(ServerWebExchange.class), any(CsrfToken.class)))
 | 
	
		
			
				|  |  | +				.willReturn(Mono.just(this.token.getToken()));
 | 
	
		
			
				|  |  | +		this.csrfFilter.setRequestHandler(requestHandler);
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		PublisherProbe<Void> chainResult = PublisherProbe.empty();
 | 
	
		
			
				|  |  | +		given(this.chain.filter(any())).willReturn(chainResult.mono());
 | 
	
		
			
				|  |  | +		this.csrfFilter.setCsrfTokenRepository(this.repository);
 | 
	
		
			
				|  |  | +		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
 | 
	
		
			
				|  |  | +		given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
 | 
	
		
			
				|  |  | +		this.post = MockServerWebExchange
 | 
	
		
			
				|  |  | +				.from(MockServerHttpRequest.post("/").header(this.token.getHeaderName(), this.token.getToken()));
 | 
	
		
			
				|  |  | +		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
 | 
	
		
			
				|  |  | +		StepVerifier.create(result).verifyComplete();
 | 
	
		
			
				|  |  | +		chainResult.assertWasSubscribed();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		verify(requestHandler).handle(eq(this.post), any());
 | 
	
		
			
				|  |  | +		verify(requestHandler).resolveCsrfTokenValue(this.post, this.token);
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	@Test
 | 
	
		
			
				|  |  | +	public void filterWhenXorServerCsrfTokenRequestProcessorAndValidTokenThenSuccess() {
 | 
	
		
			
				|  |  | +		PublisherProbe<Void> chainResult = PublisherProbe.empty();
 | 
	
		
			
				|  |  | +		given(this.chain.filter(any())).willReturn(chainResult.mono());
 | 
	
		
			
				|  |  | +		this.csrfFilter.setCsrfTokenRepository(this.repository);
 | 
	
		
			
				|  |  | +		given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
 | 
	
		
			
				|  |  | +		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
 | 
	
		
			
				|  |  | +		XorServerCsrfTokenRequestAttributeHandler requestHandler = new XorServerCsrfTokenRequestAttributeHandler();
 | 
	
		
			
				|  |  | +		this.csrfFilter.setRequestHandler(requestHandler);
 | 
	
		
			
				|  |  | +		StepVerifier.create(this.csrfFilter.filter(this.get, this.chain)).verifyComplete();
 | 
	
		
			
				|  |  | +		chainResult.assertWasSubscribed();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		Mono<CsrfToken> csrfTokenAttribute = this.get.getAttribute(CsrfToken.class.getName());
 | 
	
		
			
				|  |  | +		assertThat(csrfTokenAttribute).isNotNull();
 | 
	
		
			
				|  |  | +		StepVerifier.create(csrfTokenAttribute)
 | 
	
		
			
				|  |  | +				.consumeNextWith((csrfToken) -> this.post = MockServerWebExchange
 | 
	
		
			
				|  |  | +						.from(MockServerHttpRequest.post("/").header(csrfToken.getHeaderName(), csrfToken.getToken())))
 | 
	
		
			
				|  |  | +				.verifyComplete();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		StepVerifier.create(this.csrfFilter.filter(this.post, this.chain)).verifyComplete();
 | 
	
		
			
				|  |  | +		chainResult.assertWasSubscribed();
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	@Test
 | 
	
		
			
				|  |  | +	public void filterWhenXorServerCsrfTokenRequestProcessorAndRawTokenThenAccessDeniedException() {
 | 
	
		
			
				|  |  | +		PublisherProbe<Void> chainResult = PublisherProbe.empty();
 | 
	
		
			
				|  |  | +		this.csrfFilter.setCsrfTokenRepository(this.repository);
 | 
	
		
			
				|  |  | +		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
 | 
	
		
			
				|  |  | +		XorServerCsrfTokenRequestAttributeHandler requestHandler = new XorServerCsrfTokenRequestAttributeHandler();
 | 
	
		
			
				|  |  | +		this.csrfFilter.setRequestHandler(requestHandler);
 | 
	
		
			
				|  |  | +		this.post = MockServerWebExchange
 | 
	
		
			
				|  |  | +				.from(MockServerHttpRequest.post("/").header(this.token.getHeaderName(), this.token.getToken()));
 | 
	
		
			
				|  |  | +		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
 | 
	
		
			
				|  |  | +		StepVerifier.create(result).verifyComplete();
 | 
	
		
			
				|  |  | +		chainResult.assertWasNotSubscribed();
 | 
	
		
			
				|  |  | +		assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN);
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  |  	// gh-8452
 | 
	
		
			
				|  |  |  	public void matchesRequireCsrfProtectionWhenNonStandardHTTPMethodIsUsed() {
 | 
	
	
		
			
				|  | @@ -180,7 +253,9 @@ public class CsrfWebFilterTests {
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  |  	public void filterWhenMultipartFormDataAndEnabledThenGranted() {
 | 
	
		
			
				|  |  |  		this.csrfFilter.setCsrfTokenRepository(this.repository);
 | 
	
		
			
				|  |  | -		this.csrfFilter.setTokenFromMultipartDataEnabled(true);
 | 
	
		
			
				|  |  | +		ServerCsrfTokenRequestAttributeHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler();
 | 
	
		
			
				|  |  | +		requestHandler.setTokenFromMultipartDataEnabled(true);
 | 
	
		
			
				|  |  | +		this.csrfFilter.setRequestHandler(requestHandler);
 | 
	
		
			
				|  |  |  		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
 | 
	
		
			
				|  |  |  		given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
 | 
	
		
			
				|  |  |  		WebTestClient client = WebTestClient.bindToController(new OkController()).webFilter(this.csrfFilter).build();
 | 
	
	
		
			
				|  | @@ -192,7 +267,9 @@ public class CsrfWebFilterTests {
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  |  	public void filterWhenPostAndMultipartFormDataEnabledAndNoBodyProvided() {
 | 
	
		
			
				|  |  |  		this.csrfFilter.setCsrfTokenRepository(this.repository);
 | 
	
		
			
				|  |  | -		this.csrfFilter.setTokenFromMultipartDataEnabled(true);
 | 
	
		
			
				|  |  | +		ServerCsrfTokenRequestAttributeHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler();
 | 
	
		
			
				|  |  | +		requestHandler.setTokenFromMultipartDataEnabled(true);
 | 
	
		
			
				|  |  | +		this.csrfFilter.setRequestHandler(requestHandler);
 | 
	
		
			
				|  |  |  		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
 | 
	
		
			
				|  |  |  		given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
 | 
	
		
			
				|  |  |  		WebTestClient client = WebTestClient.bindToController(new OkController()).webFilter(this.csrfFilter).build();
 | 
	
	
		
			
				|  | @@ -203,7 +280,9 @@ public class CsrfWebFilterTests {
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  |  	public void filterWhenFormDataAndEnabledThenGranted() {
 | 
	
		
			
				|  |  |  		this.csrfFilter.setCsrfTokenRepository(this.repository);
 | 
	
		
			
				|  |  | -		this.csrfFilter.setTokenFromMultipartDataEnabled(true);
 | 
	
		
			
				|  |  | +		ServerCsrfTokenRequestAttributeHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler();
 | 
	
		
			
				|  |  | +		requestHandler.setTokenFromMultipartDataEnabled(true);
 | 
	
		
			
				|  |  | +		this.csrfFilter.setRequestHandler(requestHandler);
 | 
	
		
			
				|  |  |  		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
 | 
	
		
			
				|  |  |  		given(this.repository.generateToken(any())).willReturn(Mono.just(this.token));
 | 
	
		
			
				|  |  |  		WebTestClient client = WebTestClient.bindToController(new OkController()).webFilter(this.csrfFilter).build();
 | 
	
	
		
			
				|  | @@ -215,7 +294,9 @@ public class CsrfWebFilterTests {
 | 
	
		
			
				|  |  |  	@Test
 | 
	
		
			
				|  |  |  	public void filterWhenMultipartMixedAndEnabledThenNotRead() {
 | 
	
		
			
				|  |  |  		this.csrfFilter.setCsrfTokenRepository(this.repository);
 | 
	
		
			
				|  |  | -		this.csrfFilter.setTokenFromMultipartDataEnabled(true);
 | 
	
		
			
				|  |  | +		ServerCsrfTokenRequestAttributeHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler();
 | 
	
		
			
				|  |  | +		requestHandler.setTokenFromMultipartDataEnabled(true);
 | 
	
		
			
				|  |  | +		this.csrfFilter.setRequestHandler(requestHandler);
 | 
	
		
			
				|  |  |  		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token));
 | 
	
		
			
				|  |  |  		WebTestClient client = WebTestClient.bindToController(new OkController()).webFilter(this.csrfFilter).build();
 | 
	
		
			
				|  |  |  		client.post().uri("/").contentType(MediaType.MULTIPART_MIXED)
 |