2
0
Эх сурвалжийг харах

SEC-1446: Modified BasicAuthenticationFilter to treat invalid base64 and invalid Basic authentication tokens as a failed authentication (raising a BadCredentialsException, without calling the AuthenticationManager).

This solves the problem in this issue (invalid Base64 not resulting in a 401) and also prevents unnecessary calls to the AuthenticationManager.
Luke Taylor 15 жил өмнө
parent
commit
2e2625873c

+ 56 - 38
web/src/main/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilter.java

@@ -27,6 +27,7 @@ import javax.servlet.http.HttpServletResponse;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.authentication.AuthenticationDetailsSource;
 import org.springframework.security.authentication.AuthenticationManager;
+import org.springframework.security.authentication.BadCredentialsException;
 import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
@@ -114,18 +115,16 @@ public class BasicAuthenticationFilter extends GenericFilterBean {
 
         String header = request.getHeader("Authorization");
 
-        if ((header != null) && header.startsWith("Basic ")) {
-            byte[] base64Token = header.substring(6).getBytes("UTF-8");
-            String token = new String(Base64.decode(base64Token), getCredentialsCharset(request));
+        if (header == null || !header.startsWith("Basic ")) {
+            chain.doFilter(request, response);
+            return;
+        }
 
-            String username = "";
-            String password = "";
-            int delim = token.indexOf(":");
+        try {
+            String[] tokens = extractAndDecodeHeader(header, request);
+            assert tokens.length == 2;
 
-            if (delim != -1) {
-                username = token.substring(0, delim);
-                password = token.substring(delim + 1);
-            }
+            String username = tokens[0];
 
             if (debug) {
                 logger.debug("Basic Authentication Authorization header found for user '" + username + "'");
@@ -133,37 +132,12 @@ public class BasicAuthenticationFilter extends GenericFilterBean {
 
             if (authenticationIsRequired(username)) {
                 UsernamePasswordAuthenticationToken authRequest =
-                        new UsernamePasswordAuthenticationToken(username, password);
+                        new UsernamePasswordAuthenticationToken(username, tokens[1]);
                 authRequest.setDetails(authenticationDetailsSource.buildDetails(request));
+                Authentication authResult = authenticationManager.authenticate(authRequest);
 
-                Authentication authResult;
-
-                try {
-                    authResult = authenticationManager.authenticate(authRequest);
-                } catch (AuthenticationException failed) {
-                    // Authentication failed
-                    if (debug) {
-                        logger.debug("Authentication request for user: " + username + " failed: " + failed.toString());
-                    }
-
-                    SecurityContextHolder.getContext().setAuthentication(null);
-
-                    rememberMeServices.loginFail(request, response);
-
-                    onUnsuccessfulAuthentication(request, response, failed);
-
-                    if (ignoreFailure) {
-                        chain.doFilter(request, response);
-                    } else {
-                        authenticationEntryPoint.commence(request, response, failed);
-                    }
-
-                    return;
-                }
-
-                // Authentication success
                 if (debug) {
-                    logger.debug("Authentication success: " + authResult.toString());
+                    logger.debug("Authentication success: " + authResult);
                 }
 
                 SecurityContextHolder.getContext().setAuthentication(authResult);
@@ -172,11 +146,55 @@ public class BasicAuthenticationFilter extends GenericFilterBean {
 
                 onSuccessfulAuthentication(request, response, authResult);
             }
+
+        } catch (AuthenticationException failed) {
+            SecurityContextHolder.clearContext();
+
+            if (debug) {
+                logger.debug("Authentication request for failed: " + failed);
+            }
+
+            rememberMeServices.loginFail(request, response);
+
+            onUnsuccessfulAuthentication(request, response, failed);
+
+            if (ignoreFailure) {
+                chain.doFilter(request, response);
+            } else {
+                authenticationEntryPoint.commence(request, response, failed);
+            }
+
+            return;
         }
 
         chain.doFilter(request, response);
     }
 
+    /**
+     * Decodes the header into a username and password.
+     * <p>
+     * @throws BadCredentialsException if the Basic header is not present or is not valid Base64
+     */
+    private String[] extractAndDecodeHeader(String header, HttpServletRequest request) throws IOException {
+
+        byte[] base64Token = header.substring(6).getBytes("UTF-8");
+        byte[] decoded;
+        try {
+            decoded = Base64.decode(base64Token);
+        } catch (IllegalArgumentException e) {
+            throw new BadCredentialsException("Failed to decode basic authentication token");
+        }
+
+        String token = new String(decoded, getCredentialsCharset(request));
+
+        int delim = token.indexOf(":");
+
+        if (delim == -1) {
+            throw new BadCredentialsException("Invalid basic authentication token");
+        }
+        return new String[] {token.substring(0, delim), token.substring(delim + 1)};
+    }
+
     private boolean authenticationIsRequired(String username) {
         // Only reauthenticate if username doesn't match SecurityContextHolder and user isn't authenticated
         // (see SEC-53)

+ 59 - 45
web/src/test/java/org/springframework/security/web/authentication/www/BasicAuthenticationFilterTests.java

@@ -20,11 +20,7 @@ import static org.mockito.AdditionalMatchers.not;
 import static org.mockito.Matchers.*;
 import static org.mockito.Mockito.*;
 
-import java.io.IOException;
-
-import javax.servlet.Filter;
 import javax.servlet.FilterChain;
-import javax.servlet.ServletException;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
 
@@ -55,24 +51,9 @@ public class BasicAuthenticationFilterTests {
 
     private BasicAuthenticationFilter filter;
     private AuthenticationManager manager;
-//    private Mockery jmock = new JUnit4Mockery();
 
     //~ Methods ========================================================================================================
 
-    private MockHttpServletResponse executeFilterInContainerSimulator(Filter filter, final ServletRequest request,
-                    final boolean expectChainToProceed) throws ServletException, IOException {
-//        filter.init(mock(FilterConfig.class));
-
-        final MockHttpServletResponse response = new MockHttpServletResponse();
-
-        FilterChain chain = mock(FilterChain.class);
-        filter.doFilter(request, response, chain);
-//        filter.destroy();
-
-        verify(chain, expectChainToProceed ? times(1) : never()).doFilter(any(ServletRequest.class), any(ServletResponse.class));
-        return response;
-    }
-
     @Before
     public void setUp() throws Exception {
         SecurityContextHolder.clearContext();
@@ -97,13 +78,17 @@ public class BasicAuthenticationFilterTests {
 
     @Test
     public void testFilterIgnoresRequestsContainingNoAuthorizationHeader() throws Exception {
-        // Setup our HTTP request
+
         MockHttpServletRequest request = new MockHttpServletRequest();
         request.setServletPath("/some_file.html");
+        final MockHttpServletResponse response = new MockHttpServletResponse();
 
-        // Test
-        executeFilterInContainerSimulator(filter, request, true);
+        FilterChain chain = mock(FilterChain.class);
+        filter.doFilter(request, response, chain);
 
+        verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
+
+        // Test
         assertNull(SecurityContextHolder.getContext().getAuthentication());
     }
 
@@ -119,47 +104,64 @@ public class BasicAuthenticationFilterTests {
 
     @Test
     public void testInvalidBasicAuthorizationTokenIsIgnored() throws Exception {
-        // Setup our HTTP request
         String token = "NOT_A_VALID_TOKEN_AS_MISSING_COLON";
         MockHttpServletRequest request = new MockHttpServletRequest();
         request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes())));
         request.setServletPath("/some_file.html");
         request.setSession(new MockHttpSession());
+        final MockHttpServletResponse response = new MockHttpServletResponse();
 
-        // The filter chain shouldn't proceed
-        executeFilterInContainerSimulator(filter, request, false);
+        FilterChain chain = mock(FilterChain.class);
+        filter.doFilter(request, response, chain);
 
+        verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class));
         assertNull(SecurityContextHolder.getContext().getAuthentication());
+        assertEquals(401, response.getStatus());
+    }
+
+    @Test
+    public void invalidBase64IsIgnored() throws Exception {
+        MockHttpServletRequest request = new MockHttpServletRequest();
+        request.addHeader("Authorization", "Basic NOT_VALID_BASE64");
+        request.setServletPath("/some_file.html");
+        request.setSession(new MockHttpSession());
+        final MockHttpServletResponse response = new MockHttpServletResponse();
+
+        FilterChain chain = mock(FilterChain.class);
+        filter.doFilter(request, response, chain);
+        // The filter chain shouldn't proceed
+        verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class));
+        assertNull(SecurityContextHolder.getContext().getAuthentication());
+        assertEquals(401, response.getStatus());
     }
 
     @Test
     public void testNormalOperation() throws Exception {
-        // Setup our HTTP request
         String token = "rod:koala";
         MockHttpServletRequest request = new MockHttpServletRequest();
         request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes())));
         request.setServletPath("/some_file.html");
