浏览代码

SEC-1901: Changed DebugFilter to no longer extend OncePerRequesetFilter so that the FilterChainProxy is invoked on forwards

Rob Winch 13 年之前
父节点
当前提交
488efbc97e

+ 39 - 5
config/src/main/java/org/springframework/security/config/debug/DebugFilter.java

@@ -3,11 +3,13 @@ package org.springframework.security.config.debug;
 import org.springframework.security.web.FilterChainProxy;
 import org.springframework.security.web.SecurityFilterChain;
 import org.springframework.security.web.util.UrlUtils;
-import org.springframework.web.filter.OncePerRequestFilter;
 
 import javax.servlet.Filter;
 import javax.servlet.FilterChain;
+import javax.servlet.FilterConfig;
 import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequestWrapper;
 import javax.servlet.http.HttpServletResponse;
@@ -23,9 +25,12 @@ import java.util.*;
  *
  *
  * @author Luke Taylor
+ * @author Rob Winch
  * @since 3.1
  */
-class DebugFilter extends OncePerRequestFilter {
+class DebugFilter implements Filter {
+    private static final String ALREADY_FILTERED_ATTR_NAME = DebugFilter.class.getName().concat(".FILTERED");
+
     private final FilterChainProxy fcp;
     private final Logger logger = new Logger();
 
@@ -33,8 +38,15 @@ class DebugFilter extends OncePerRequestFilter {
         this.fcp = fcp;
     }
 
-    @Override
-    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
+    public final void doFilter(ServletRequest srvltRequest, ServletResponse srvltResponse, FilterChain filterChain)
+            throws ServletException, IOException {
+
+        if (!(srvltRequest instanceof HttpServletRequest) || !(srvltResponse instanceof HttpServletResponse)) {
+            throw new ServletException("DebugFilter just supports HTTP requests");
+        }
+        HttpServletRequest request = (HttpServletRequest) srvltRequest;
+        HttpServletResponse response = (HttpServletResponse) srvltResponse;
+
         List<Filter> filters = getFilters(request);
         logger.log("Request received for '" + UrlUtils.buildRequestUrl(request) + "':\n\n" +
                 request + "\n\n" +
@@ -42,7 +54,23 @@ class DebugFilter extends OncePerRequestFilter {
                 "pathInfo:" + request.getPathInfo() + "\n\n" +
                 formatFilters(filters));
 
-        fcp.doFilter(new DebugRequestWrapper(request), response, filterChain);
+        if (request.getAttribute(ALREADY_FILTERED_ATTR_NAME) == null) {
+            invokeWithWrappedRequest(request, response, filterChain);
+        } else {
+            fcp.doFilter(request, response, filterChain);
+        }
+    }
+
+    private void invokeWithWrappedRequest(HttpServletRequest request,
+            HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException {
+        request.setAttribute(ALREADY_FILTERED_ATTR_NAME, Boolean.TRUE);
+        request = new DebugRequestWrapper(request);
+        try {
+            fcp.doFilter(request, response, filterChain);
+        }
+        finally {
+            request.removeAttribute(ALREADY_FILTERED_ATTR_NAME);
+        }
     }
 
     String formatFilters(List<Filter> filters) {
@@ -72,6 +100,12 @@ class DebugFilter extends OncePerRequestFilter {
 
         return null;
     }
+
+    public void init(FilterConfig filterConfig) throws ServletException {
+    }
+
+    public void destroy() {
+    }
 }
 
 class DebugRequestWrapper extends HttpServletRequestWrapper {

+ 92 - 0
config/src/test/java/org/springframework/security/config/debug/DebugFilterTest.java

@@ -0,0 +1,92 @@
+package org.springframework.security.config.debug;
+
+import static org.junit.Assert.assertEquals;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.never;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+import javax.servlet.FilterChain;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletRequestWrapper;
+import javax.servlet.http.HttpServletResponse;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.powermock.core.classloader.annotations.PrepareOnlyThisForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.powermock.reflect.internal.WhiteboxImpl;
+import org.springframework.security.web.FilterChainProxy;
+
+/**
+ *
+ * @author Rob Winch
+ *
+ */
+@RunWith(PowerMockRunner.class)
+@PrepareOnlyThisForTest(Logger.class)
+public class DebugFilterTest {
+    @Captor
+    private ArgumentCaptor<HttpServletRequest> requestCaptor;
+    @Mock
+    private HttpServletRequest request;
+    @Mock
+    private HttpServletResponse response;
+    @Mock
+    private FilterChain filterChain;
+    @Mock
+    private FilterChainProxy fcp;
+    @Mock
+    private Logger logger;
+
+    private String requestAttr;
+
+    private DebugFilter filter;
+
+    @Before
+    public void setUp() {
+        when(request.getServletPath()).thenReturn("/login");
+        filter = new DebugFilter(fcp);
+        WhiteboxImpl.setInternalState(filter, Logger.class, logger);
+        requestAttr = WhiteboxImpl.getInternalState(filter, "ALREADY_FILTERED_ATTR_NAME", filter.getClass());
+    }
+
+    @Test
+    public void doFilterProcessesRequests() throws Exception {
+        filter.doFilter(request, response, filterChain);
+
+        verify(logger).log(anyString());
+        verify(request).setAttribute(requestAttr, Boolean.TRUE);
+        verify(fcp).doFilter(requestCaptor.capture(), eq(response), eq(filterChain));
+        assertEquals(DebugRequestWrapper.class,requestCaptor.getValue().getClass());
+        verify(request).removeAttribute(requestAttr);
+    }
+
+    // SEC-1901
+    @Test
+    public void doFilterProcessesForwardedRequests() throws Exception {
+        when(request.getAttribute(requestAttr)).thenReturn(Boolean.TRUE);
+        HttpServletRequest request = new DebugRequestWrapper(this.request);
+
+        filter.doFilter(request, response, filterChain);
+
+        verify(logger).log(anyString());
+        verify(fcp).doFilter(request, response, filterChain);
+        verify(this.request,never()).removeAttribute(requestAttr);
+    }
+
+    @Test
+    public void doFilterDoesNotWrapWithDebugRequestWrapperAgain() throws Exception {
+        when(request.getAttribute(requestAttr)).thenReturn(Boolean.TRUE);
+        HttpServletRequest fireWalledRequest = new HttpServletRequestWrapper(new DebugRequestWrapper(this.request));
+
+        filter.doFilter(fireWalledRequest, response, filterChain);
+
+        verify(fcp).doFilter(fireWalledRequest, response, filterChain);
+    }
+}