Преглед на файлове

SEC-481: Refactoring commence method of AuthenticationProcessingFilterEtryPoint to allow alternative redirect options. Extracted two methods, "buildRedirectUrlToLoginPage" and "buildHttpsRedirectUrlForRequest" and introduced a RedirectUrlBuilder class for assembling the URLs from schemes, ports etc.

Luke Taylor преди 17 години
родител
ревизия
99621a225d

+ 83 - 61
core/src/main/java/org/springframework/security/ui/webapp/AuthenticationProcessingFilterEntryPoint.java

@@ -23,6 +23,7 @@ import org.springframework.security.util.PortMapper;
 import org.springframework.security.util.PortMapperImpl;
 import org.springframework.security.util.PortResolver;
 import org.springframework.security.util.PortResolverImpl;
+import org.springframework.security.util.RedirectUrlBuilder;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -102,55 +103,26 @@ public class AuthenticationProcessingFilterEntryPoint implements AuthenticationE
         return getLoginFormUrl();
     }
 
+    /**
+     * Performs the redirect (or forward) to the login form URL.
+     */
     public void commence(ServletRequest request, ServletResponse response, AuthenticationException authException)
             throws IOException, ServletException {
 
         HttpServletRequest httpRequest = (HttpServletRequest) request;
         HttpServletResponse httpResponse = (HttpServletResponse) response;
-        String scheme = request.getScheme();
-        String serverName = request.getServerName();
-        int serverPort = portResolver.getServerPort(request);
-        String contextPath = httpRequest.getContextPath();
-
-        boolean inHttp = "http".equals(scheme.toLowerCase());
-        boolean inHttps = "https".equals(scheme.toLowerCase());
-        boolean includePort = true;
-        boolean doForceHttps = false;
-        Integer httpsPort = null;
-
-        if (inHttp && (serverPort == 80)) {
-            includePort = false;
-        } else if (inHttps && (serverPort == 443)) {
-            includePort = false;
-        }
-
-        if (forceHttps && inHttp) {
-            httpsPort = portMapper.lookupHttpsPort(new Integer(serverPort));
 
-            if (httpsPort != null) {
-                doForceHttps = true;
-                if (httpsPort.intValue() == 443) {
-                    includePort = false;
-                } else {
-                    includePort = true;
-                }
-            }
-        }
-
-        String loginForm = determineUrlToUseForThisRequest(httpRequest, httpResponse, authException);
         String redirectUrl = null;
 
         if (serverSideRedirect) {
-            if (doForceHttps) {
-                // before doing server side redirect, we need to do client redirect to https.
 
-                String servletPath = httpRequest.getServletPath();
-                String pathInfo = httpRequest.getPathInfo();
-                String query = httpRequest.getQueryString();
+            if (forceHttps && "http".equals(request.getScheme())) {
+                redirectUrl = buildHttpsRedirectUrlForRequest(httpRequest);
+            }
 
-                redirectUrl = "https://" + serverName + ((includePort) ? (":" + httpsPort) : "") + contextPath
-                        + servletPath + (pathInfo == null ? "" : pathInfo) + (query == null ? "" : "?" + query);
-            } else {
+            if (redirectUrl == null) {
+                String loginForm = determineUrlToUseForThisRequest(httpRequest, httpResponse, authException);
+                
                 if (logger.isDebugEnabled()) {
                     logger.debug("Server side forward to: " + loginForm);
                 }
@@ -162,40 +134,71 @@ public class AuthenticationProcessingFilterEntryPoint implements AuthenticationE
                 return;
             }
         } else {
-            if (doForceHttps) {
-                redirectUrl = "https://" + serverName + ((includePort) ? (":" + httpsPort) : "") + contextPath
-                        + loginForm;
-            } else {
-                redirectUrl = scheme + "://" + serverName + ((includePort) ? (":" + serverPort) : "") + contextPath
-                        + loginForm;
-            }
-        }
+            // redirect to login page. Use https if forceHttps true
+
+            redirectUrl = buildRedirectUrlToLoginPage(httpRequest, httpResponse, authException);
 
-        if (logger.isDebugEnabled()) {
-            logger.debug("Redirecting to: " + redirectUrl);
         }
 
         httpResponse.sendRedirect(httpResponse.encodeRedirectURL(redirectUrl));
     }
 
-    public boolean getForceHttps() {
-        return forceHttps;
-    }
+    protected String buildRedirectUrlToLoginPage(HttpServletRequest request, HttpServletResponse response,
+            AuthenticationException authException) {
 
-    public String getLoginFormUrl() {
-        return loginFormUrl;
-    }
+        String loginForm = determineUrlToUseForThisRequest(request, response, authException);
+        int serverPort = portResolver.getServerPort(request);
+        String scheme = request.getScheme();
 
-    public PortMapper getPortMapper() {
-        return portMapper;
-    }
+        RedirectUrlBuilder urlBuilder = new RedirectUrlBuilder();
 
-    public PortResolver getPortResolver() {
-        return portResolver;
+        urlBuilder.setScheme(scheme);
+        urlBuilder.setServerName(request.getServerName());
+        urlBuilder.setPort(serverPort);
+        urlBuilder.setContextPath(request.getContextPath());
+        urlBuilder.setPathInfo(loginForm);
+
+        if (forceHttps && "http".equals(scheme)) {
+            Integer httpsPort = portMapper.lookupHttpsPort(new Integer(serverPort));
+
+            if (httpsPort != null) {
+                // Overwrite scheme and port in the redirect URL
+                urlBuilder.setScheme("https");
+                urlBuilder.setPort(httpsPort.intValue());
+            } else {
+                logger.warn("Unable to redirect to HTTPS as no port mapping found for HTTP port " + serverPort);
+            }
+        }
+
+        return urlBuilder.getUrl();
     }
 
-    public boolean isServerSideRedirect() {
-        return serverSideRedirect;
+    /**
+     * Builds a URL to redirect the supplied request to HTTPS.
+     */
+    protected String buildHttpsRedirectUrlForRequest(HttpServletRequest request)
+            throws IOException, ServletException {
+
+        int serverPort = portResolver.getServerPort(request);
+        Integer httpsPort = portMapper.lookupHttpsPort(new Integer(serverPort));
+
+        if (httpsPort != null) {
+            RedirectUrlBuilder urlBuilder = new RedirectUrlBuilder();
+            urlBuilder.setScheme("https");
+            urlBuilder.setServerName(request.getServerName());
+            urlBuilder.setPort(httpsPort.intValue());
+            urlBuilder.setContextPath(request.getContextPath());
+            urlBuilder.setServletPath(request.getServletPath());
+            urlBuilder.setPathInfo(request.getPathInfo());
+            urlBuilder.setQuery(request.getQueryString());
+
+            return urlBuilder.getUrl();
+        }
+
+        // Fall through to server-side forward with warning message
+        logger.warn("Unable to redirect to HTTPS as no port mapping found for HTTP port " + serverPort);
+
+        return null;
     }
 
     /**
@@ -210,6 +213,10 @@ public class AuthenticationProcessingFilterEntryPoint implements AuthenticationE
         this.forceHttps = forceHttps;
     }
 
+    protected boolean isForceHttps() {
+        return forceHttps;
+    }
+
     /**
      * The URL where the <code>AuthenticationProcessingFilter</code> login
      * page can be found. Should be relative to the web-app context path, and
@@ -221,14 +228,26 @@ public class AuthenticationProcessingFilterEntryPoint implements AuthenticationE
         this.loginFormUrl = loginFormUrl;
     }
 
+    public String getLoginFormUrl() {
+        return loginFormUrl;
+    }
+
     public void setPortMapper(PortMapper portMapper) {
         this.portMapper = portMapper;
     }
 
+    protected PortMapper getPortMapper() {
+        return portMapper;
+    }
+
     public void setPortResolver(PortResolver portResolver) {
         this.portResolver = portResolver;
     }
 
+    protected PortResolver getPortResolver() {
+        return portResolver;
+    }
+
     /**
      * Tells if we are to do a server side include of the <code>loginFormUrl</code> instead of a 302 redirect.
      *
@@ -238,4 +257,7 @@ public class AuthenticationProcessingFilterEntryPoint implements AuthenticationE
         this.serverSideRedirect = serverSideRedirect;
 	}
 
+    protected boolean isServerSideRedirect() {
+        return serverSideRedirect;
+    }
 }

+ 85 - 0
core/src/main/java/org/springframework/security/util/RedirectUrlBuilder.java

@@ -0,0 +1,85 @@
+package org.springframework.security.util;
+
+import org.springframework.util.Assert;
+
+/**
+ * Internal class for building redirect URLs.
+ *
+ * Could probably make more use of the classes in java.net for this.
+ *
+ * @author Luke Taylor
+ * @version $Id$
+ * @since 2.0
+ */
+public class RedirectUrlBuilder {
+    private String scheme;
+    private String serverName;
+    private int port;
+    private String contextPath;
+    private String servletPath;
+    private String pathInfo;
+    private String query;
+
+    public void setScheme(String scheme) {
+        if(! ("http".equals(scheme) | "https".equals(scheme)) ) {
+            throw new IllegalArgumentException("Unsupported scheme '" + scheme + "'");
+        }
+        this.scheme = scheme;
+    }
+
+    public void setServerName(String serverName) {
+        this.serverName = serverName;
+    }
+
+    public void setPort(int port) {
+        this.port = port;
+    }
+
+    public void setContextPath(String contextPath) {
+        this.contextPath = contextPath;
+    }
+
+    public void setServletPath(String servletPath) {
+        this.servletPath = servletPath;
+    }
+
+    public void setPathInfo(String pathInfo) {
+        this.pathInfo = pathInfo;
+    }
+
+    public void setQuery(String query) {
+        this.query = query;
+    }
+
+    public String getUrl() {
+        StringBuffer sb = new StringBuffer();
+
+        Assert.notNull(scheme);
+        Assert.notNull(serverName);
+
+        sb.append(scheme).append("://").append(serverName);
+
+        // Append the port number if it's not standard for the scheme
+        if (port != (scheme.equals("http") ? 80 : 443)) {
+            sb.append(":").append(Integer.toString(port));
+        }
+
+        if (contextPath != null) {
+            sb.append(contextPath);
+        }
+
+        if (servletPath != null) {
+            sb.append(servletPath);
+        }
+
+        if (pathInfo != null) {
+            sb.append(pathInfo);
+        }
+
+        if (query != null) {
+            sb.append("?").append(query);
+        }
+        
+        return sb.toString();
+    }
+}

+ 47 - 19
core/src/test/java/org/springframework/security/ui/webapp/AuthenticationProcessingFilterEntryPointTests.java

@@ -38,14 +38,6 @@ import java.util.Map;
 public class AuthenticationProcessingFilterEntryPointTests extends TestCase {
     //~ Methods ========================================================================================================
 
-    public static void main(String[] args) {
-        junit.textui.TestRunner.run(AuthenticationProcessingFilterEntryPointTests.class);
-    }
-
-    public final void setUp() throws Exception {
-        super.setUp();
-    }
-
     public void testDetectsMissingLoginFormUrl() throws Exception {
         AuthenticationProcessingFilterEntryPoint ep = new AuthenticationProcessingFilterEntryPoint();
         ep.setPortMapper(new PortMapperImpl());
@@ -95,13 +87,12 @@ public class AuthenticationProcessingFilterEntryPointTests extends TestCase {
         assertTrue(ep.getPortResolver() != null);
 
         ep.setForceHttps(false);
-        assertFalse(ep.getForceHttps());
+        assertFalse(ep.isForceHttps());
         ep.setForceHttps(true);
-        assertTrue(ep.getForceHttps());
+        assertTrue(ep.isForceHttps());
     }
 
-    public void testHttpsOperationFromOriginalHttpUrl()
-        throws Exception {
+    public void testHttpsOperationFromOriginalHttpUrl() throws Exception {
         MockHttpServletRequest request = new MockHttpServletRequest();
         request.setRequestURI("/some_path");
         request.setScheme("http");
@@ -152,8 +143,7 @@ public class AuthenticationProcessingFilterEntryPointTests extends TestCase {
         assertEquals("https://www.example.com:9999/bigWebApp/hello", response.getRedirectedUrl());
     }
 
-    public void testHttpsOperationFromOriginalHttpsUrl()
-        throws Exception {
+    public void testHttpsOperationFromOriginalHttpsUrl() throws Exception {
         MockHttpServletRequest request = new MockHttpServletRequest();
         request.setRequestURI("/some_path");
         request.setScheme("https");
@@ -198,16 +188,13 @@ public class AuthenticationProcessingFilterEntryPointTests extends TestCase {
 
         MockHttpServletResponse response = new MockHttpServletResponse();
 
-        ep.afterPropertiesSet();
         ep.commence(request, response, null);
         assertEquals("http://www.example.com/bigWebApp/hello", response.getRedirectedUrl());
     }
 
-    public void testOperationWhenHttpsRequestsButHttpsPortUnknown()
-        throws Exception {
+    public void testOperationWhenHttpsRequestsButHttpsPortUnknown() throws Exception {
         AuthenticationProcessingFilterEntryPoint ep = new AuthenticationProcessingFilterEntryPoint();
         ep.setLoginFormUrl("/hello");
-        ep.setPortMapper(new PortMapperImpl());
         ep.setPortResolver(new MockPortResolver(8888, 1234));
         ep.setForceHttps(true);
         ep.afterPropertiesSet();
@@ -222,10 +209,51 @@ public class AuthenticationProcessingFilterEntryPointTests extends TestCase {
 
         MockHttpServletResponse response = new MockHttpServletResponse();
 
-        ep.afterPropertiesSet();
         ep.commence(request, response, null);
 
         // Response doesn't switch to HTTPS, as we didn't know HTTP port 8888 to HTTP port mapping
         assertEquals("http://www.example.com:8888/bigWebApp/hello", response.getRedirectedUrl());
     }
+
+    public void testServerSideRedirectWithoutForceHttpsForwardsToLoginPage() throws Exception {
+        AuthenticationProcessingFilterEntryPoint ep = new AuthenticationProcessingFilterEntryPoint();
+        ep.setLoginFormUrl("/hello");
+        ep.setServerSideRedirect(true);
+        ep.afterPropertiesSet();
+        MockHttpServletRequest request = new MockHttpServletRequest();
+        request.setRequestURI("/bigWebApp/some_path");
+        request.setServletPath("/some_path");
+        request.setContextPath("/bigWebApp");
+        request.setScheme("http");
+        request.setServerName("www.example.com");
+        request.setContextPath("/bigWebApp");
+        request.setServerPort(80);
+
+        MockHttpServletResponse response = new MockHttpServletResponse();
+
+        ep.commence(request, response, null);
+        assertEquals("/hello", response.getForwardedUrl());
+    }
+
+    public void testServerSideRedirectWithForceHttpsRedirectsCurrentRequest() throws Exception {
+        AuthenticationProcessingFilterEntryPoint ep = new AuthenticationProcessingFilterEntryPoint();
+        ep.setLoginFormUrl("/hello");
+        ep.setServerSideRedirect(true);
+        ep.setForceHttps(true);
+        ep.afterPropertiesSet();
+        MockHttpServletRequest request = new MockHttpServletRequest();
+        request.setRequestURI("/bigWebApp/some_path");
+        request.setServletPath("/some_path");
+        request.setContextPath("/bigWebApp");
+        request.setScheme("http");
+        request.setServerName("www.example.com");
+        request.setContextPath("/bigWebApp");
+        request.setServerPort(80);
+
+        MockHttpServletResponse response = new MockHttpServletResponse();
+
+        ep.commence(request, response, null);
+        assertEquals("https://www.example.com/bigWebApp/some_path", response.getRedirectedUrl());
+    }
+
 }