瀏覽代碼

StrictFirewallHttpRequest.buid returns StrictFirewallHttpRequest

Closes gh-16069
Rob Winch 8 月之前
父節點
當前提交
6a0b683e60

+ 64 - 0
web/src/main/java/org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewall.java

@@ -16,6 +16,8 @@
 
 package org.springframework.security.web.server.firewall;
 
+import java.net.InetSocketAddress;
+import java.net.URI;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
@@ -23,6 +25,7 @@ import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.Consumer;
 import java.util.function.Predicate;
 import java.util.regex.Pattern;
 
@@ -33,6 +36,7 @@ import org.springframework.http.HttpMethod;
 import org.springframework.http.server.reactive.ServerHttpRequest;
 import org.springframework.http.server.reactive.ServerHttpRequestDecorator;
 import org.springframework.http.server.reactive.ServerHttpResponse;
+import org.springframework.http.server.reactive.SslInfo;
 import org.springframework.util.Assert;
 import org.springframework.util.MultiValueMap;
 import org.springframework.web.server.ServerWebExchange;
@@ -743,6 +747,11 @@ public class StrictServerWebExchangeFirewall implements ServerWebExchangeFirewal
 				return queryParams;
 			}
 
+			@Override
+			public Builder mutate() {
+				return new StrictFirewallBuilder(super.mutate());
+			}
+
 			private final class StrictFirewallHttpHeaders extends HttpHeaders {
 
 				private StrictFirewallHttpHeaders(HttpHeaders delegate) {
@@ -783,6 +792,61 @@ public class StrictServerWebExchangeFirewall implements ServerWebExchangeFirewal
 
 			}
 
+			private final class StrictFirewallBuilder implements Builder {
+
+				private final Builder delegate;
+
+				private StrictFirewallBuilder(Builder delegate) {
+					this.delegate = delegate;
+				}
+
+				@Override
+				public Builder method(HttpMethod httpMethod) {
+					return this.delegate.method(httpMethod);
+				}
+
+				@Override
+				public Builder uri(URI uri) {
+					return this.delegate.uri(uri);
+				}
+
+				@Override
+				public Builder path(String path) {
+					return this.delegate.path(path);
+				}
+
+				@Override
+				public Builder contextPath(String contextPath) {
+					return this.delegate.contextPath(contextPath);
+				}
+
+				@Override
+				public Builder header(String headerName, String... headerValues) {
+					return this.delegate.header(headerName, headerValues);
+				}
+
+				@Override
+				public Builder headers(Consumer<HttpHeaders> headersConsumer) {
+					return this.delegate.headers(headersConsumer);
+				}
+
+				@Override
+				public Builder sslInfo(SslInfo sslInfo) {
+					return this.delegate.sslInfo(sslInfo);
+				}
+
+				@Override
+				public Builder remoteAddress(InetSocketAddress remoteAddress) {
+					return this.delegate.remoteAddress(remoteAddress);
+				}
+
+				@Override
+				public ServerHttpRequest build() {
+					return new StrictFirewallHttpRequest(this.delegate.build());
+				}
+
+			}
+
 		}
 
 	}

+ 21 - 0
web/src/test/java/org/springframework/security/web/server/firewall/StrictServerWebExchangeFirewallTests.java

@@ -513,4 +513,25 @@ class StrictServerWebExchangeFirewallTests {
 		assertThat(exchange.getRequest().getHeaders().get(null)).isNull();
 	}
 
+	@Test
+	void getFirewalledExchangeWhenMutateThenHeadersStillFirewalled() {
+		String invalidHeaderName = "bad name";
+		this.firewall.setAllowedHeaderNames((name) -> !name.equals(invalidHeaderName));
+		ServerWebExchange exchange = getFirewalledExchange();
+		ServerWebExchange mutatedExchange = exchange.mutate().request(exchange.getRequest().mutate().build()).build();
+		HttpHeaders headers = mutatedExchange.getRequest().getHeaders();
+		assertThatExceptionOfType(ServerExchangeRejectedException.class)
+			.isThrownBy(() -> headers.get(invalidHeaderName));
+	}
+
+	@Test
+	void getMutatedFirewalledExchangeGetHeaderWhenNotAllowedHeaderNameThenException() {
+		String invalidHeaderName = "bad name";
+		this.firewall.setAllowedHeaderNames((name) -> !name.equals(invalidHeaderName));
+		ServerWebExchange exchange = getFirewalledExchange();
+		HttpHeaders headers = exchange.getRequest().mutate().build().getHeaders();
+		assertThatExceptionOfType(ServerExchangeRejectedException.class)
+			.isThrownBy(() -> headers.get(invalidHeaderName));
+	}
+
 }