|
@@ -1,5 +1,5 @@
|
|
|
/*
|
|
|
- * Copyright 2012-2017 the original author or authors.
|
|
|
+ * Copyright 2012-2020 the original author or authors.
|
|
|
*
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
* you may not use this file except in compliance with the License.
|
|
@@ -24,6 +24,7 @@ import java.util.Collections;
|
|
|
import java.util.HashSet;
|
|
|
import java.util.List;
|
|
|
import java.util.Set;
|
|
|
+import java.util.function.Predicate;
|
|
|
|
|
|
/**
|
|
|
* <p>
|
|
@@ -59,10 +60,15 @@ import java.util.Set;
|
|
|
* Rejects URLs that contain a URL encoded percent. See
|
|
|
* {@link #setAllowUrlEncodedPercent(boolean)}
|
|
|
* </li>
|
|
|
+ * <li>
|
|
|
+ * Rejects hosts that are not allowed. See
|
|
|
+ * {@link #setAllowedHostnames(Predicate)}
|
|
|
+ * </li>
|
|
|
* </ul>
|
|
|
*
|
|
|
* @see DefaultHttpFirewall
|
|
|
* @author Rob Winch
|
|
|
+ * @author Eddú Meléndez
|
|
|
* @since 5.0.1
|
|
|
*/
|
|
|
public class StrictHttpFirewall implements HttpFirewall {
|
|
@@ -82,6 +88,8 @@ public class StrictHttpFirewall implements HttpFirewall {
|
|
|
|
|
|
private Set<String> decodedUrlBlacklist = new HashSet<String>();
|
|
|
|
|
|
+ private Predicate<String> allowedHostnames = hostname -> true;
|
|
|
+
|
|
|
public StrictHttpFirewall() {
|
|
|
urlBlacklistsAddAll(FORBIDDEN_SEMICOLON);
|
|
|
urlBlacklistsAddAll(FORBIDDEN_FORWARDSLASH);
|
|
@@ -230,6 +238,13 @@ public class StrictHttpFirewall implements HttpFirewall {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ public void setAllowedHostnames(Predicate<String> allowedHostnames) {
|
|
|
+ if (allowedHostnames == null) {
|
|
|
+ throw new IllegalArgumentException("allowedHostnames cannot be null");
|
|
|
+ }
|
|
|
+ this.allowedHostnames = allowedHostnames;
|
|
|
+ }
|
|
|
+
|
|
|
private void urlBlacklistsAddAll(Collection<String> values) {
|
|
|
this.encodedUrlBlacklist.addAll(values);
|
|
|
this.decodedUrlBlacklist.addAll(values);
|
|
@@ -243,6 +258,7 @@ public class StrictHttpFirewall implements HttpFirewall {
|
|
|
@Override
|
|
|
public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws RequestRejectedException {
|
|
|
rejectedBlacklistedUrls(request);
|
|
|
+ rejectedUntrustedHosts(request);
|
|
|
|
|
|
if (!isNormalized(request)) {
|
|
|
throw new RequestRejectedException("The request was rejected because the URL was not normalized.");
|
|
@@ -272,6 +288,13 @@ public class StrictHttpFirewall implements HttpFirewall {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ private void rejectedUntrustedHosts(HttpServletRequest request) {
|
|
|
+ String serverName = request.getServerName();
|
|
|
+ if (serverName != null && !this.allowedHostnames.test(serverName)) {
|
|
|
+ throw new RequestRejectedException("The request was rejected because the domain " + serverName + " is untrusted.");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
@Override
|
|
|
public HttpServletResponse getFirewalledResponse(HttpServletResponse response) {
|
|
|
return new FirewalledResponse(response);
|