فهرست منبع

Validate headers and parameters in StrictHttpFirewall

Adds methods to configure validation of header names and values and
parameter names and values:
 * setAllowedHeaderNames(Predicate)
 * setAllowedHeaderValues(Predicate)
 * setAllowedParameterNames(Predicate)
 * setAllowedParameterValues(Predicate)

By default, header names, header values, and parameter names that
contain ISO control characters or unassigned unicode characters are
rejected. No parameter value validation is performed by default.

Issue gh-8644
Craig Andrews 5 سال پیش
والد
کامیت
c71352c548

+ 62 - 0
web/src/main/java/org/springframework/security/web/FilterInvocation.java

@@ -23,6 +23,10 @@ import java.lang.reflect.Constructor;
 import java.lang.reflect.InvocationHandler;
 import java.lang.reflect.InvocationHandler;
 import java.lang.reflect.Method;
 import java.lang.reflect.Method;
 import java.lang.reflect.Proxy;
 import java.lang.reflect.Proxy;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.LinkedHashMap;
+import java.util.Map;
 
 
 import javax.servlet.FilterChain;
 import javax.servlet.FilterChain;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletRequest;
@@ -31,6 +35,7 @@ import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequestWrapper;
 import javax.servlet.http.HttpServletRequestWrapper;
 import javax.servlet.http.HttpServletResponse;
 import javax.servlet.http.HttpServletResponse;
 
 
