2
0
Эх сурвалжийг харах

SEC-1606: Added a FirewalledRequestAwareRequestDispatcher that will call FirewalledRequest.reset() before a forward

Rob Winch 15 жил өмнө
parent
commit
54ffc98bb4

+ 42 - 0
core/src/main/java/org/springframework/security/firewall/RequestWrapper.java

@@ -1,6 +1,12 @@
 package org.springframework.security.firewall;
 
+import javax.servlet.RequestDispatcher;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
 import javax.servlet.http.HttpServletRequest;
+
+import java.io.IOException;
 import java.util.*;
 
 /**
@@ -92,7 +98,43 @@ final class RequestWrapper extends FirewalledRequest {
         return stripPaths ? strippedServletPath : super.getServletPath();
     }
 
+    public RequestDispatcher getRequestDispatcher(String path) {
+        return this.stripPaths ? new FirewalledRequestAwareRequestDispatcher(path) : super.getRequestDispatcher(path);
+    }
+
     public void reset() {
         this.stripPaths = false;
     }
+
+    /**
+     * Ensures {@link FirewalledRequest#reset()} is called prior to performing a forward. It then delegates work to the
+     * {@link RequestDispatcher} from the original {@link HttpServletRequest}.
+     *
+     * @author Rob Winch
+     */
+    private class FirewalledRequestAwareRequestDispatcher implements RequestDispatcher {
+        private final String path;
+
+        /**
+         *
+         * @param path the {@code path} that will be used to obtain the delegate {@link RequestDispatcher} from the
+         * original {@link HttpServletRequest}.
+         */
+        public FirewalledRequestAwareRequestDispatcher(String path) {
+            this.path = path;
+        }
+
+        public void forward(ServletRequest request, ServletResponse response) throws ServletException, IOException {
+            reset();
+            getDelegateDispatcher().forward(request, response);
+        }
+
+        public void include(ServletRequest request, ServletResponse response) throws ServletException, IOException {
+            getDelegateDispatcher().include(request, response);
+        }
+
+        private RequestDispatcher getDelegateDispatcher() {
+            return RequestWrapper.super.getRequestDispatcher(path);
+        }
+    }
 }

+ 45 - 3
core/src/test/java/org/springframework/security/firewall/RequestWrapperTests.java

@@ -1,13 +1,19 @@
 package org.springframework.security.firewall;
 
-import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.*;
+import static org.mockito.Mockito.*;
+
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+import javax.servlet.RequestDispatcher;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 
 import org.junit.BeforeClass;
 import org.junit.Test;
 import org.springframework.mock.web.MockHttpServletRequest;
 
-import java.util.*;
-
 /**
  * @author Luke Taylor
  */
@@ -59,4 +65,40 @@ public class RequestWrapperTests {
         }
     }
 
+    @Test
+    public void resetWhenForward() throws Exception {
+        String denormalizedPath = testPaths.keySet().iterator().next();
+        String forwardPath = "/forward/path";
+        HttpServletRequest mockRequest = mock(HttpServletRequest.class);
+        HttpServletResponse mockResponse = mock(HttpServletResponse.class);
+        RequestDispatcher mockDispatcher = mock(RequestDispatcher.class);
+        when(mockRequest.getServletPath()).thenReturn("");
+        when(mockRequest.getPathInfo()).thenReturn(denormalizedPath);
+        when(mockRequest.getRequestDispatcher(forwardPath)).thenReturn(mockDispatcher);
+
+        RequestWrapper wrapper = new RequestWrapper(mockRequest);
+        RequestDispatcher dispatcher = wrapper.getRequestDispatcher(forwardPath);
+        dispatcher.forward(mockRequest, mockResponse);
+
+        verify(mockRequest).getRequestDispatcher(forwardPath);
+        verify(mockDispatcher).forward(mockRequest, mockResponse);
+        assertEquals(denormalizedPath,wrapper.getPathInfo());
+        verify(mockRequest,times(2)).getPathInfo();
+        // validate wrapper.getServletPath() delegates to the mock
+        wrapper.getServletPath();
+        verify(mockRequest,times(2)).getServletPath();
+        verifyNoMoreInteractions(mockRequest,mockResponse,mockDispatcher);
+    }
+
+    @Test
+    public void requestDispatcherNotWrappedAfterReset() {
+        String path = "/forward/path";
+        HttpServletRequest request = mock(HttpServletRequest.class);
+        RequestDispatcher dispatcher = mock(RequestDispatcher.class);
+        when(request.getRequestDispatcher(path)).thenReturn(dispatcher);
+        RequestWrapper wrapper = new RequestWrapper(request);
+        wrapper.reset();
+        assertSame(dispatcher, wrapper.getRequestDispatcher(path));
+    }
+
 }