Przeglądaj źródła

Refactored channel entry points to use a common base clase since the functionality is almost exactlythe same (apart from the function called on the PortMapper).

Luke Taylor 17 lat temu
rodzic
commit
60b7e2d4f2

+ 92 - 0
core/src/main/java/org/springframework/security/securechannel/AbstractRetryEntryPoint.java

@@ -0,0 +1,92 @@
+package org.springframework.security.securechannel;
+
+import org.springframework.security.util.PortMapper;
+import org.springframework.security.util.PortResolver;
+import org.springframework.security.util.PortMapperImpl;
+import org.springframework.security.util.PortResolverImpl;
+import org.springframework.util.Assert;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import java.io.IOException;
+
+/**
+ * @author Luke Taylor
+ * @version $Id$
+ */
+public abstract class AbstractRetryEntryPoint implements ChannelEntryPoint {
+    //~ Static fields/initializers =====================================================================================
+    private static final Log logger = LogFactory.getLog(RetryWithHttpEntryPoint.class);
+
+    //~ Instance fields ================================================================================================
+
+    private PortMapper portMapper = new PortMapperImpl();
+    private PortResolver portResolver = new PortResolverImpl();
+    /** The scheme ("http://" or "https://") */
+    private String scheme;
+    /** The standard port for the scheme (80 for http, 443 for https) */
+    private int standardPort;
+
+    //~ Constructors ===================================================================================================
+
+    public AbstractRetryEntryPoint(String scheme, int standardPort) {
+        this.scheme = scheme;
+        this.standardPort = standardPort;
+    }
+
+    //~ Methods ========================================================================================================
+
+    public void commence(ServletRequest req, ServletResponse res) throws IOException, ServletException {
+        HttpServletRequest request = (HttpServletRequest) req;
+
+        String pathInfo = request.getPathInfo();
+        String queryString = request.getQueryString();
+        String contextPath = request.getContextPath();
+        String destination = request.getServletPath() + ((pathInfo == null) ? "" : pathInfo)
+            + ((queryString == null) ? "" : ("?" + queryString));
+
+        String redirectUrl = contextPath;
+
+        Integer currentPort = new Integer(portResolver.getServerPort(request));
+        Integer redirectPort = getMappedPort(currentPort);
+
+        if (redirectPort != null) {
+            boolean includePort = redirectPort.intValue() != standardPort;
+
+            redirectUrl = scheme + request.getServerName() + ((includePort) ? (":" + redirectPort) : "") + contextPath
+                + destination;
+        }
+
+        if (logger.isDebugEnabled()) {
+            logger.debug("Redirecting to: " + redirectUrl);
+        }
+
+        ((HttpServletResponse) res).sendRedirect(((HttpServletResponse) res).encodeRedirectURL(redirectUrl));
+    }
+
+    protected abstract Integer getMappedPort(Integer mapFromPort);
+
+    protected PortMapper getPortMapper() {
+        return portMapper;
+    }
+
+    protected PortResolver getPortResolver() {
+        return portResolver;
+    }
+
+    public void setPortMapper(PortMapper portMapper) {
+        Assert.notNull(portMapper, "portMapper cannot be null");
+        this.portMapper = portMapper;
+    }
+
+    public void setPortResolver(PortResolver portResolver) {
+        Assert.notNull(portResolver, "portResolver cannot be null");
+        this.portResolver = portResolver;
+    }
+}

+ 9 - 84
core/src/main/java/org/springframework/security/securechannel/RetryWithHttpEntryPoint.java

@@ -15,98 +15,23 @@
 
 package org.springframework.security.securechannel;
 
-import org.springframework.security.util.PortMapper;
-import org.springframework.security.util.PortMapperImpl;
-import org.springframework.security.util.PortResolver;
-import org.springframework.security.util.PortResolverImpl;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-
-import org.springframework.beans.factory.InitializingBean;
-
-import org.springframework.util.Assert;
-
-import java.io.IOException;
-
-import javax.servlet.ServletException;
-import javax.servlet.ServletRequest;
-import javax.servlet.ServletResponse;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-
 
 /**
- * Commences an insecure channel by retrying the original request using HTTP.<P>This entry point should suffice in
- * most circumstances. However, it is not intended to properly handle HTTP POSTs or other usage where a standard
- * redirect would cause an issue.</p>
+ * Commences an insecure channel by retrying the original request using HTTP.
+ * <p>
+ * This entry point should suffice in most circumstances. However, it is not intended to properly handle HTTP POSTs or
+ * other usage where a standard redirect would cause an issue.
  *
  * @author Ben Alex
  * @version $Id$
  */
