2
0
Эх сурвалжийг харах

SEC-1950: Defensively invoke SecurityContextHolder.clearContext() in FilterChainProxy

Rob Winch 13 жил өмнө
parent
commit
bb8f3bae7c

+ 11 - 0
web/src/main/java/org/springframework/security/web/FilterChainProxy.java

@@ -17,6 +17,7 @@ package org.springframework.security.web;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.firewall.DefaultHttpFirewall;
 import org.springframework.security.web.firewall.FirewalledRequest;
 import org.springframework.security.web.firewall.HttpFirewall;
@@ -150,6 +151,16 @@ public class FilterChainProxy extends GenericFilterBean {
 
     public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
             throws IOException, ServletException {
+        try {
+            doFilterInternal(request, response, chain);
+        } finally {
+            // SEC-1950
+            SecurityContextHolder.clearContext();
+        }
+    }
+
+    private void doFilterInternal(ServletRequest request, ServletResponse response, FilterChain chain)
+            throws IOException, ServletException {
 
         FirewalledRequest fwRequest = firewall.getFirewalledRequest((HttpServletRequest) request);
         HttpServletResponse fwResponse = firewall.getFirewalledResponse((HttpServletResponse) response);

+ 42 - 0
web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java

@@ -3,18 +3,22 @@ package org.springframework.security.web;
 import static org.junit.Assert.*;
 import static org.mockito.Mockito.*;
 
+import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.firewall.FirewalledRequest;
 import org.springframework.security.web.firewall.HttpFirewall;
 import org.springframework.security.web.util.RequestMatcher;
 
 import javax.servlet.Filter;
 import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequestWrapper;
 import javax.servlet.http.HttpServletResponse;
@@ -55,6 +59,11 @@ public class FilterChainProxyTests {
         chain = mock(FilterChain.class);
     }
 
+    @After
+    public void teardown() {
+        SecurityContextHolder.clearContext();
+    }
+
     @Test
     public void toStringCallSucceeds() throws Exception {
         fcp.afterPropertiesSet();
@@ -155,4 +164,37 @@ public class FilterChainProxyTests {
         verify(firstFwr).reset();
         verify(fwr).reset();
     }
+
+    @Test
+    public void doFilterClearsSecurityContextHolder() throws Exception {
+        when(matcher.matches(any(HttpServletRequest.class))).thenReturn(true);
+        doAnswer(new Answer<Object>() {
+            public Object answer(InvocationOnMock inv) throws Throwable {
+                SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("username", "password"));
+                return null;
+            }
+        }).when(filter).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
+
+        fcp.doFilter(request, response, chain);
+
+        assertNull(SecurityContextHolder.getContext().getAuthentication());
+    }
+
+    @Test
+    public void doFilterClearsSecurityContextHolderWithException() throws Exception {
+        when(matcher.matches(any(HttpServletRequest.class))).thenReturn(true);
+        doAnswer(new Answer<Object>() {
+            public Object answer(InvocationOnMock inv) throws Throwable {
+                SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("username", "password"));
+                throw new ServletException("oops");
+            }
+        }).when(filter).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class), any(FilterChain.class));
+
+        try {
+            fcp.doFilter(request, response, chain);
+            fail("Expected Exception");
+        }catch(ServletException success) {}
+
+        assertNull(SecurityContextHolder.getContext().getAuthentication());
+    }
 }