-//        request.setSession(new MockHttpSession());
 
         // Test
         assertNull(SecurityContextHolder.getContext().getAuthentication());
-        executeFilterInContainerSimulator(filter, request, true);
+        FilterChain chain = mock(FilterChain.class);
+        filter.doFilter(request, new MockHttpServletResponse(), chain);
 
+        verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
         assertNotNull(SecurityContextHolder.getContext().getAuthentication());
         assertEquals("rod", SecurityContextHolder.getContext().getAuthentication().getName());
-
     }
 
     @Test
     public void testOtherAuthorizationSchemeIsIgnored() throws Exception {
-        // Setup our HTTP request
+
         MockHttpServletRequest request = new MockHttpServletRequest();
         request.addHeader("Authorization", "SOME_OTHER_AUTHENTICATION_SCHEME");
         request.setServletPath("/some_file.html");
+        FilterChain chain = mock(FilterChain.class);
+        filter.doFilter(request, new MockHttpServletResponse(), chain);
 
-        // Test
-        executeFilterInContainerSimulator(filter, request, true);
-
+        verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
         assertNull(SecurityContextHolder.getContext().getAuthentication());
     }
 
@@ -179,27 +181,36 @@ public class BasicAuthenticationFilterTests {
 
     @Test
     public void testSuccessLoginThenFailureLoginResultsInSessionLosingToken() throws Exception {
-        // Setup our HTTP request
         String token = "rod:koala";
         MockHttpServletRequest request = new MockHttpServletRequest();
         request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes())));
         request.setServletPath("/some_file.html");
