|
@@ -31,6 +31,7 @@ import javax.servlet.http.HttpServletRequest;
|
|
|
import javax.servlet.http.HttpServletResponse;
|
|
|
|
|
|
import org.springframework.http.HttpMethod;
|
|
|
+import org.springframework.util.Assert;
|
|
|
|
|
|
/**
|
|
|
* <p>
|
|
@@ -83,7 +84,7 @@ public class StrictHttpFirewall implements HttpFirewall {
|
|
|
* Used to specify to {@link #setAllowedHttpMethods(Collection)} that any HTTP method
|
|
|
* should be allowed.
|
|
|
*/
|
|
|
- private static final Set<String> ALLOW_ANY_HTTP_METHOD = Collections.unmodifiableSet(Collections.emptySet());
|
|
|
+ private static final Set<String> ALLOW_ANY_HTTP_METHOD = Collections.emptySet();
|
|
|
|
|
|
private static final String ENCODED_PERCENT = "%25";
|
|
|
|
|
@@ -165,15 +166,9 @@ public class StrictHttpFirewall implements HttpFirewall {
|
|
|
* @see #setUnsafeAllowAnyHttpMethod(boolean)
|
|
|
*/
|
|
|
public void setAllowedHttpMethods(Collection<String> allowedHttpMethods) {
|
|
|
- if (allowedHttpMethods == null) {
|
|
|
- throw new IllegalArgumentException("allowedHttpMethods cannot be null");
|
|
|
- }
|
|
|
- if (allowedHttpMethods == ALLOW_ANY_HTTP_METHOD) {
|
|
|
- this.allowedHttpMethods = ALLOW_ANY_HTTP_METHOD;
|
|
|
- }
|
|
|
- else {
|
|
|
- this.allowedHttpMethods = new HashSet<>(allowedHttpMethods);
|
|
|
- }
|
|
|
+ Assert.notNull(allowedHttpMethods, "allowedHttpMethods cannot be null");
|
|
|
+ this.allowedHttpMethods = (allowedHttpMethods != ALLOW_ANY_HTTP_METHOD) ? new HashSet<>(allowedHttpMethods)
|
|
|
+ : ALLOW_ANY_HTTP_METHOD;
|
|
|
}
|
|
|
|
|
|
/**
|
|
@@ -361,9 +356,7 @@ public class StrictHttpFirewall implements HttpFirewall {
|
|
|
* @see Character#isDefined(int)
|
|
|
*/
|
|
|
public void setAllowedHeaderNames(Predicate<String> allowedHeaderNames) {
|
|
|
- if (allowedHeaderNames == null) {
|
|
|
- throw new IllegalArgumentException("allowedHeaderNames cannot be null");
|
|
|
- }
|
|
|
+ Assert.notNull(allowedHeaderNames, "allowedHeaderNames cannot be null");
|
|
|
this.allowedHeaderNames = allowedHeaderNames;
|
|
|
}
|
|
|
|
|
@@ -378,28 +371,20 @@ public class StrictHttpFirewall implements HttpFirewall {
|
|
|
* @see Character#isDefined(int)
|
|
|
*/
|
|
|
public void setAllowedHeaderValues(Predicate<String> allowedHeaderValues) {
|
|
|
- if (allowedHeaderValues == null) {
|
|
|
- throw new IllegalArgumentException("allowedHeaderValues cannot be null");
|
|
|
- }
|
|
|
+ Assert.notNull(allowedHeaderValues, "allowedHeaderValues cannot be null");
|
|
|
this.allowedHeaderValues = allowedHeaderValues;
|
|
|
}
|
|
|
|
|
|
- /*
|
|
|
+ /**
|
|
|
* Determines which parameter names should be allowed. The default is to reject header
|
|
|
- * names that contain ISO control characters and characters that are not defined. </p>
|
|
|
- *
|
|
|
+ * names that contain ISO control characters and characters that are not defined.
|
|
|
* @param allowedParameterNames the predicate for testing parameter names
|
|
|
- *
|
|
|
+ * @since 5.4
|
|
|
* @see Character#isISOControl(int)
|
|
|
- *
|
|
|
* @see Character#isDefined(int)
|
|
|
- *
|
|
|
- * @since 5.4
|
|
|
*/
|
|
|
public void setAllowedParameterNames(Predicate<String> allowedParameterNames) {
|
|
|
- if (allowedParameterNames == null) {
|
|
|
- throw new IllegalArgumentException("allowedParameterNames cannot be null");
|
|
|
- }
|
|
|
+ Assert.notNull(allowedParameterNames, "allowedParameterNames cannot be null");
|
|
|
this.allowedParameterNames = allowedParameterNames;
|
|
|
}
|
|
|
|
|
@@ -412,9 +397,7 @@ public class StrictHttpFirewall implements HttpFirewall {
|
|
|
* @since 5.4
|
|
|
*/
|
|
|
public void setAllowedParameterValues(Predicate<String> allowedParameterValues) {
|
|
|
- if (allowedParameterValues == null) {
|
|
|
- throw new IllegalArgumentException("allowedParameterValues cannot be null");
|
|
|
- }
|
|
|
+ Assert.notNull(allowedParameterValues, "allowedParameterValues cannot be null");
|
|
|
this.allowedParameterValues = allowedParameterValues;
|
|
|
}
|
|
|
|
|
@@ -426,9 +409,7 @@ public class StrictHttpFirewall implements HttpFirewall {
|
|
|
* @since 5.2
|
|
|
*/
|
|
|
public void setAllowedHostnames(Predicate<String> allowedHostnames) {
|
|
|
- if (allowedHostnames == null) {
|
|
|
- throw new IllegalArgumentException("allowedHostnames cannot be null");
|
|
|
- }
|
|
|
+ Assert.notNull(allowedHostnames, "allowedHostnames cannot be null");
|
|
|
this.allowedHostnames = allowedHostnames;
|
|
|
}
|
|
|
|
|
@@ -447,173 +428,15 @@ public class StrictHttpFirewall implements HttpFirewall {
|
|
|
rejectForbiddenHttpMethod(request);
|
|
|
rejectedBlocklistedUrls(request);
|
|
|
rejectedUntrustedHosts(request);
|
|
|
-
|
|
|
if (!isNormalized(request)) {
|
|
|
throw new RequestRejectedException("The request was rejected because the URL was not normalized.");
|
|
|
}
|
|
|
-
|
|
|
String requestUri = request.getRequestURI();
|
|
|
if (!containsOnlyPrintableAsciiCharacters(requestUri)) {
|
|
|
throw new RequestRejectedException(
|
|
|
"The requestURI was rejected because it can only contain printable ASCII characters.");
|
|
|
}
|
|
|
- return new FirewalledRequest(request) {
|
|
|
- @Override
|
|
|
- public long getDateHeader(String name) {
|
|
|
- if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
|
|
|
- throw new RequestRejectedException(
|
|
|
- "The request was rejected because the header name \"" + name + "\" is not allowed.");
|
|
|
- }
|
|
|
- return super.getDateHeader(name);
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public int getIntHeader(String name) {
|
|
|
- if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
|
|
|
- throw new RequestRejectedException(
|
|
|
- "The request was rejected because the header name \"" + name + "\" is not allowed.");
|
|
|
- }
|
|
|
- return super.getIntHeader(name);
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public String getHeader(String name) {
|
|
|
- if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
|
|
|
- throw new RequestRejectedException(
|
|
|
- "The request was rejected because the header name \"" + name + "\" is not allowed.");
|
|
|
- }
|
|
|
- String value = super.getHeader(name);
|
|
|
- if (value != null && !StrictHttpFirewall.this.allowedHeaderValues.test(value)) {
|
|
|
- throw new RequestRejectedException(
|
|
|
- "The request was rejected because the header value \"" + value + "\" is not allowed.");
|
|
|
- }
|
|
|
- return value;
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public Enumeration<String> getHeaders(String name) {
|
|
|
- if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
|
|
|
- throw new RequestRejectedException(
|
|
|
- "The request was rejected because the header name \"" + name + "\" is not allowed.");
|
|
|
- }
|
|
|
-
|
|
|
- Enumeration<String> valuesEnumeration = super.getHeaders(name);
|
|
|
- return new Enumeration<String>() {
|
|
|
- @Override
|
|
|
- public boolean hasMoreElements() {
|
|
|
- return valuesEnumeration.hasMoreElements();
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public String nextElement() {
|
|
|
- String value = valuesEnumeration.nextElement();
|
|
|
- if (!StrictHttpFirewall.this.allowedHeaderValues.test(value)) {
|
|
|
- throw new RequestRejectedException("The request was rejected because the header value \""
|
|
|
- + value + "\" is not allowed.");
|
|
|
- }
|
|
|
- return value;
|
|
|
- }
|
|
|
- };
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public Enumeration<String> getHeaderNames() {
|
|
|
- Enumeration<String> namesEnumeration = super.getHeaderNames();
|
|
|
- return new Enumeration<String>() {
|
|
|
- @Override
|
|
|
- public boolean hasMoreElements() {
|
|
|
- return namesEnumeration.hasMoreElements();
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public String nextElement() {
|
|
|
- String name = namesEnumeration.nextElement();
|
|
|
- if (!StrictHttpFirewall.this.allowedHeaderNames.test(name)) {
|
|
|
- throw new RequestRejectedException("The request was rejected because the header name \""
|
|
|
- + name + "\" is not allowed.");
|
|
|
- }
|
|
|
- return name;
|
|
|
- }
|
|
|
- };
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public String getParameter(String name) {
|
|
|
- if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
|
|
|
- throw new RequestRejectedException(
|
|
|
- "The request was rejected because the parameter name \"" + name + "\" is not allowed.");
|
|
|
- }
|
|
|
- String value = super.getParameter(name);
|
|
|
- if (value != null && !StrictHttpFirewall.this.allowedParameterValues.test(value)) {
|
|
|
- throw new RequestRejectedException(
|
|
|
- "The request was rejected because the parameter value \"" + value + "\" is not allowed.");
|
|
|
- }
|
|
|
- return value;
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public Map<String, String[]> getParameterMap() {
|
|
|
- Map<String, String[]> parameterMap = super.getParameterMap();
|
|
|
- for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
|
|
|
- String name = entry.getKey();
|
|
|
- String[] values = entry.getValue();
|
|
|
- if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
|
|
|
- throw new RequestRejectedException(
|
|
|
- "The request was rejected because the parameter name \"" + name + "\" is not allowed.");
|
|
|
- }
|
|
|
- for (String value : values) {
|
|
|
- if (!StrictHttpFirewall.this.allowedParameterValues.test(value)) {
|
|
|
- throw new RequestRejectedException("The request was rejected because the parameter value \""
|
|
|
- + value + "\" is not allowed.");
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- return parameterMap;
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public Enumeration<String> getParameterNames() {
|
|
|
- Enumeration<String> namesEnumeration = super.getParameterNames();
|
|
|
- return new Enumeration<String>() {
|
|
|
- @Override
|
|
|
- public boolean hasMoreElements() {
|
|
|
- return namesEnumeration.hasMoreElements();
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public String nextElement() {
|
|
|
- String name = namesEnumeration.nextElement();
|
|
|
- if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
|
|
|
- throw new RequestRejectedException("The request was rejected because the parameter name \""
|
|
|
- + name + "\" is not allowed.");
|
|
|
- }
|
|
|
- return name;
|
|
|
- }
|
|
|
- };
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public String[] getParameterValues(String name) {
|
|
|
- if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
|
|
|
- throw new RequestRejectedException(
|
|
|
- "The request was rejected because the parameter name \"" + name + "\" is not allowed.");
|
|
|
- }
|
|
|
- String[] values = super.getParameterValues(name);
|
|
|
- if (values != null) {
|
|
|
- for (String value : values) {
|
|
|
- if (!StrictHttpFirewall.this.allowedParameterValues.test(value)) {
|
|
|
- throw new RequestRejectedException("The request was rejected because the parameter value \""
|
|
|
- + value + "\" is not allowed.");
|
|
|
- }
|
|
|
- }
|
|
|
- }
|
|
|
- return values;
|
|
|
- }
|
|
|
-
|
|
|
- @Override
|
|
|
- public void reset() {
|
|
|
- }
|
|
|
- };
|
|
|
+ return new StrictFirewalledRequest(request);
|
|
|
}
|
|
|
|
|
|
private void rejectForbiddenHttpMethod(HttpServletRequest request) {
|
|
@@ -705,12 +528,11 @@ public class StrictHttpFirewall implements HttpFirewall {
|
|
|
private static boolean containsOnlyPrintableAsciiCharacters(String uri) {
|
|
|
int length = uri.length();
|
|
|
for (int i = 0; i < length; i++) {
|
|
|
- char c = uri.charAt(i);
|
|
|
- if (c < '\u0020' || c > '\u007e') {
|
|
|
+ char ch = uri.charAt(i);
|
|
|
+ if (ch < '\u0020' || ch > '\u007e') {
|
|
|
return false;
|
|
|
}
|
|
|
}
|
|
|
-
|
|
|
return true;
|
|
|
}
|
|
|
|
|
@@ -728,22 +550,17 @@ public class StrictHttpFirewall implements HttpFirewall {
|
|
|
if (path == null) {
|
|
|
return true;
|
|
|
}
|
|
|
-
|
|
|
- for (int j = path.length(); j > 0;) {
|
|
|
- int i = path.lastIndexOf('/', j - 1);
|
|
|
- int gap = j - i;
|
|
|
-
|
|
|
- if (gap == 2 && path.charAt(i + 1) == '.') {
|
|
|
- // ".", "/./" or "/."
|
|
|
- return false;
|
|
|
+ for (int i = path.length(); i > 0;) {
|
|
|
+ int slashIndex = path.lastIndexOf('/', i - 1);
|
|
|
+ int gap = i - slashIndex;
|
|
|
+ if (gap == 2 && path.charAt(slashIndex + 1) == '.') {
|
|
|
+ return false; // ".", "/./" or "/."
|
|
|
}
|
|
|
- else if (gap == 3 && path.charAt(i + 1) == '.' && path.charAt(i + 2) == '.') {
|
|
|
+ if (gap == 3 && path.charAt(slashIndex + 1) == '.' && path.charAt(slashIndex + 2) == '.') {
|
|
|
return false;
|
|
|
}
|
|
|
-
|
|
|
- j = i;
|
|
|
+ i = slashIndex;
|
|
|
}
|
|
|
-
|
|
|
return true;
|
|
|
}
|
|
|
|
|
@@ -782,4 +599,166 @@ public class StrictHttpFirewall implements HttpFirewall {
|
|
|
return getDecodedUrlBlocklist();
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * Strict {@link FirewalledRequest}.
|
|
|
+ */
|
|
|
+ private class StrictFirewalledRequest extends FirewalledRequest {
|
|
|
+
|
|
|
+ StrictFirewalledRequest(HttpServletRequest request) {
|
|
|
+ super(request);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public long getDateHeader(String name) {
|
|
|
+ validateAllowedHeaderName(name);
|
|
|
+ return super.getDateHeader(name);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public int getIntHeader(String name) {
|
|
|
+ validateAllowedHeaderName(name);
|
|
|
+ return super.getIntHeader(name);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String getHeader(String name) {
|
|
|
+ validateAllowedHeaderName(name);
|
|
|
+ String value = super.getHeader(name);
|
|
|
+ if (value != null) {
|
|
|
+ validateAllowedHeaderValue(value);
|
|
|
+ }
|
|
|
+ return value;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Enumeration<String> getHeaders(String name) {
|
|
|
+ validateAllowedHeaderName(name);
|
|
|
+ Enumeration<String> headers = super.getHeaders(name);
|
|
|
+ return new Enumeration<String>() {
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean hasMoreElements() {
|
|
|
+ return headers.hasMoreElements();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String nextElement() {
|
|
|
+ String value = headers.nextElement();
|
|
|
+ validateAllowedHeaderValue(value);
|
|
|
+ return value;
|
|
|
+ }
|
|
|
+
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Enumeration<String> getHeaderNames() {
|
|
|
+ Enumeration<String> names = super.getHeaderNames();
|
|
|
+ return new Enumeration<String>() {
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean hasMoreElements() {
|
|
|
+ return names.hasMoreElements();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String nextElement() {
|
|
|
+ String headerNames = names.nextElement();
|
|
|
+ validateAllowedHeaderName(headerNames);
|
|
|
+ return headerNames;
|
|
|
+ }
|
|
|
+
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String getParameter(String name) {
|
|
|
+ validateAllowedParameterName(name);
|
|
|
+ String value = super.getParameter(name);
|
|
|
+ if (value != null) {
|
|
|
+ validateAllowedParameterValue(value);
|
|
|
+ }
|
|
|
+ return value;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Map<String, String[]> getParameterMap() {
|
|
|
+ Map<String, String[]> parameterMap = super.getParameterMap();
|
|
|
+ for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
|
|
|
+ String name = entry.getKey();
|
|
|
+ String[] values = entry.getValue();
|
|
|
+ validateAllowedParameterName(name);
|
|
|
+ for (String value : values) {
|
|
|
+ validateAllowedParameterValue(value);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return parameterMap;
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public Enumeration<String> getParameterNames() {
|
|
|
+ Enumeration<String> paramaterNames = super.getParameterNames();
|
|
|
+ return new Enumeration<String>() {
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public boolean hasMoreElements() {
|
|
|
+ return paramaterNames.hasMoreElements();
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String nextElement() {
|
|
|
+ String name = paramaterNames.nextElement();
|
|
|
+ validateAllowedParameterName(name);
|
|
|
+ return name;
|
|
|
+ }
|
|
|
+
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public String[] getParameterValues(String name) {
|
|
|
+ validateAllowedParameterName(name);
|
|
|
+ String[] values = super.getParameterValues(name);
|
|
|
+ if (values != null) {
|
|
|
+ for (String value : values) {
|
|
|
+ validateAllowedParameterValue(value);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return values;
|
|
|
+ }
|
|
|
+
|
|
|
+ private void validateAllowedHeaderName(String headerNames) {
|
|
|
+ if (!StrictHttpFirewall.this.allowedHeaderNames.test(headerNames)) {
|
|
|
+ throw new RequestRejectedException(
|
|
|
+ "The request was rejected because the header name \"" + headerNames + "\" is not allowed.");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void validateAllowedHeaderValue(String value) {
|
|
|
+ if (!StrictHttpFirewall.this.allowedHeaderValues.test(value)) {
|
|
|
+ throw new RequestRejectedException(
|
|
|
+ "The request was rejected because the header value \"" + value + "\" is not allowed.");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void validateAllowedParameterName(String name) {
|
|
|
+ if (!StrictHttpFirewall.this.allowedParameterNames.test(name)) {
|
|
|
+ throw new RequestRejectedException(
|
|
|
+ "The request was rejected because the parameter name \"" + name + "\" is not allowed.");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private void validateAllowedParameterValue(String value) {
|
|
|
+ if (!StrictHttpFirewall.this.allowedParameterValues.test(value)) {
|
|
|
+ throw new RequestRejectedException(
|
|
|
+ "The request was rejected because the parameter value \"" + value + "\" is not allowed.");
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Override
|
|
|
+ public void reset() {
|
|
|
+ }
|
|
|
+
|
|
|
+ };
|
|
|
+
|
|
|
}
|