Sfoglia il codice sorgente

Support ServerWebExchangeFirewall @Bean

Closes gh-15987
Rob Winch 10 mesi fa
parent
commit
1ba6301afa

+ 4 - 1
config/src/main/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfiguration.java

@@ -21,6 +21,7 @@ import java.util.List;
 
 import io.micrometer.observation.ObservationRegistry;
 
+import org.springframework.beans.factory.ObjectProvider;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
 import org.springframework.context.ApplicationContext;
@@ -33,6 +34,7 @@ import org.springframework.security.web.reactive.result.view.CsrfRequestDataValu
 import org.springframework.security.web.server.ObservationWebFilterChainDecorator;
 import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.WebFilterChainProxy;
+import org.springframework.security.web.server.firewall.ServerWebExchangeFirewall;
 import org.springframework.util.ClassUtils;
 import org.springframework.util.ObjectUtils;
 import org.springframework.web.reactive.result.view.AbstractView;
@@ -79,11 +81,12 @@ class WebFluxSecurityConfiguration {
 
 	@Bean(SPRING_SECURITY_WEBFILTERCHAINFILTER_BEAN_NAME)
 	@Order(WEB_FILTER_CHAIN_FILTER_ORDER)
-	WebFilterChainProxy springSecurityWebFilterChainFilter() {
+	WebFilterChainProxy springSecurityWebFilterChainFilter(ObjectProvider<ServerWebExchangeFirewall> firewall) {
 		WebFilterChainProxy proxy = new WebFilterChainProxy(getSecurityWebFilterChains());
 		if (!this.observationRegistry.isNoop()) {
 			proxy.setFilterChainDecorator(new ObservationWebFilterChainDecorator(this.observationRegistry));
 		}
+		firewall.ifUnique(proxy::setFirewall);
 		return proxy;
 	}
 

+ 50 - 0
config/src/test/java/org/springframework/security/config/annotation/web/reactive/WebFluxSecurityConfigurationTests.java

@@ -16,14 +16,24 @@
 
 package org.springframework.security.config.annotation.web.reactive;
 
+import java.util.Collections;
+
+import org.jetbrains.annotations.NotNull;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
+import reactor.core.publisher.Mono;
 
+import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
+import org.springframework.http.HttpStatus;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.web.server.MockServerWebExchange;
 import org.springframework.security.config.test.SpringTestContext;
 import org.springframework.security.config.test.SpringTestContextExtension;
 import org.springframework.security.config.users.ReactiveAuthenticationTestConfiguration;
 import org.springframework.security.web.server.WebFilterChainProxy;
+import org.springframework.security.web.server.firewall.ServerWebExchangeFirewall;
+import org.springframework.web.server.handler.DefaultWebFilterChain;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -47,6 +57,32 @@ public class WebFluxSecurityConfigurationTests {
 		assertThat(webFilterChainProxy).isNotNull();
 	}
 
+	@Test
+	void loadConfigWhenDefaultThenFirewalled() throws Exception {
+		this.spring
+			.register(ServerHttpSecurityConfiguration.class, ReactiveAuthenticationTestConfiguration.class,
+					WebFluxSecurityConfiguration.class)
+			.autowire();
+		WebFilterChainProxy webFilterChainProxy = this.spring.getContext().getBean(WebFilterChainProxy.class);
+		MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/;/").build());
+		DefaultWebFilterChain chain = emptyChain();
+		webFilterChainProxy.filter(exchange, chain).block();
+		assertThat(exchange.getResponse().getStatusCode()).isEqualTo(HttpStatus.BAD_REQUEST);
+	}
+
+	@Test
+	void loadConfigWhenFirewallBeanThenCustomized() throws Exception {
+		this.spring
+			.register(ServerHttpSecurityConfiguration.class, ReactiveAuthenticationTestConfiguration.class,
+					WebFluxSecurityConfiguration.class, NoOpFirewallConfig.class)
+			.autowire();
+		WebFilterChainProxy webFilterChainProxy = this.spring.getContext().getBean(WebFilterChainProxy.class);
+		MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/;/").build());
+		DefaultWebFilterChain chain = emptyChain();
+		webFilterChainProxy.filter(exchange, chain).block();
+		assertThat(exchange.getResponse().getStatusCode()).isNotEqualTo(HttpStatus.BAD_REQUEST);
+	}
+
 	@Test
 	public void loadConfigWhenBeanProxyingEnabledAndSubclassThenWebFilterChainProxyExists() {
 		this.spring
@@ -57,6 +93,20 @@ public class WebFluxSecurityConfigurationTests {
 		assertThat(webFilterChainProxy).isNotNull();
 	}
 
+	private static @NotNull DefaultWebFilterChain emptyChain() {
+		return new DefaultWebFilterChain((webExchange) -> Mono.empty(), Collections.emptyList());
+	}
+
+	@Configuration
+	static class NoOpFirewallConfig {
+
+		@Bean
+		ServerWebExchangeFirewall noOpFirewall() {
+			return ServerWebExchangeFirewall.INSECURE_NOOP;
+		}
+
+	}
+
 	@Configuration
 	static class SubclassConfig extends WebFluxSecurityConfiguration {