+        final MockHttpServletResponse response1 = new MockHttpServletResponse();
 
-        // Test
-        executeFilterInContainerSimulator(filter, request, true);
+        FilterChain chain = mock(FilterChain.class);
+        filter.doFilter(request, response1, chain);
 
+        verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
+
+        // Test
         assertNotNull(SecurityContextHolder.getContext().getAuthentication());
         assertEquals("rod", SecurityContextHolder.getContext().getAuthentication().getName());
 
         // NOW PERFORM FAILED AUTHENTICATION
-        // Setup our HTTP request
+
         token = "otherUser:WRONG_PASSWORD";
         request = new MockHttpServletRequest();
         request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes())));
+        final MockHttpServletResponse response2 = new MockHttpServletResponse();
+
+        chain = mock(FilterChain.class);
+        filter.doFilter(request, response2, chain);
+
+        verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class));
         request.setServletPath("/some_file.html");
 
-        // Test - the filter chain will not be invoked, as we get a 403 forbidden response
-        MockHttpServletResponse response = executeFilterInContainerSimulator(filter, request, false);
+        // Test - the filter chain will not be invoked, as we get a 401 forbidden response
+        MockHttpServletResponse response = response2;
 
         assertNull(SecurityContextHolder.getContext().getAuthentication());
         assertEquals(401, response.getStatus());
@@ -207,7 +218,6 @@ public class BasicAuthenticationFilterTests {
 
     @Test
     public void testWrongPasswordContinuesFilterChainIfIgnoreFailureIsTrue() throws Exception {
-        // Setup our HTTP request
         String token = "rod:WRONG_PASSWORD";
         MockHttpServletRequest request = new MockHttpServletRequest();
         request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes())));
@@ -216,26 +226,30 @@ public class BasicAuthenticationFilterTests {
 
         filter.setIgnoreFailure(true);
         assertTrue(filter.isIgnoreFailure());
+        FilterChain chain = mock(FilterChain.class);
+        filter.doFilter(request, new MockHttpServletResponse(), chain);
 
-        // Test - the filter chain will be invoked, as we've set ignoreFailure = true
-        executeFilterInContainerSimulator(filter, request, true);
+        verify(chain).doFilter(any(ServletRequest.class), any(ServletResponse.class));
 
+        // Test - the filter chain will be invoked, as we've set ignoreFailure = true
         assertNull(SecurityContextHolder.getContext().getAuthentication());
     }
 
     @Test
     public void testWrongPasswordReturnsForbiddenIfIgnoreFailureIsFalse() throws Exception {
-        // Setup our HTTP request
         String token = "rod:WRONG_PASSWORD";
         MockHttpServletRequest request = new MockHttpServletRequest();
         request.addHeader("Authorization", "Basic " + new String(Base64.encodeBase64(token.getBytes())));
         request.setServletPath("/some_file.html");
         request.setSession(new MockHttpSession());
         assertFalse(filter.isIgnoreFailure());
+        final MockHttpServletResponse response = new MockHttpServletResponse();
 
-        // Test - the filter chain will not be invoked, as we get a 403 forbidden response
-        MockHttpServletResponse response = executeFilterInContainerSimulator(filter, request, false);
+        FilterChain chain = mock(FilterChain.class);
+        filter.doFilter(request, response, chain);
 
+        // Test - the filter chain will not be invoked, as we get a 401 forbidden response
+        verify(chain, never()).doFilter(any(ServletRequest.class), any(ServletResponse.class));
         assertNull(SecurityContextHolder.getContext().getAuthentication());
         assertEquals(401, response.getStatus());
     }