|  | @@ -1,5 +1,5 @@
 | 
	
		
			
				|  |  |  /*
 | 
	
		
			
				|  |  | - * Copyright 2012-2017 the original author or authors.
 | 
	
		
			
				|  |  | + * Copyright 2012-2019 the original author or authors.
 | 
	
		
			
				|  |  |   *
 | 
	
		
			
				|  |  |   * Licensed under the Apache License, Version 2.0 (the "License");
 | 
	
		
			
				|  |  |   * you may not use this file except in compliance with the License.
 | 
	
	
		
			
				|  | @@ -26,6 +26,7 @@ import java.util.Collections;
 | 
	
		
			
				|  |  |  import java.util.HashSet;
 | 
	
		
			
				|  |  |  import java.util.List;
 | 
	
		
			
				|  |  |  import java.util.Set;
 | 
	
		
			
				|  |  | +import java.util.function.Predicate;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  /**
 | 
	
		
			
				|  |  |   * <p>
 | 
	
	
		
			
				|  | @@ -66,10 +67,15 @@ import java.util.Set;
 | 
	
		
			
				|  |  |   * Rejects URLs that contain a URL encoded percent. See
 | 
	
		
			
				|  |  |   * {@link #setAllowUrlEncodedPercent(boolean)}
 | 
	
		
			
				|  |  |   * </li>
 | 
	
		
			
				|  |  | + * <li>
 | 
	
		
			
				|  |  | + * Rejects hosts that are not allowed. See
 | 
	
		
			
				|  |  | + * {@link #setAllowedHostnames(Predicate)}
 | 
	
		
			
				|  |  | + * </li>
 | 
	
		
			
				|  |  |   * </ul>
 | 
	
		
			
				|  |  |   *
 | 
	
		
			
				|  |  |   * @see DefaultHttpFirewall
 | 
	
		
			
				|  |  |   * @author Rob Winch
 | 
	
		
			
				|  |  | + * @author Eddú Meléndez
 | 
	
		
			
				|  |  |   * @since 4.2.4
 | 
	
		
			
				|  |  |   */
 | 
	
		
			
				|  |  |  public class StrictHttpFirewall implements HttpFirewall {
 | 
	
	
		
			
				|  | @@ -98,6 +104,8 @@ public class StrictHttpFirewall implements HttpFirewall {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	private Set<String> allowedHttpMethods = createDefaultAllowedHttpMethods();
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	private Predicate<String> allowedHostnames = hostname -> true;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	public StrictHttpFirewall() {
 | 
	
		
			
				|  |  |  		urlBlacklistsAddAll(FORBIDDEN_SEMICOLON);
 | 
	
		
			
				|  |  |  		urlBlacklistsAddAll(FORBIDDEN_FORWARDSLASH);
 | 
	
	
		
			
				|  | @@ -297,6 +305,13 @@ public class StrictHttpFirewall implements HttpFirewall {
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	public void setAllowedHostnames(Predicate<String> allowedHostnames) {
 | 
	
		
			
				|  |  | +		if (allowedHostnames == null) {
 | 
	
		
			
				|  |  | +			throw new IllegalArgumentException("allowedHostnames cannot be null");
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +		this.allowedHostnames = allowedHostnames;
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	private void urlBlacklistsAddAll(Collection<String> values) {
 | 
	
		
			
				|  |  |  		this.encodedUrlBlacklist.addAll(values);
 | 
	
		
			
				|  |  |  		this.decodedUrlBlacklist.addAll(values);
 | 
	
	
		
			
				|  | @@ -311,6 +326,7 @@ public class StrictHttpFirewall implements HttpFirewall {
 | 
	
		
			
				|  |  |  	public FirewalledRequest getFirewalledRequest(HttpServletRequest request) throws RequestRejectedException {
 | 
	
		
			
				|  |  |  		rejectForbiddenHttpMethod(request);
 | 
	
		
			
				|  |  |  		rejectedBlacklistedUrls(request);
 | 
	
		
			
				|  |  | +		rejectedUntrustedHosts(request);
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  		if (!isNormalized(request)) {
 | 
	
		
			
				|  |  |  			throw new RequestRejectedException("The request was rejected because the URL was not normalized.");
 | 
	
	
		
			
				|  | @@ -352,6 +368,13 @@ public class StrictHttpFirewall implements HttpFirewall {
 | 
	
		
			
				|  |  |  		}
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	private void rejectedUntrustedHosts(HttpServletRequest request) {
 | 
	
		
			
				|  |  | +		String serverName = request.getServerName();
 | 
	
		
			
				|  |  | +		if (serverName != null && !this.allowedHostnames.test(serverName)) {
 | 
	
		
			
				|  |  | +			throw new RequestRejectedException("The request was rejected because the domain " + serverName + " is untrusted.");
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	@Override
 | 
	
		
			
				|  |  |  	public HttpServletResponse getFirewalledResponse(HttpServletResponse response) {
 | 
	
		
			
				|  |  |  		return new FirewalledResponse(response);
 |