+import org.springframework.http.HttpHeaders;
 import org.springframework.security.web.util.UrlUtils;
 import org.springframework.security.web.util.UrlUtils;
 
 
 /**
 /**
@@ -161,6 +166,8 @@ class DummyRequest extends HttpServletRequestWrapper {
 	private String pathInfo;
 	private String pathInfo;
 	private String queryString;
 	private String queryString;
 	private String method;
 	private String method;
+	private final HttpHeaders headers = new HttpHeaders();
+	private final Map<String, String[]> parameters = new LinkedHashMap<>();
 
 
 	DummyRequest() {
 	DummyRequest() {
 		super(UNSUPPORTED_REQUEST);
 		super(UNSUPPORTED_REQUEST);
@@ -232,6 +239,61 @@ class DummyRequest extends HttpServletRequestWrapper {
 	public String getServerName() {
 	public String getServerName() {
 		return null;
 		return null;
 	}
 	}
+
+	@Override
+	public String getHeader(String name) {
+		return this.headers.getFirst(name);
+	}
+
+	@Override
+	public Enumeration<String> getHeaders(String name) {
+		return Collections.enumeration(this.headers.get(name));
+	}
+
+	@Override
+	public Enumeration<String> getHeaderNames() {
+		return Collections.enumeration(this.headers.keySet());
+	}
+
+	@Override
+	public int getIntHeader(String name) {
+		String value = this.headers.getFirst(name);
+		if (value == null ) {
+			return -1;
+		}
+		else {
+			return Integer.parseInt(value);
+		}
+	}
+
+	public void addHeader(String name, String value) {
+		this.headers.add(name, value);
+	}
+
+	@Override
+	public String getParameter(String name) {
+		String[] arr = this.parameters.get(name);
+		return (arr != null && arr.length > 0 ? arr[0] : null);
+	}
+
+	@Override
+	public Map<String, String[]> getParameterMap() {
+		return Collections.unmodifiableMap(this.parameters);
+	}
+
+	@Override
+	public Enumeration<String> getParameterNames() {
+		return Collections.enumeration(this.parameters.keySet());
+	}
+
+	@Override
+	public String[] getParameterValues(String name) {
+		return this.parameters.get(name);
+	}
+
+	public void setParameter(String name, String... values) {
+		this.parameters.put(name, values);
+	}
 }
 }
 
 
 final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler {
 final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler {

+ 240 - 0
web/src/main/java/org/springframework/security/web/firewall/StrictHttpFirewall.java

@@ -19,10 +19,13 @@ package org.springframework.security.web.firewall;
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.Collections;
+import java.util.Enumeration;
 import java.util.HashSet;
 import java.util.HashSet;
 import java.util.List;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 import java.util.Set;
 import java.util.function.Predicate;
 import java.util.function.Predicate;
+import java.util.regex.Pattern;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 import javax.servlet.http.HttpServletResponse;
 
 
@@ -74,6 +77,22 @@ import org.springframework.http.HttpMethod;
  * Rejects hosts that are not allowed. See
  * Rejects hosts that are not allowed. See
  * {@link #setAllowedHostnames(Predicate)}
  * {@link #setAllowedHostnames(Predicate)}
  * </li>
  * </li>
+ * <li>
+ * Reject headers names that are not allowed. See
+ * {@link #setAllowedHeaderNames(Predicate)}
+ * </li>
+ * <li>
+ * Reject headers values that are not allowed. See
+ * {@link #setAllowedHeaderValues(Predicate)}
+ * </li>
+ * <li>
+ * Reject parameter names that are not allowed. See
+ * {@link #setAllowedParameterNames(Predicate)}
+ * </li>
+ * <li>
+ * Reject parameter values that are not allowed. See
+ * {@link #setAllowedParameterValues(Predicate)}
+ * </li>
  * </ul>
  * </ul>
  *
  *
  * @see DefaultHttpFirewall
  * @see DefaultHttpFirewall
@@ -111,6 +130,18 @@ public class StrictHttpFirewall implements HttpFirewall {
 
 
 	private Predicate<String> allowedHostnames = hostname -> true;
 	private Predicate<String> allowedHostnames = hostname -> true;
 
 
+	private static final Pattern ASSIGNED_AND_NOT_ISO_CONTROL_PATTERN = Pattern.compile("[\\p{IsAssigned}&&[^\\p{IsControl}]]*");
+
+	private static final Predicate<String> ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE = s -> ASSIGNED_AND_NOT_ISO_CONTROL_PATTERN.matcher(s).matches();
+
+	private Predicate<String> allowedHeaderNames = ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE;
+
+	private Predicate<String> allowedHeaderValues = ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE;
+
+	private Predicate<String> allowedParameterNames = ASSIGNED_AND_NOT_ISO_CONTROL_PREDICATE;
+
+	private Predicate<String> allowedParameterValues = value -> true;
+
 	public StrictHttpFirewall() {
 	public StrictHttpFirewall() {
 		urlBlocklistsAddAll(FORBIDDEN_SEMICOLON);
 		urlBlocklistsAddAll(FORBIDDEN_SEMICOLON);
 		urlBlocklistsAddAll(FORBIDDEN_FORWARDSLASH);
 		urlBlocklistsAddAll(FORBIDDEN_FORWARDSLASH);
@@ -330,6 +361,77 @@ public class StrictHttpFirewall implements HttpFirewall {
 		}
 		}
 	}
 	}
 
 
+	/**
+	 * <p>
+	 * Determines which header names should be allowed.
+	 * The default is to reject header names that contain ISO control characters
+	 * and characters that are not defined.
+	 * </p>
+	 *
+	 * @param allowedHeaderNames the predicate for testing header names
+	 * @see Character#isISOControl(int)
+	 * @see Character#isDefined(int)
+	 * @since 5.4
+	 */
+	public void setAllowedHeaderNames(Predicate<String> allowedHeaderNames) {
+		if (allowedHeaderNames == null) {
+			throw new IllegalArgumentException("allowedHeaderNames cannot be null");
+		}
+		this.allowedHeaderNames = allowedHeaderNames;
+	}
+
+	/**
+	 * <p>
+	 * Determines which header values should be allowed.
+	 * The default is to reject header values that contain ISO control characters
+	 * and characters that are not defined.
+	 * </p>
+	 *
+	 * @param allowedHeaderValues the predicate for testing hostnames
+	 * @see Character#isISOControl(int)
+	 * @see Character#isDefined(int)
+	 * @since 5.4
+	 */
+	public void setAllowedHeaderValues(Predicate<String> allowedHeaderValues) {
+		if (allowedHeaderValues == null) {
+			throw new IllegalArgumentException("allowedHeaderValues cannot be null");
+		}
+		this.allowedHeaderValues = allowedHeaderValues;
+	}
+	/*
+	 * Determines which parameter names should be allowed.
+	 * The default is to reject header names that contain ISO control characters
+	 * and characters that are not defined.
+	 * </p>
+	 *
+	 * @param allowedParameterNames the predicate for testing parameter names
+	 * @see Character#isISOControl(int)
+	 * @see Character#isDefined(int)
+	 * @since 5.4
+	 */
+	public void setAllowedParameterNames(Predicate<String> allowedParameterNames) {
+		if (allowedParameterNames == null) {
+			throw new IllegalArgumentException("allowedParameterNames cannot be null");
+		}
+		this.allowedParameterNames = allowedParameterNames;
+	}
+
+	/**
+	 * <p>
+	 * Determines which parameter values should be allowed.
+	 * The default is to allow any parameter value.
+	 * </p>
+	 *
+	 * @param allowedParameterValues the predicate for testing parameter values
+	 * @since 5.4
+	 */
+	public void setAllowedParameterValues(Predicate<String> allowedParameterValues) {
+		if (allowedParameterValues == null) {
+			throw new IllegalArgumentException("allowedParameterValues cannot be null");
+		}
+		this.allowedParameterValues = allowedParameterValues;
+	}
+
 	/**
 	/**
 	 * <p>
 	 * <p>
 	 * Determines which hostnames should be allowed. The default is to allow any hostname.
 	 * Determines which hostnames should be allowed. The default is to allow any hostname.
@@ -370,6 +472,144 @@ public class StrictHttpFirewall implements HttpFirewall {
 			throw new RequestRejectedException("The requestURI was rejected because it can only contain printable ASCII characters.");
 			throw new RequestRejectedException("The requestURI was rejected because it can only contain printable ASCII characters.");
 		}
 		}
 		return new FirewalledRequest(request) {
 		return new FirewalledRequest(request) {
+			@Override
+			public long getDateHeader(String name) {
+				if (!allowedHeaderNames.test(name)) {
+					throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
+				}
+				return super.getDateHeader(name);
+			}
+
+			@Override
+			public int getIntHeader(String name) {
+				if (!allowedHeaderNames.test(name)) {
+					throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
+				}
+				return super.getIntHeader(name);
+			}
+
+			@Override
+			public String getHeader(String name) {
+				if (!allowedHeaderNames.test(name)) {
+					throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
+				}
+				String value = super.getHeader(name);
+				if (value != null && !allowedHeaderValues.test(value)) {
+					throw new RequestRejectedException("The request was rejected because the header value \"" + value + "\" is not allowed.");
+				}
+				return value;
+			}
+
+			@Override
+			public Enumeration<String> getHeaders(String name) {
+				if (!allowedHeaderNames.test(name)) {
+					throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
+				}
+
+				Enumeration<String> valuesEnumeration = super.getHeaders(name);
+				return new Enumeration<String>() {
+					@Override
+					public boolean hasMoreElements() {
+						return valuesEnumeration.hasMoreElements();
+					}
+
+					@Override
+					public String nextElement() {
+						String value = valuesEnumeration.nextElement();
+						if (!allowedHeaderValues.test(value)) {
+							throw new RequestRejectedException("The request was rejected because the header value \"" + value + "\" is not allowed.");
+						}
+						return value;
+					}
+				};
+			}
+
+			@Override
+			public Enumeration<String> getHeaderNames() {
+				Enumeration<String> namesEnumeration = super.getHeaderNames();
+				return new Enumeration<String>() {
+					@Override
+					public boolean hasMoreElements() {
+						return namesEnumeration.hasMoreElements();
+					}
+
+					@Override
+					public String nextElement() {
+						String name = namesEnumeration.nextElement();
+						if (!allowedHeaderNames.test(name)) {
+							throw new RequestRejectedException("The request was rejected because the header name \"" + name + "\" is not allowed.");
+						}
+						return name;
+					}
+				};
+			}
+
+			@Override
+			public String getParameter(String name) {
+				if (!allowedParameterNames.test(name)) {
+					throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
+				}
+				String value = super.getParameter(name);
+				if (value != null && !allowedParameterValues.test(value)) {
+					throw new RequestRejectedException("The request was rejected because the parameter value \"" + value + "\" is not allowed.");
+				}
+				return value;
+			}
+
+			@Override
+			public Map<String, String[]> getParameterMap() {
+				Map<String, String[]> parameterMap = super.getParameterMap();
+				for (Map.Entry<String, String[]> entry : parameterMap.entrySet()) {
+					String name = entry.getKey();
+					String[] values = entry.getValue();
+					if (!allowedParameterNames.test(name)) {
+						throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
+					}
+					for (String value: values) {
+						if (!allowedParameterValues.test(value)) {
+							throw new RequestRejectedException("The request was rejected because the parameter value \"" + value + "\" is not allowed.");
+						}
+					}
+				}
+				return parameterMap;
+			}
+
+			@Override
+			public Enumeration<String> getParameterNames() {
+				Enumeration<String> namesEnumeration = super.getParameterNames();
+				return new Enumeration<String>() {
+					@Override
+					public boolean hasMoreElements() {
+						return namesEnumeration.hasMoreElements();
+					}
+
+					@Override
+					public String nextElement() {
+						String name = namesEnumeration.nextElement();
+						if (!allowedParameterNames.test(name)) {
+							throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
+						}
+						return name;
+					}
+				};
+			}
+
+			@Override
+			public String[] getParameterValues(String name) {
+				if (!allowedParameterNames.test(name)) {
+					throw new RequestRejectedException("The request was rejected because the parameter name \"" + name + "\" is not allowed.");
+				}
+				String[] values = super.getParameterValues(name);
+				if (values != null) {
+					for (String value: values) {
+						if (!allowedParameterValues.test(value)) {
+							throw new RequestRejectedException("The request was rejected because the parameter value \"" + value + "\" is not allowed.");
+						}
+					}
+				}
+				return values;
+			}
+
 			@Override
 			@Override
 			public void reset() {
 			public void reset() {
 			}
 			}

+ 143 - 0
web/src/test/java/org/springframework/security/web/firewall/StrictHttpFirewallTests.java

@@ -23,6 +23,8 @@ import static org.assertj.core.api.Assertions.fail;
 import java.util.Arrays;
 import java.util.Arrays;
 import java.util.List;
 import java.util.List;
 
 
+import javax.servlet.http.HttpServletRequest;
+
 import org.junit.Test;
 import org.junit.Test;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpMethod;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
@@ -595,4 +597,145 @@ public class StrictHttpFirewallTests {
 
 
 		this.firewall.getFirewalledRequest(this.request);
 		this.firewall.getFirewalledRequest(this.request);
 	}
 	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetHeaderWhenNotAllowedHeaderNameThenException() {
+		this.firewall.setAllowedHeaderNames(name -> !name.equals("bad name"));
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getHeader("bad name");
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetHeaderWhenNotAllowedHeaderValueThenException() {
+		this.request.addHeader("good name", "bad value");
+		this.firewall.setAllowedHeaderValues(value -> !value.equals("bad value"));
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getHeader("good name");
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetDateHeaderWhenControlCharacterInHeaderNameThenException() {
+		this.request.addHeader("Bad\0Name", "some value");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getDateHeader("Bad\0Name");
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetIntHeaderWhenControlCharacterInHeaderNameThenException() {
+		this.request.addHeader("Bad\0Name", "some value");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getIntHeader("Bad\0Name");
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetHeaderWhenControlCharacterInHeaderNameThenException() {
+		this.request.addHeader("Bad\0Name", "some value");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getHeader("Bad\0Name");
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetHeaderWhenUndefinedCharacterInHeaderNameThenException() {
+		this.request.addHeader("Bad\uFFFEName", "some value");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getHeader("Bad\uFFFEName");
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetHeadersWhenControlCharacterInHeaderNameThenException() {
+		this.request.addHeader("Bad\0Name", "some value");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getHeaders("Bad\0Name");
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetHeaderNamesWhenControlCharacterInHeaderNameThenException() {
+		this.request.addHeader("Bad\0Name", "some value");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getHeaderNames().nextElement();
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetHeaderWhenControlCharacterInHeaderValueThenException() {
+		this.request.addHeader("Something", "bad\0value");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getHeader("Something");
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetHeaderWhenUndefinedCharacterInHeaderValueThenException() {
+		this.request.addHeader("Something", "bad\uFFFEvalue");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getHeader("Something");
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetHeadersWhenControlCharacterInHeaderValueThenException() {
+		this.request.addHeader("Something", "bad\0value");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getHeaders("Something").nextElement();
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetParameterWhenControlCharacterInParameterNameThenException() {
+		this.request.addParameter("Bad\0Name", "some value");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getParameter("Bad\0Name");
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetParameterMapWhenControlCharacterInParameterNameThenException() {
+		this.request.addParameter("Bad\0Name", "some value");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getParameterMap();
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetParameterNamesWhenControlCharacterInParameterNameThenException() {
+		this.request.addParameter("Bad\0Name", "some value");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getParameterNames().nextElement();
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetParameterNamesWhenUndefinedCharacterInParameterNameThenException() {
+		this.request.addParameter("Bad\uFFFEName", "some value");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getParameterNames().nextElement();
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetParameterValuesWhenNotAllowedInParameterValueThenException() {
+		this.firewall.setAllowedParameterValues(value -> !value.equals("bad value"));
+
+		this.request.addParameter("Something", "bad value");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getParameterValues("Something");
+	}
+
+	@Test(expected = RequestRejectedException.class)
+	public void getFirewalledRequestGetParameterValuesWhenNotAllowedInParameterNameThenException() {
+		this.firewall.setAllowedParameterNames(value -> !value.equals("bad name"));
+
+		this.request.addParameter("bad name", "good value");
+
+		HttpServletRequest request = this.firewall.getFirewalledRequest(this.request);
+		request.getParameterValues("bad name");
+	}
 }
 }