| 
					
				 | 
			
			
				@@ -1,5 +1,5 @@ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 /* 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				- * Copyright 2002-2021 the original author or authors. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ * Copyright 2002-2022 the original author or authors. 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  * 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  * Licensed under the Apache License, Version 2.0 (the "License"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  * you may not use this file except in compliance with the License. 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -34,13 +34,17 @@ import org.springframework.test.web.reactive.server.WebTestClient; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.springframework.web.bind.annotation.RequestMapping; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.springframework.web.bind.annotation.RestController; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.springframework.web.reactive.function.BodyInserters; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import org.springframework.web.server.ServerWebExchange; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.springframework.web.server.WebFilterChain; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import org.springframework.web.server.WebSession; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import static org.assertj.core.api.Assertions.assertThat; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import static org.mockito.ArgumentMatchers.any; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import static org.mockito.ArgumentMatchers.eq; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import static org.mockito.BDDMockito.given; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import static org.mockito.Mockito.mock; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+import static org.mockito.Mockito.verify; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 import static org.mockito.Mockito.verifyNoMoreInteractions; 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 /** 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -65,6 +69,15 @@ public class CsrfWebFilterTests { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	private MockServerWebExchange post = MockServerWebExchange.from(MockServerHttpRequest.post("/")); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+	@Test 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+	public void setRequestHandlerWhenNullThenThrowsIllegalArgumentException() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		// @formatter:off 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		assertThatIllegalArgumentException() 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+				.isThrownBy(() -> this.csrfFilter.setRequestHandler(null)) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+				.withMessage("requestHandler cannot be null"); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		// @formatter:on 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+	} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	@Test 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	public void filterWhenGetThenSessionNotCreatedAndChainContinues() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		PublisherProbe<Void> chainResult = PublisherProbe.empty(); 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -145,6 +158,66 @@ public class CsrfWebFilterTests { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		chainResult.assertWasSubscribed(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				  
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+	@Test 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+	public void filterWhenRequestHandlerSetThenUsed() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		ServerCsrfTokenRequestHandler requestHandler = mock(ServerCsrfTokenRequestHandler.class); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		given(requestHandler.resolveCsrfTokenValue(any(ServerWebExchange.class), any(CsrfToken.class))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+				.willReturn(Mono.just(this.token.getToken())); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		this.csrfFilter.setRequestHandler(requestHandler); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		PublisherProbe<Void> chainResult = PublisherProbe.empty(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		given(this.chain.filter(any())).willReturn(chainResult.mono()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		this.csrfFilter.setCsrfTokenRepository(this.repository); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		given(this.repository.generateToken(any())).willReturn(Mono.just(this.token)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		this.post = MockServerWebExchange 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+				.from(MockServerHttpRequest.post("/").header(this.token.getHeaderName(), this.token.getToken())); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		StepVerifier.create(result).verifyComplete(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		chainResult.assertWasSubscribed(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		verify(requestHandler).handle(eq(this.post), any()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		verify(requestHandler).resolveCsrfTokenValue(this.post, this.token); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+	} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+	@Test 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+	public void filterWhenXorServerCsrfTokenRequestProcessorAndValidTokenThenSuccess() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		PublisherProbe<Void> chainResult = PublisherProbe.empty(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		given(this.chain.filter(any())).willReturn(chainResult.mono()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		this.csrfFilter.setCsrfTokenRepository(this.repository); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		given(this.repository.generateToken(any())).willReturn(Mono.just(this.token)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		XorServerCsrfTokenRequestAttributeHandler requestHandler = new XorServerCsrfTokenRequestAttributeHandler(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		this.csrfFilter.setRequestHandler(requestHandler); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		StepVerifier.create(this.csrfFilter.filter(this.get, this.chain)).verifyComplete(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		chainResult.assertWasSubscribed(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		Mono<CsrfToken> csrfTokenAttribute = this.get.getAttribute(CsrfToken.class.getName()); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		assertThat(csrfTokenAttribute).isNotNull(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		StepVerifier.create(csrfTokenAttribute) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+				.consumeNextWith((csrfToken) -> this.post = MockServerWebExchange 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+						.from(MockServerHttpRequest.post("/").header(csrfToken.getHeaderName(), csrfToken.getToken()))) 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+				.verifyComplete(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		StepVerifier.create(this.csrfFilter.filter(this.post, this.chain)).verifyComplete(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		chainResult.assertWasSubscribed(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+	} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+	@Test 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+	public void filterWhenXorServerCsrfTokenRequestProcessorAndRawTokenThenAccessDeniedException() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		PublisherProbe<Void> chainResult = PublisherProbe.empty(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		this.csrfFilter.setCsrfTokenRepository(this.repository); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		XorServerCsrfTokenRequestAttributeHandler requestHandler = new XorServerCsrfTokenRequestAttributeHandler(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		this.csrfFilter.setRequestHandler(requestHandler); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		this.post = MockServerWebExchange 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+				.from(MockServerHttpRequest.post("/").header(this.token.getHeaderName(), this.token.getToken())); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		StepVerifier.create(result).verifyComplete(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		chainResult.assertWasNotSubscribed(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+	} 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+ 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	@Test 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	// gh-8452 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	public void matchesRequireCsrfProtectionWhenNonStandardHTTPMethodIsUsed() { 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -180,7 +253,9 @@ public class CsrfWebFilterTests { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	@Test 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	public void filterWhenMultipartFormDataAndEnabledThenGranted() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		this.csrfFilter.setCsrfTokenRepository(this.repository); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-		this.csrfFilter.setTokenFromMultipartDataEnabled(true); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		ServerCsrfTokenRequestAttributeHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		requestHandler.setTokenFromMultipartDataEnabled(true); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		this.csrfFilter.setRequestHandler(requestHandler); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		given(this.repository.generateToken(any())).willReturn(Mono.just(this.token)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		WebTestClient client = WebTestClient.bindToController(new OkController()).webFilter(this.csrfFilter).build(); 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -192,7 +267,9 @@ public class CsrfWebFilterTests { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	@Test 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	public void filterWhenPostAndMultipartFormDataEnabledAndNoBodyProvided() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		this.csrfFilter.setCsrfTokenRepository(this.repository); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-		this.csrfFilter.setTokenFromMultipartDataEnabled(true); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		ServerCsrfTokenRequestAttributeHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		requestHandler.setTokenFromMultipartDataEnabled(true); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		this.csrfFilter.setRequestHandler(requestHandler); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		given(this.repository.generateToken(any())).willReturn(Mono.just(this.token)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		WebTestClient client = WebTestClient.bindToController(new OkController()).webFilter(this.csrfFilter).build(); 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -203,7 +280,9 @@ public class CsrfWebFilterTests { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	@Test 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	public void filterWhenFormDataAndEnabledThenGranted() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		this.csrfFilter.setCsrfTokenRepository(this.repository); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-		this.csrfFilter.setTokenFromMultipartDataEnabled(true); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		ServerCsrfTokenRequestAttributeHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		requestHandler.setTokenFromMultipartDataEnabled(true); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		this.csrfFilter.setRequestHandler(requestHandler); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		given(this.repository.generateToken(any())).willReturn(Mono.just(this.token)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		WebTestClient client = WebTestClient.bindToController(new OkController()).webFilter(this.csrfFilter).build(); 
			 | 
		
	
	
		
			
				| 
					
				 | 
			
			
				@@ -215,7 +294,9 @@ public class CsrfWebFilterTests { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	@Test 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 	public void filterWhenMultipartMixedAndEnabledThenNotRead() { 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		this.csrfFilter.setCsrfTokenRepository(this.repository); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				-		this.csrfFilter.setTokenFromMultipartDataEnabled(true); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		ServerCsrfTokenRequestAttributeHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		requestHandler.setTokenFromMultipartDataEnabled(true); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				+		this.csrfFilter.setRequestHandler(requestHandler); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		given(this.repository.loadToken(any())).willReturn(Mono.just(this.token)); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		WebTestClient client = WebTestClient.bindToController(new OkController()).webFilter(this.csrfFilter).build(); 
			 | 
		
	
		
			
				 | 
				 | 
			
			
				 		client.post().uri("/").contentType(MediaType.MULTIPART_MIXED) 
			 |