소스 검색

CsrfWebFilter supports multipart/form-data

Fixes gh-7576
Rob Winch 5 년 전
부모
커밋
635f7e1edd

+ 13 - 0
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -2731,6 +2731,19 @@ public class ServerHttpSecurity {
 			return this;
 		}
 
+		/**
+		 * Specifies if {@link CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart
+		 * data requests.
+		 *
+		 * @param enabled true if should read from multipart form body, else false. Default is false
+		 * @return the {@link CsrfSpec} for additional configuration
+		 */
+		public CsrfSpec tokenFromMultipartDataEnabled(boolean enabled) {
+			this.filter.setTokenFromMultipartDataEnabled(enabled);
+			return this;
+		}
+
+
 		/**
 		 * Allows method chaining to continue configuring the {@link ServerHttpSecurity}
 		 * @return the {@link ServerHttpSecurity} to continue configuring

+ 1 - 0
gradle/dependency-management.gradle

@@ -210,6 +210,7 @@ dependencyManagement {
 		dependency 'org.slf4j:slf4j-nop:1.7.28'
 		dependency 'org.sonatype.sisu.inject:cglib:2.2.1-v20090111'
 		dependency 'org.springframework.ldap:spring-ldap-core:2.3.2.RELEASE'
+		dependency 'org.synchronoss.cloud:nio-multipart-parser:1.1.0'
 		dependency 'org.thymeleaf:thymeleaf-spring5:3.0.11.RELEASE'
 		dependency 'org.unbescape:unbescape:1.1.5.RELEASE'
 		dependency 'org.w3c.css:sac:1.3'

+ 1 - 0
web/spring-security-web.gradle

@@ -25,6 +25,7 @@ dependencies {
 	testCompile 'org.codehaus.groovy:groovy-all'
 	testCompile 'org.skyscreamer:jsonassert'
 	testCompile 'org.springframework:spring-webflux'
+	testCompile 'org.synchronoss.cloud:nio-multipart-parser'
 	testCompile powerMock2Dependencies
 	testCompile spockDependencies
 

+ 37 - 6
web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java

@@ -16,14 +16,12 @@
 
 package org.springframework.security.web.server.csrf;
 
-import java.util.Arrays;
-import java.util.HashSet;
-import java.util.Set;
-
-import reactor.core.publisher.Mono;
-
+import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpStatus;
+import org.springframework.http.MediaType;
+import org.springframework.http.codec.multipart.FormFieldPart;
+import org.springframework.http.server.reactive.ServerHttpRequest;
 import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
 import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
@@ -31,6 +29,11 @@ import org.springframework.util.Assert;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
 import org.springframework.web.server.WebFilterChain;
+import reactor.core.publisher.Mono;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
 
 import static java.lang.Boolean.TRUE;
 
@@ -78,6 +81,8 @@ public class CsrfWebFilter implements WebFilter {
 
 	private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN);
 
+	private boolean isTokenFromMultipartDataEnabled;
+
 	public void setAccessDeniedHandler(
 		ServerAccessDeniedHandler accessDeniedHandler) {
 		Assert.notNull(accessDeniedHandler, "accessDeniedHandler");
@@ -96,6 +101,15 @@ public class CsrfWebFilter implements WebFilter {
 		this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
 	}
 
+	/**
+	 * Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart
+	 * data requests.
+	 * @param tokenFromMultipartDataEnabled true if should read from multipart form body, else false. Default is false
+	 */
+	public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) {
+		this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled;
+	}
+
 	@Override
 	public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
 		if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) {
@@ -128,9 +142,26 @@ public class CsrfWebFilter implements WebFilter {
 		return exchange.getFormData()
 			.flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
 			.switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
+			.switchIfEmpty(tokenFromMultipartData(exchange, expected))
 			.map(actual -> actual.equals(expected.getToken()));
 	}
 
+	private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
+		if (!this.isTokenFromMultipartDataEnabled) {
+			return Mono.empty();
+		}
+		ServerHttpRequest request = exchange.getRequest();
+		HttpHeaders headers = request.getHeaders();
+		MediaType contentType = headers.getContentType();
+		if (!contentType.includes(MediaType.MULTIPART_FORM_DATA)) {
+			return Mono.empty();
+		}
+		return exchange.getMultipartData()
+			.map(d -> d.getFirst(expected.getParameterName()))
+			.cast(FormFieldPart.class)
+			.map(FormFieldPart::value);
+	}
+
 	private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
 		return Mono.defer(() ->{
 			Mono<CsrfToken> csrfToken = csrfToken(exchange);

+ 96 - 5
web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java

@@ -20,17 +20,20 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
-import reactor.core.publisher.Mono;
-import reactor.test.StepVerifier;
-import reactor.test.publisher.PublisherProbe;
-
 import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
 import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
 import org.springframework.mock.web.server.MockServerWebExchange;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
+import org.springframework.test.web.reactive.server.WebTestClient;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilterChain;
 import org.springframework.web.server.WebSession;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+import reactor.test.publisher.PublisherProbe;
 
 import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
 import static org.mockito.ArgumentMatchers.any;
@@ -38,6 +41,7 @@ import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
 import static org.springframework.mock.web.server.MockServerWebExchange.from;
+import static org.springframework.web.reactive.function.BodyInserters.fromMultipartData;
 
 /**
  * @author Rob Winch
@@ -57,7 +61,7 @@ public class CsrfWebFilterTests {
 	private MockServerWebExchange get = from(
 		MockServerHttpRequest.get("/"));
 
-	private MockServerWebExchange post = from(
+	private ServerWebExchange post = from(
 		MockServerHttpRequest.post("/"));
 
 	@Test
@@ -193,4 +197,91 @@ public class CsrfWebFilterTests {
 
 		verifyZeroInteractions(matcher);
 	}
+
+	@Test
+	public void filterWhenMultipartFormDataAndNotEnabledThenDenied() {
+		this.csrfFilter.setCsrfTokenRepository(this.repository);
+		when(this.repository.loadToken(any()))
+				.thenReturn(Mono.just(this.token));
+
+		WebTestClient client = WebTestClient.bindToController(new OkController())
+				.webFilter(this.csrfFilter)
+				.build();
+
+		client.post()
+				.uri("/")
+				.contentType(MediaType.MULTIPART_FORM_DATA)
+				.body(fromMultipartData(this.token.getParameterName(), this.token.getToken()))
+				.exchange()
+				.expectStatus().isForbidden();
+	}
+
+	@Test
+	public void filterWhenMultipartFormDataAndEnabledThenGranted() {
+		this.csrfFilter.setCsrfTokenRepository(this.repository);
+		this.csrfFilter.setTokenFromMultipartDataEnabled(true);
+		when(this.repository.loadToken(any()))
+				.thenReturn(Mono.just(this.token));
+		when(this.repository.generateToken(any()))
+				.thenReturn(Mono.just(this.token));
+
+		WebTestClient client = WebTestClient.bindToController(new OkController())
+			.webFilter(this.csrfFilter)
+			.build();
+
+		client.post()
+			.uri("/")
+			.contentType(MediaType.MULTIPART_FORM_DATA)
+			.body(fromMultipartData(this.token.getParameterName(), this.token.getToken()))
+			.exchange()
+				.expectStatus().is2xxSuccessful();
+	}
+
+	@Test
+	public void filterWhenFormDataAndEnabledThenGranted() {
+		this.csrfFilter.setCsrfTokenRepository(this.repository);
+		this.csrfFilter.setTokenFromMultipartDataEnabled(true);
+		when(this.repository.loadToken(any()))
+				.thenReturn(Mono.just(this.token));
+		when(this.repository.generateToken(any()))
+				.thenReturn(Mono.just(this.token));
+
+		WebTestClient client = WebTestClient.bindToController(new OkController())
+				.webFilter(this.csrfFilter)
+				.build();
+
+		client.post()
+				.uri("/")
+				.contentType(MediaType.APPLICATION_FORM_URLENCODED)
+				.bodyValue(this.token.getParameterName() + "="+this.token.getToken())
+				.exchange()
+				.expectStatus().is2xxSuccessful();
+	}
+
+	@Test
+	public void filterWhenMultipartMixedAndEnabledThenNotRead() {
+		this.csrfFilter.setCsrfTokenRepository(this.repository);
+		this.csrfFilter.setTokenFromMultipartDataEnabled(true);
+		when(this.repository.loadToken(any()))
+				.thenReturn(Mono.just(this.token));
+
+		WebTestClient client = WebTestClient.bindToController(new OkController())
+				.webFilter(this.csrfFilter)
+				.build();
+
+		client.post()
+				.uri("/")
+				.contentType(MediaType.MULTIPART_MIXED)
+				.bodyValue(this.token.getParameterName() + "="+this.token.getToken())
+				.exchange()
+				.expectStatus().isForbidden();
+	}
+
+	@RestController
+	static class OkController {
+		@RequestMapping("/**")
+		String ok() {
+			return "ok";
+		}
+	}
 }