-public class RetryWithHttpEntryPoint implements InitializingBean, ChannelEntryPoint {
-    //~ Static fields/initializers =====================================================================================
-
-    private static final Log logger = LogFactory.getLog(RetryWithHttpEntryPoint.class);
-
-    //~ Instance fields ================================================================================================
-
-    private PortMapper portMapper = new PortMapperImpl();
-    private PortResolver portResolver = new PortResolverImpl();
-
-    //~ Methods ========================================================================================================
-
-    public void afterPropertiesSet() throws Exception {
-        Assert.notNull(portMapper, "portMapper is required");
-        Assert.notNull(portResolver, "portResolver is required");
-    }
-
-    public void commence(ServletRequest request, ServletResponse response)
-        throws IOException, ServletException {
-        HttpServletRequest req = (HttpServletRequest) request;
-
-        String pathInfo = req.getPathInfo();
-        String queryString = req.getQueryString();
-        String contextPath = req.getContextPath();
-        String destination = req.getServletPath() + ((pathInfo == null) ? "" : pathInfo)
-            + ((queryString == null) ? "" : ("?" + queryString));
-
-        String redirectUrl = contextPath;
-
-        Integer httpsPort = new Integer(portResolver.getServerPort(req));
-        Integer httpPort = portMapper.lookupHttpPort(httpsPort);
-
-        if (httpPort != null) {
-            boolean includePort = true;
-
-            if (httpPort.intValue() == 80) {
-                includePort = false;
-            }
-
-            redirectUrl = "http://" + req.getServerName() + ((includePort) ? (":" + httpPort) : "") + contextPath
-                + destination;
-        }
-
-        if (logger.isDebugEnabled()) {
-            logger.debug("Redirecting to: " + redirectUrl);
-        }
-
-        ((HttpServletResponse) response).sendRedirect(((HttpServletResponse) response).encodeRedirectURL(redirectUrl));
-    }
-
-    public PortMapper getPortMapper() {
-        return portMapper;
-    }
-
-    public PortResolver getPortResolver() {
-        return portResolver;
-    }
+public class RetryWithHttpEntryPoint extends AbstractRetryEntryPoint {
 
-    public void setPortMapper(PortMapper portMapper) {
-        this.portMapper = portMapper;
+    public RetryWithHttpEntryPoint() {
+        super("http://", 80);
     }
 
-    public void setPortResolver(PortResolver portResolver) {
-        this.portResolver = portResolver;
+    protected Integer getMappedPort(Integer mapFromPort) {
+        return getPortMapper().lookupHttpPort(mapFromPort);
     }
 }

+ 9 - 85
core/src/main/java/org/springframework/security/securechannel/RetryWithHttpsEntryPoint.java

@@ -15,98 +15,22 @@
 
 package org.springframework.security.securechannel;
 
-import org.springframework.security.util.PortMapper;
-import org.springframework.security.util.PortMapperImpl;
-import org.springframework.security.util.PortResolver;
-import org.springframework.security.util.PortResolverImpl;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-
-import org.springframework.beans.factory.InitializingBean;
-
-import org.springframework.util.Assert;
-
-import java.io.IOException;
-
-import javax.servlet.ServletException;
-import javax.servlet.ServletRequest;
-import javax.servlet.ServletResponse;
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-
-
 /**
- * Commences a secure channel by retrying the original request using HTTPS.<P>This entry point should suffice in
- * most circumstances. However, it is not intended to properly handle HTTP POSTs or other usage where a standard
- * redirect would cause an issue.</p>
+ * Commences a secure channel by retrying the original request using HTTPS.
+ * <p>
+ * This entry point should suffice in most circumstances. However, it is not intended to properly handle HTTP POSTs
+ * or other usage where a standard redirect would cause an issue.</p>
  *
  * @author Ben Alex
  * @version $Id$
  */
-public class RetryWithHttpsEntryPoint implements InitializingBean, ChannelEntryPoint {
-    //~ Static fields/initializers =====================================================================================
-
-    private static final Log logger = LogFactory.getLog(RetryWithHttpsEntryPoint.class);
-
-    //~ Instance fields ================================================================================================
-
-    private PortMapper portMapper = new PortMapperImpl();
-    private PortResolver portResolver = new PortResolverImpl();
-
-    //~ Methods ========================================================================================================
-
-    public void afterPropertiesSet() throws Exception {
-        Assert.notNull(portMapper, "portMapper is required");
-        Assert.notNull(portResolver, "portResolver is required");
-    }
-
-    public void commence(ServletRequest request, ServletResponse response)
-        throws IOException, ServletException {
-        HttpServletRequest req = (HttpServletRequest) request;
-
-        String pathInfo = req.getPathInfo();
-        String queryString = req.getQueryString();
-        String contextPath = req.getContextPath();
-        String destination = req.getServletPath() + ((pathInfo == null) ? "" : pathInfo)
-            + ((queryString == null) ? "" : ("?" + queryString));
-
-        String redirectUrl = contextPath;
-
-        Integer httpPort = new Integer(portResolver.getServerPort(req));
-        Integer httpsPort = portMapper.lookupHttpsPort(httpPort);
-
-        if (httpsPort != null) {
-            boolean includePort = true;
-
-            if (httpsPort.intValue() == 443) {
-                includePort = false;
-            }
-
-            redirectUrl = "https://" + req.getServerName() + ((includePort) ? (":" + httpsPort) : "") + contextPath
-                + destination;
-        }
-
-        if (logger.isDebugEnabled()) {
-            logger.debug("Redirecting to: " + redirectUrl);
-        }
-
-        ((HttpServletResponse) response).sendRedirect(((HttpServletResponse) response).encodeRedirectURL(redirectUrl));
-    }
-
-    public PortMapper getPortMapper() {
-        return portMapper;
-    }
-
-    public PortResolver getPortResolver() {
-        return portResolver;
-    }
+public class RetryWithHttpsEntryPoint extends AbstractRetryEntryPoint {
 
-    public void setPortMapper(PortMapper portMapper) {
-        this.portMapper = portMapper;
+    public RetryWithHttpsEntryPoint() {
+        super("https://", 443);
     }
 
-    public void setPortResolver(PortResolver portResolver) {
-        this.portResolver = portResolver;
+    protected Integer getMappedPort(Integer mapFromPort) {
+        return getPortMapper().lookupHttpsPort(mapFromPort);
     }
 }

+ 2 - 18
core/src/test/java/org/springframework/security/securechannel/RetryWithHttpEntryPointTests.java

@@ -37,35 +37,23 @@ import java.util.Map;
 public class RetryWithHttpEntryPointTests extends TestCase {
     //~ Methods ========================================================================================================
 
-    public static void main(String[] args) {
-        junit.textui.TestRunner.run(RetryWithHttpEntryPointTests.class);
-    }
-
-    public final void setUp() throws Exception {
-        super.setUp();
-    }
-
     public void testDetectsMissingPortMapper() throws Exception {
         RetryWithHttpEntryPoint ep = new RetryWithHttpEntryPoint();
-        ep.setPortMapper(null);
 
         try {
-            ep.afterPropertiesSet();
+            ep.setPortMapper(null);
             fail("Should have thrown IllegalArgumentException");
         } catch (IllegalArgumentException expected) {
-            assertEquals("portMapper is required", expected.getMessage());
         }
     }
 
     public void testDetectsMissingPortResolver() throws Exception {
         RetryWithHttpEntryPoint ep = new RetryWithHttpEntryPoint();
-        ep.setPortResolver(null);
 
         try {
-            ep.afterPropertiesSet();
+            ep.setPortResolver(null);
             fail("Should have thrown IllegalArgumentException");
         } catch (IllegalArgumentException expected) {
-            assertEquals("portResolver is required", expected.getMessage());
         }
     }
 
@@ -92,7 +80,6 @@ public class RetryWithHttpEntryPointTests extends TestCase {
         RetryWithHttpEntryPoint ep = new RetryWithHttpEntryPoint();
         ep.setPortMapper(new PortMapperImpl());
         ep.setPortResolver(new MockPortResolver(80, 443));
-        ep.afterPropertiesSet();
 
         ep.commence(request, response);
         assertEquals("http://www.example.com/bigWebApp/hello/pathInfo.html?open=true", response.getRedirectedUrl());
@@ -113,7 +100,6 @@ public class RetryWithHttpEntryPointTests extends TestCase {
         RetryWithHttpEntryPoint ep = new RetryWithHttpEntryPoint();
         ep.setPortMapper(new PortMapperImpl());
         ep.setPortResolver(new MockPortResolver(80, 443));
-        ep.afterPropertiesSet();
 
         ep.commence(request, response);
         assertEquals("http://www.example.com/bigWebApp/hello", response.getRedirectedUrl());
@@ -135,7 +121,6 @@ public class RetryWithHttpEntryPointTests extends TestCase {
         RetryWithHttpEntryPoint ep = new RetryWithHttpEntryPoint();
         ep.setPortMapper(new PortMapperImpl());
         ep.setPortResolver(new MockPortResolver(8768, 1234));
-        ep.afterPropertiesSet();
 
         ep.commence(request, response);
         assertEquals("/bigWebApp", response.getRedirectedUrl());
@@ -161,7 +146,6 @@ public class RetryWithHttpEntryPointTests extends TestCase {
         RetryWithHttpEntryPoint ep = new RetryWithHttpEntryPoint();
         ep.setPortResolver(new MockPortResolver(8888, 9999));
         ep.setPortMapper(portMapper);
-        ep.afterPropertiesSet();
 
         ep.commence(request, response);
         assertEquals("http://www.example.com:8888/bigWebApp/hello/pathInfo.html?open=true", response.getRedirectedUrl());

+ 3 - 12
core/src/test/java/org/springframework/security/securechannel/RetryWithHttpsEntryPointTests.java

@@ -47,25 +47,21 @@ public class RetryWithHttpsEntryPointTests extends TestCase {
 
     public void testDetectsMissingPortMapper() throws Exception {
         RetryWithHttpsEntryPoint ep = new RetryWithHttpsEntryPoint();
-        ep.setPortMapper(null);
 
         try {
-            ep.afterPropertiesSet();
+            ep.setPortMapper(null);
             fail("Should have thrown IllegalArgumentException");
         } catch (IllegalArgumentException expected) {
-            assertEquals("portMapper is required", expected.getMessage());
         }
     }
 
     public void testDetectsMissingPortResolver() throws Exception {
         RetryWithHttpsEntryPoint ep = new RetryWithHttpsEntryPoint();
-        ep.setPortResolver(null);
 
         try {
-            ep.afterPropertiesSet();
+            ep.setPortResolver(null);
             fail("Should have thrown IllegalArgumentException");
         } catch (IllegalArgumentException expected) {
-            assertEquals("portResolver is required", expected.getMessage());
         }
     }
 
@@ -92,7 +88,6 @@ public class RetryWithHttpsEntryPointTests extends TestCase {
         RetryWithHttpsEntryPoint ep = new RetryWithHttpsEntryPoint();
         ep.setPortMapper(new PortMapperImpl());
         ep.setPortResolver(new MockPortResolver(80, 443));
-        ep.afterPropertiesSet();
 
         ep.commence(request, response);
         assertEquals("https://www.example.com/bigWebApp/hello/pathInfo.html?open=true", response.getRedirectedUrl());
@@ -113,14 +108,12 @@ public class RetryWithHttpsEntryPointTests extends TestCase {
         RetryWithHttpsEntryPoint ep = new RetryWithHttpsEntryPoint();
         ep.setPortMapper(new PortMapperImpl());
         ep.setPortResolver(new MockPortResolver(80, 443));
-        ep.afterPropertiesSet();
 
         ep.commence(request, response);
         assertEquals("https://www.example.com/bigWebApp/hello", response.getRedirectedUrl());
     }
 
-    public void testOperationWhenTargetPortIsUnknown()
-        throws Exception {
+    public void testOperationWhenTargetPortIsUnknown() throws Exception {
         MockHttpServletRequest request = new MockHttpServletRequest();
         request.setQueryString("open=true");
         request.setScheme("http");
@@ -135,7 +128,6 @@ public class RetryWithHttpsEntryPointTests extends TestCase {
         RetryWithHttpsEntryPoint ep = new RetryWithHttpsEntryPoint();
         ep.setPortMapper(new PortMapperImpl());
         ep.setPortResolver(new MockPortResolver(8768, 1234));
-        ep.afterPropertiesSet();
 
         ep.commence(request, response);
         assertEquals("/bigWebApp", response.getRedirectedUrl());
@@ -161,7 +153,6 @@ public class RetryWithHttpsEntryPointTests extends TestCase {
         RetryWithHttpsEntryPoint ep = new RetryWithHttpsEntryPoint();
         ep.setPortResolver(new MockPortResolver(8888, 9999));
         ep.setPortMapper(portMapper);
-        ep.afterPropertiesSet();
 
         ep.commence(request, response);
         assertEquals("https://www.example.com:9999/bigWebApp/hello/pathInfo.html?open=true", response.getRedirectedUrl());