浏览代码

Add support for allowedHostnames in StrictHttpFirewall

Introduce a new method `setAllowedHostnames` which perform the validation
against untrusted hostnames.

Fixes gh-4310
Eddú Meléndez 6 年之前
父节点
当前提交
3e5b65f647

+ 6 - 1
web/src/main/java/org/springframework/security/web/FilterInvocation.java

@@ -228,10 +228,15 @@ class DummyRequest extends HttpServletRequestWrapper {
 	public void setQueryString(String queryString) {
 		this.queryString = queryString;
 	}
+
+	@Override
+	public String getServerName() {
+		return null;
+	}
 }
 
 final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler {
 	public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
 		throw new UnsupportedOperationException(method + " is not supported");
 	}
-}
+}

+ 24 - 1
web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2012-2017 the original author or authors.
+ * Copyright 2012-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@ import java.util.Collections;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
+import java.util.function.Predicate;
 
 /**
  * <p>
@@ -59,10 +60,15 @@ import java.util.Set;
  * Rejects URLs that contain a URL encoded percent. See
  * {@link #setAllowUrlEncodedPercent(boolean)}
  * </li>
+ * <li>
+ * Rejects hosts that are not allowed. See
+ * {@link #setAllowedHostnames(Predicate)}
+ * </li>
  * </ul>
  *
  * @see DefaultHttpFirewall
  * @author Rob Winch
+ * @author Eddú Meléndez
  * @since 5.0.1
  */
 public class StrictHttpFirewall implements HttpFirewall {
@@ -82,6 +88,8 @@ public class StrictHttpFirewall implements HttpFirewall {
 
 	private Set<String> decodedUrlBlacklist = new HashSet<String>();
 
+	private Predicate<String> allowedHostnames = hostname -> true;
+
 	public StrictHttpFirewall() {
 		urlBlacklistsAddAll(FORBIDDEN_SEMICOLON);
 		urlBlacklistsAddAll(FORBIDDEN_FORWARDSLASH);
@@ -230,6 +238,13 @@ public class StrictHttpFirewall implements HttpFirewall {
 		}
 	}
 
+	public void setAllowedHostnames(Predicate<String> allowedHostnames) {
+		if (allowedHostnames == null) {
+			throw new IllegalArgumentException("allowedHostnames cannot be null");
+		}
+		this.allowedHostnames = allowedHostnames;
+	}
+
 	private void urlBlacklistsAddAll(Collection<String> values) {
 		this.encodedUrlBlacklist.addAll(values);
 		this.decodedUrlBlacklist.addAll(values);
@@ -243,6 +258,7 @@ public class StrictHttpFirewall implements HttpFirewall {
 	@Override
 	public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws RequestRejectedException {
 		rejectedBlacklistedUrls(request);
+		rejectedUntrustedHosts(request);
 
 		if (!isNormalized(request)) {
 			throw new RequestRejectedException("The request was rejected because the URL was not normalized.");
@@ -272,6 +288,13 @@ public class StrictHttpFirewall implements HttpFirewall {
 		}
 	}
 
+	private void rejectedUntrustedHosts(HttpServletRequest request) {
+		String serverName = request.getServerName();
+		if (serverName != null && !this.allowedHostnames.test(serverName)) {
+			throw new RequestRejectedException("The request was rejected because the domain " + serverName + " is untrusted.");
+		}
+	}
+
 	@Override
 	public HttpServletResponse getFirewalledResponse(HttpServletResponse response) {
 		return new FirewalledResponse(response);

+ 28 - 1
web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2012-2017 the original author or authors.
+ * Copyright 2012-2020 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@ import static org.assertj.core.api.Assertions.fail;
 
 /**
  * @author Rob Winch
+ * @author Eddú Meléndez
  */
 public class StrictHttpFirewallTests {
 	public String[] unnormalizedPaths = { "/..", "/./path/", "/path/path/.", "/path/path//.", "./path/../path//.",
@@ -373,4 +374,30 @@ public class StrictHttpFirewallTests {
 
 		this.firewall.getFirewalledRequest(request);
 	}
+
+	@Test
+	public void getFirewalledRequestWhenTrustedDomainThenNoException() {
+		String host = "example.org";
+		this.request.addHeader("Host", host);
+		this.firewall.setAllowedHostnames(hostname -> hostname.equals("example.org"));
+
+		try {
+			this.firewall.getFirewalledRequest(this.request);
+		} catch (RequestRejectedException fail) {
+			fail("Host " + host + " was rejected");
+		}
+	}
+
+	@Test
+	public void getFirewalledRequestWhenUntrustedDomainThenException() {
+		String host = "example.org";
+		this.request.addHeader("Host", host);
+		this.firewall.setAllowedHostnames(hostname -> hostname.equals("myexample.org"));
+
+		try {
+			this.firewall.getFirewalledRequest(this.request);
+			fail("Host " + host + " was accepted");
+		} catch (RequestRejectedException expected) {
+		}
+	}
 }