Explorar o código

SEC-2054: BasicAuthenticationFilter not invoked on ERROR dispatch

Rob Winch %!s(int64=10) %!d(string=hai) anos
pai
achega
e2f7b38b87

+ 5 - 4
web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java

@@ -39,6 +39,7 @@ import org.springframework.security.web.authentication.RememberMeServices;
 import org.springframework.security.web.authentication.WebAuthenticationDetailsSource;
 import org.springframework.util.Assert;
 import org.springframework.web.filter.GenericFilterBean;
+import org.springframework.web.filter.OncePerRequestFilter;
 
 
 /**
@@ -85,7 +86,7 @@ import org.springframework.web.filter.GenericFilterBean;
  *
  * @author Ben Alex
  */
-public class BasicAuthenticationFilter extends GenericFilterBean {
+public class BasicAuthenticationFilter extends OncePerRequestFilter {
 
     //~ Instance fields ================================================================================================
 
@@ -138,11 +139,9 @@ public class BasicAuthenticationFilter extends GenericFilterBean {
         }
     }
 
-    public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
+    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain chain)
             throws IOException, ServletException {
         final boolean debug = logger.isDebugEnabled();
-        final HttpServletRequest request = (HttpServletRequest) req;
-        final HttpServletResponse response = (HttpServletResponse) res;
 
         String header = request.getHeader("Authorization");
 
@@ -201,6 +200,8 @@ public class BasicAuthenticationFilter extends GenericFilterBean {
         chain.doFilter(request, response);
     }
 
+
+
     /**
      * Decodes the header into a username and password.
      *

+ 40 - 21
web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java

@@ -15,7 +15,7 @@
 
 package org.springframework.security.web.authentication.www;
 
-import static org.junit.Assert.*;
+import static org.fest.assertions.Assertions.assertThat;
 import static org.mockito.AdditionalMatchers.not;
 import static org.mockito.Matchers.*;
 import static org.mockito.Mockito.*;
@@ -39,6 +39,7 @@ import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.AuthenticationEntryPoint;
 import org.springframework.security.web.authentication.WebAuthenticationDetails;
+import org.springframework.web.util.WebUtils;
 
 
 /**
@@ -89,17 +90,17 @@ public class BasicAuthenticationFilterTests {
         verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
 
         // Test
-        assertNull(SecurityContextHolder.getContext().getAuthentication());
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
     }
 
     @Test
     public void testGettersSetters() {
         BasicAuthenticationFilter filter = new BasicAuthenticationFilter();
         filter.setAuthenticationManager(manager);
-        assertTrue(filter.getAuthenticationManager() != null);
+        assertThat(filter.getAuthenticationManager()).isNotNull();
 
         filter.setAuthenticationEntryPoint(mock(AuthenticationEntryPoint.class));
-        assertTrue(filter.getAuthenticationEntryPoint() != null);
+        assertThat(filter.getAuthenticationEntryPoint()).isNotNull();
     }
 
     @Test
@@ -115,8 +116,8 @@ public class BasicAuthenticationFilterTests {
         filter.doFilter(request, response, chain);
 
         verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class));
-        assertNull(SecurityContextHolder.getContext().getAuthentication());
-        assertEquals(401, response.getStatus());
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
+        assertThat(response.getStatus()).isEqualTo(401);
     }
 
     @Test
@@ -131,8 +132,8 @@ public class BasicAuthenticationFilterTests {
         filter.doFilter(request, response, chain);
         // The filter chain shouldn't proceed
         verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class));
-        assertNull(SecurityContextHolder.getContext().getAuthentication());
-        assertEquals(401, response.getStatus());
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
+        assertThat(response.getStatus()).isEqualTo(401);
     }
 
     @Test
@@ -143,13 +144,13 @@ public class BasicAuthenticationFilterTests {
         request.setServletPath("/some_file.html");
 
         // Test
-        assertNull(SecurityContextHolder.getContext().getAuthentication());
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
         FilterChain chain = mock(FilterChain.class);
         filter.doFilter(request, new MockHttpServletResponse(), chain);
 
         verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
-        assertNotNull(SecurityContextHolder.getContext().getAuthentication());
-        assertEquals("rod", SecurityContextHolder.getContext().getAuthentication().getName());
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull();
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isEqualTo("rod");
     }
 
     @Test
@@ -162,7 +163,7 @@ public class BasicAuthenticationFilterTests {
         filter.doFilter(request, new MockHttpServletResponse(), chain);
 
         verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
-        assertNull(SecurityContextHolder.getContext().getAuthentication());
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
     }
 
     @Test(expected=IllegalArgumentException.class)
@@ -193,8 +194,8 @@ public class BasicAuthenticationFilterTests {
         verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
 
         // Test
-        assertNotNull(SecurityContextHolder.getContext().getAuthentication());
-        assertEquals("rod", SecurityContextHolder.getContext().getAuthentication().getName());
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isNotNull();
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isEqualTo("rod");
 
         // NOW PERFORM FAILED AUTHENTICATION
 
@@ -212,8 +213,8 @@ public class BasicAuthenticationFilterTests {
         // Test - the filter chain will not be invoked, as we get a 401 forbidden response
         MockHttpServletResponse response = response2;
 
-        assertNull(SecurityContextHolder.getContext().getAuthentication());
-        assertEquals(401, response.getStatus());
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
+        assertThat(response.getStatus()).isEqualTo(401);
     }
 
     @Test
@@ -225,14 +226,14 @@ public class BasicAuthenticationFilterTests {
         request.setSession(new MockHttpSession());
 
         filter.setIgnoreFailure(true);
-        assertTrue(filter.isIgnoreFailure());
+        assertThat(filter.isIgnoreFailure()).isTrue();
         FilterChain chain = mock(FilterChain.class);
         filter.doFilter(request, new MockHttpServletResponse(), chain);
 
         verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
 
         // Test - the filter chain will be invoked, as we've set ignoreFailure = true
-        assertNull(SecurityContextHolder.getContext().getAuthentication());
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
     }
 
     @Test
@@ -242,7 +243,7 @@ public class BasicAuthenticationFilterTests {
         request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes())));
         request.setServletPath("/some_file.html");
         request.setSession(new MockHttpSession());
-        assertFalse(filter.isIgnoreFailure());
+        assertThat(filter.isIgnoreFailure()).isFalse();
         final MockHttpServletResponse response = new MockHttpServletResponse();
 
         FilterChain chain = mock(FilterChain.class);
@@ -250,7 +251,25 @@ public class BasicAuthenticationFilterTests {
 
         // Test - the filter chain will not be invoked, as we get a 401 forbidden response
         verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class));
-        assertNull(SecurityContextHolder.getContext().getAuthentication());
-        assertEquals(401, response.getStatus());
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
+        assertThat(response.getStatus()).isEqualTo(401);
+    }
+
+    // SEC-2054
+    @Test
+    public void skippedOnErrorDispatch() throws Exception {
+
+        String token = "bad:credentials";
+        MockHttpServletRequest request = new MockHttpServletRequest();
+        request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes())));
+        request.setServletPath("/some_file.html");
+        request.setAttribute(WebUtils.ERROR_REQUEST_URI_ATTRIBUTE, "/error");
+        MockHttpServletResponse response = new MockHttpServletResponse();
+
+        FilterChain chain = mock(FilterChain.class);
+
+        filter.doFilter(request, response, chain);
+
+        assertThat(response.getStatus()).isEqualTo(200);
     }
 }