Jelajahi Sumber

SEC-29: Save POST parameters on AuthenticationEntryPoint redirect.

Ben Alex 19 tahun lalu
induk
melakukan
d125569bd6

+ 10 - 17
core/src/main/java/org/acegisecurity/intercept/web/FilterInvocation.java

@@ -1,4 +1,4 @@
-/* Copyright 2004, 2005 Acegi Technology Pty Limited
+/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,6 +15,8 @@
 
 package org.acegisecurity.intercept.web;
 
+import org.acegisecurity.util.UrlUtils;
+
 import javax.servlet.FilterChain;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
@@ -88,10 +90,7 @@ public class FilterInvocation {
      * @return the full URL of this request
      */
     public String getFullRequestUrl() {
-        return getHttpRequest().getScheme() + "://"
-        + getHttpRequest().getServerName() + ":"
-        + getHttpRequest().getServerPort() + getHttpRequest().getContextPath()
-        + getRequestUrl();
+        return UrlUtils.getFullRequestUrl(this);
     }
 
     public HttpServletRequest getHttpRequest() {
@@ -106,19 +105,13 @@ public class FilterInvocation {
         return request;
     }
 
+    /**
+     * Obtains the web application-specific fragment of the URL.
+     *
+     * @return the URL, excluding any server name, context path or servlet path
+     */
     public String getRequestUrl() {
-        String pathInfo = getHttpRequest().getPathInfo();
-        String queryString = getHttpRequest().getQueryString();
-
-        String uri = getHttpRequest().getServletPath();
-
-        if (uri == null) {
-            uri = getHttpRequest().getRequestURI();
-            uri = uri.substring(getHttpRequest().getContextPath().length());
-        }
-
-        return uri + ((pathInfo == null) ? "" : pathInfo)
-        + ((queryString == null) ? "" : ("?" + queryString));
+        return UrlUtils.getRequestUrl(this);
     }
 
     public ServletResponse getResponse() {

+ 32 - 26
core/src/main/java/org/acegisecurity/securechannel/ChannelProcessingFilter.java

@@ -1,4 +1,4 @@
-/* Copyright 2004 Acegi Technology Pty Limited
+/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,6 +17,7 @@ package org.acegisecurity.securechannel;
 
 import org.acegisecurity.ConfigAttribute;
 import org.acegisecurity.ConfigAttributeDefinition;
+
 import org.acegisecurity.intercept.web.FilterInvocation;
 import org.acegisecurity.intercept.web.FilterInvocationDefinitionSource;
 
@@ -24,6 +25,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 
 import org.springframework.beans.factory.InitializingBean;
+
 import org.springframework.util.Assert;
 
 import java.io.IOException;
@@ -78,34 +80,19 @@ public class ChannelProcessingFilter implements InitializingBean, Filter {
 
     //~ Methods ================================================================
 
-    public void setChannelDecisionManager(
-        ChannelDecisionManager channelDecisionManager) {
-        this.channelDecisionManager = channelDecisionManager;
-    }
-
-    public ChannelDecisionManager getChannelDecisionManager() {
-        return channelDecisionManager;
-    }
-
-    public void setFilterInvocationDefinitionSource(
-        FilterInvocationDefinitionSource filterInvocationDefinitionSource) {
-        this.filterInvocationDefinitionSource = filterInvocationDefinitionSource;
-    }
-
-    public FilterInvocationDefinitionSource getFilterInvocationDefinitionSource() {
-        return filterInvocationDefinitionSource;
-    }
-
     public void afterPropertiesSet() throws Exception {
-        Assert.notNull(filterInvocationDefinitionSource, "filterInvocationDefinitionSource must be specified");
-        Assert.notNull(channelDecisionManager, "channelDecisionManager must be specified");
+        Assert.notNull(filterInvocationDefinitionSource,
+            "filterInvocationDefinitionSource must be specified");
+        Assert.notNull(channelDecisionManager,
+            "channelDecisionManager must be specified");
 
         Iterator iter = this.filterInvocationDefinitionSource
-                .getConfigAttributeDefinitions();
+            .getConfigAttributeDefinitions();
 
         if (iter == null) {
             if (logger.isWarnEnabled()) {
-                logger.warn("Could not validate configuration attributes as the FilterInvocationDefinitionSource did not return a ConfigAttributeDefinition Iterator");
+                logger.warn(
+                    "Could not validate configuration attributes as the FilterInvocationDefinitionSource did not return a ConfigAttributeDefinition Iterator");
             }
 
             return;
@@ -115,7 +102,7 @@ public class ChannelProcessingFilter implements InitializingBean, Filter {
 
         while (iter.hasNext()) {
             ConfigAttributeDefinition def = (ConfigAttributeDefinition) iter
-                    .next();
+                .next();
             Iterator attributes = def.getConfigAttributes();
 
             while (attributes.hasNext()) {
@@ -132,7 +119,8 @@ public class ChannelProcessingFilter implements InitializingBean, Filter {
                 logger.info("Validated configuration attributes");
             }
         } else {
-            throw new IllegalArgumentException("Unsupported configuration attributes: " + set.toString());
+            throw new IllegalArgumentException(
+                "Unsupported configuration attributes: " + set.toString());
         }
     }
 
@@ -154,7 +142,7 @@ public class ChannelProcessingFilter implements InitializingBean, Filter {
 
         if (attr != null) {
             if (logger.isDebugEnabled()) {
-                logger.debug("Request: " + fi.getFullRequestUrl()
+                logger.debug("Request: " + fi.toString()
                     + "; ConfigAttributes: " + attr.toString());
             }
 
@@ -168,5 +156,23 @@ public class ChannelProcessingFilter implements InitializingBean, Filter {
         chain.doFilter(request, response);
     }
 
+    public ChannelDecisionManager getChannelDecisionManager() {
+        return channelDecisionManager;
+    }
+
+    public FilterInvocationDefinitionSource getFilterInvocationDefinitionSource() {
+        return filterInvocationDefinitionSource;
+    }
+
     public void init(FilterConfig filterConfig) throws ServletException {}
+
+    public void setChannelDecisionManager(
+        ChannelDecisionManager channelDecisionManager) {
+        this.channelDecisionManager = channelDecisionManager;
+    }
+
+    public void setFilterInvocationDefinitionSource(
+        FilterInvocationDefinitionSource filterInvocationDefinitionSource) {
+        this.filterInvocationDefinitionSource = filterInvocationDefinitionSource;
+    }
 }

+ 21 - 10
core/src/main/java/org/acegisecurity/ui/AbstractProcessingFilter.java

@@ -26,6 +26,7 @@ import org.acegisecurity.event.authentication.InteractiveAuthenticationSuccessEv
 
 import org.acegisecurity.ui.rememberme.NullRememberMeServices;
 import org.acegisecurity.ui.rememberme.RememberMeServices;
+import org.acegisecurity.ui.savedrequest.SavedRequest;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -78,10 +79,12 @@ import javax.servlet.http.HttpServletResponse;
  * <li>
  * <code>defaultTargetUrl</code> indicates the URL that should be used for
  * redirection if the <code>HttpSession</code> attribute named {@link
- * #ACEGI_SECURITY_TARGET_URL_KEY} does not indicate the target URL once
- * authentication is completed successfully. eg: <code>/</code>. This will be
- * treated as relative to the web-app's context path, and should include the
- * leading <code>/</code>.
+ * #ACEGI_SAVED_REQUEST_KEY} does not indicate the target URL once
+ * authentication is completed successfully. eg: <code>/</code>. The
+ * <code>defaultTargetUrl</code> will be treated as relative to the web-app's
+ * context path, and should include the leading <code>/</code>. Alternatively,
+ * inclusion of a scheme name (eg http:// or https://) as the prefix will
+ * denote a fully-qualified URL and this is also supported.
  * </li>
  * <li>
  * <code>authenticationFailureUrl</code> indicates the URL that should be used
@@ -95,8 +98,8 @@ import javax.servlet.http.HttpServletResponse;
  * <li>
  * <code>alwaysUseDefaultTargetUrl</code> causes successful authentication to
  * always redirect to the <code>defaultTargetUrl</code>, even if the
- * <code>HttpSession</code> attribute named {@link
- * #ACEGI_SECURITY_TARGET_URL_KEY} defines the intended target URL.
+ * <code>HttpSession</code> attribute named {@link #ACEGI_SAVED_REQUEST_KEY}
+ * defines the intended target URL.
  * </li>
  * </ul>
  * 
@@ -132,12 +135,15 @@ import javax.servlet.http.HttpServletResponse;
  * recorded via an <code>AuthenticationManager</code>-specific application
  * event.
  * </p>
+ *
+ * @author Ben Alex
+ * @version $Id$
  */
 public abstract class AbstractProcessingFilter implements Filter,
     InitializingBean, ApplicationEventPublisherAware, MessageSourceAware {
     //~ Static fields/initializers =============================================
 
-    public static final String ACEGI_SECURITY_TARGET_URL_KEY = "ACEGI_SECURITY_TARGET_URL";
+    public static final String ACEGI_SAVED_REQUEST_KEY = "ACEGI_SAVED_REQUEST_KEY";
     public static final String ACEGI_SECURITY_LAST_EXCEPTION_KEY = "ACEGI_SECURITY_LAST_EXCEPTION";
 
     //~ Instance fields ========================================================
@@ -303,6 +309,13 @@ public abstract class AbstractProcessingFilter implements Filter,
         return continueChainBeforeSuccessfulAuthentication;
     }
 
+    public static String obtainFullRequestUrl(HttpServletRequest request) {
+        SavedRequest savedRequest = (SavedRequest) request.getSession()
+                                                          .getAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY);
+
+        return (savedRequest == null) ? null : savedRequest.getFullRequestUrl();
+    }
+
     protected void onPreAuthentication(HttpServletRequest request,
         HttpServletResponse response)
         throws AuthenticationException, IOException {}
@@ -428,9 +441,7 @@ public abstract class AbstractProcessingFilter implements Filter,
                 + authResult + "'");
         }
 
-        String targetUrl = (String) request.getSession()
-                                           .getAttribute(ACEGI_SECURITY_TARGET_URL_KEY);
-        request.getSession().removeAttribute(ACEGI_SECURITY_TARGET_URL_KEY);
+        String targetUrl = obtainFullRequestUrl(request);
 
         if (alwaysUseDefaultTargetUrl == true) {
             targetUrl = null;

+ 8 - 22
core/src/main/java/org/acegisecurity/ui/ExceptionTranslationFilter.java

@@ -24,7 +24,7 @@ import org.acegisecurity.InsufficientAuthenticationException;
 
 import org.acegisecurity.context.SecurityContextHolder;
 
-import org.acegisecurity.intercept.web.FilterInvocation;
+import org.acegisecurity.ui.savedrequest.SavedRequest;
 
 import org.acegisecurity.util.PortResolver;
 import org.acegisecurity.util.PortResolverImpl;
@@ -250,34 +250,20 @@ public class ExceptionTranslationFilter implements Filter, InitializingBean {
         AuthenticationException reason) throws ServletException, IOException {
         HttpServletRequest httpRequest = (HttpServletRequest) request;
 
-        int port = portResolver.getServerPort(httpRequest);
-        boolean includePort = true;
-
-        if ("http".equals(httpRequest.getScheme().toLowerCase())
-            && (port == 80)) {
-            includePort = false;
-        }
-
-        if ("https".equals(httpRequest.getScheme().toLowerCase())
-            && (port == 443)) {
-            includePort = false;
-        }
-
-        String targetUrl = httpRequest.getScheme() + "://"
-            + httpRequest.getServerName() + ((includePort) ? (":" + port) : "")
-            + httpRequest.getContextPath()
-            + new FilterInvocation(request, response, chain).getRequestUrl();
+        SavedRequest savedRequest = new SavedRequest(httpRequest, portResolver);
 
         if (logger.isDebugEnabled()) {
             logger.debug(
-                "Authentication entry point being called; target URL added to Session: "
-                + targetUrl);
+                "Authentication entry point being called; SavedRequest added to Session: "
+                + savedRequest);
         }
 
         if (createSessionAllowed) {
+            // Store the HTTP request itself. Used by AbstractProcessingFilter
+            // for redirection after successful authentication (SEC-29)
             httpRequest.getSession()
-                       .setAttribute(AbstractProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY,
-                targetUrl);
+                       .setAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY,
+                savedRequest);
         }
 
         // SEC-112: Clear the SecurityContextHolder's Authentication, as the

+ 152 - 0
core/src/main/java/org/acegisecurity/ui/savedrequest/Enumerator.java

@@ -0,0 +1,152 @@
+/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.acegisecurity.ui.savedrequest;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Enumeration;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.NoSuchElementException;
+
+
+/**
+ * <p>
+ * Adapter that wraps an <code>Enumeration</code> around a Java 2 collection
+ * <code>Iterator</code>.
+ * </p>
+ * 
+ * <p>
+ * Constructors are provided to easily create such wrappers.
+ * </p>
+ * 
+ * <p>
+ * This class is based on code in Apache Tomcat.
+ * </p>
+ *
+ * @author Craig McClanahan
+ * @author Andrey Grebnev
+ * @version $Id$
+ */
+public class Enumerator implements Enumeration {
+    //~ Instance fields ========================================================
+
+    /**
+     * The <code>Iterator</code> over which the <code>Enumeration</code>
+     * represented by this class actually operates.
+     */
+    private Iterator iterator = null;
+
+    //~ Constructors ===========================================================
+
+    /**
+     * Return an Enumeration over the values of the specified Collection.
+     *
+     * @param collection Collection whose values should be enumerated
+     */
+    public Enumerator(Collection collection) {
+        this(collection.iterator());
+    }
+
+    /**
+     * Return an Enumeration over the values of the specified Collection.
+     *
+     * @param collection Collection whose values should be enumerated
+     * @param clone true to clone iterator
+     */
+    public Enumerator(Collection collection, boolean clone) {
+        this(collection.iterator(), clone);
+    }
+
+    /**
+     * Return an Enumeration over the values returned by the specified
+     * Iterator.
+     *
+     * @param iterator Iterator to be wrapped
+     */
+    public Enumerator(Iterator iterator) {
+        super();
+        this.iterator = iterator;
+    }
+
+    /**
+     * Return an Enumeration over the values returned by the specified
+     * Iterator.
+     *
+     * @param iterator Iterator to be wrapped
+     * @param clone true to clone iterator
+     */
+    public Enumerator(Iterator iterator, boolean clone) {
+        super();
+
+        if (!clone) {
+            this.iterator = iterator;
+        } else {
+            List list = new ArrayList();
+
+            while (iterator.hasNext()) {
+                list.add(iterator.next());
+            }
+
+            this.iterator = list.iterator();
+        }
+    }
+
+    /**
+     * Return an Enumeration over the values of the specified Map.
+     *
+     * @param map Map whose values should be enumerated
+     */
+    public Enumerator(Map map) {
+        this(map.values().iterator());
+    }
+
+    /**
+     * Return an Enumeration over the values of the specified Map.
+     *
+     * @param map Map whose values should be enumerated
+     * @param clone true to clone iterator
+     */
+    public Enumerator(Map map, boolean clone) {
+        this(map.values().iterator(), clone);
+    }
+
+    //~ Methods ================================================================
+
+    /**
+     * Tests if this enumeration contains more elements.
+     *
+     * @return <code>true</code> if and only if this enumeration object
+     *         contains at least one more element to provide,
+     *         <code>false</code> otherwise
+     */
+    public boolean hasMoreElements() {
+        return (iterator.hasNext());
+    }
+
+    /**
+     * Returns the next element of this enumeration if this enumeration has at
+     * least one more element to provide.
+     *
+     * @return the next element of this enumeration
+     *
+     * @exception NoSuchElementException if no more elements exist
+     */
+    public Object nextElement() throws NoSuchElementException {
+        return (iterator.next());
+    }
+}

+ 234 - 0
core/src/main/java/org/acegisecurity/ui/savedrequest/FastHttpDateFormat.java

@@ -0,0 +1,234 @@
+/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.acegisecurity.ui.savedrequest;
+
+import java.text.DateFormat;
+import java.text.ParseException;
+import java.text.SimpleDateFormat;
+
+import java.util.Date;
+import java.util.HashMap;
+import java.util.Locale;
+import java.util.TimeZone;
+
+
+/**
+ * <p>
+ * Utility class to generate HTTP dates.
+ * </p>
+ * 
+ * <p>
+ * This class is based on code in Apache Tomcat.
+ * </p>
+ *
+ * @author Remy Maucherat
+ * @author Andrey Grebnev
+ * @version $Id$
+ */
+public class FastHttpDateFormat {
+    //~ Static fields/initializers =============================================
+
+    /** HTTP date format. */
+    protected static final SimpleDateFormat format = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz",
+            Locale.US);
+
+    /**
+     * The set of SimpleDateFormat formats to use in
+     * <code>getDateHeader()</code>.
+     */
+    protected static final SimpleDateFormat[] formats = {new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz",
+                Locale.US), new SimpleDateFormat("EEEEEE, dd-MMM-yy HH:mm:ss zzz",
+                Locale.US), new SimpleDateFormat("EEE MMMM d HH:mm:ss yyyy",
+                Locale.US)};
+
+    /** GMT timezone - all HTTP dates are on GMT */
+    protected final static TimeZone gmtZone = TimeZone.getTimeZone("GMT");
+
+    static {
+        format.setTimeZone(gmtZone);
+
+        formats[0].setTimeZone(gmtZone);
+        formats[1].setTimeZone(gmtZone);
+        formats[2].setTimeZone(gmtZone);
+    }
+
+    /** Instant on which the currentDate object was generated. */
+    protected static long currentDateGenerated = 0L;
+
+    /** Current formatted date. */
+    protected static String currentDate = null;
+
+    /** Formatter cache. */
+    protected static final HashMap formatCache = new HashMap();
+
+    /** Parser cache. */
+    protected static final HashMap parseCache = new HashMap();
+
+    //~ Methods ================================================================
+
+    /**
+     * Formats a specified date to HTTP format. If local format is not
+     * <code>null</code>, it's used instead.
+     *
+     * @param value Date value to format
+     * @param threadLocalformat The format to use (or <code>null</code> -- then
+     *        HTTP format will be used)
+     *
+     * @return Formatted date
+     */
+    public static final String formatDate(long value,
+        DateFormat threadLocalformat) {
+        String cachedDate = null;
+        Long longValue = new Long(value);
+
+        try {
+            cachedDate = (String) formatCache.get(longValue);
+        } catch (Exception e) {}
+
+        if (cachedDate != null) {
+            return cachedDate;
+        }
+
+        String newDate = null;
+        Date dateValue = new Date(value);
+
+        if (threadLocalformat != null) {
+            newDate = threadLocalformat.format(dateValue);
+
+            synchronized (formatCache) {
+                updateCache(formatCache, longValue, newDate);
+            }
+        } else {
+            synchronized (formatCache) {
+                newDate = format.format(dateValue);
+                updateCache(formatCache, longValue, newDate);
+            }
+        }
+
+        return newDate;
+    }
+
+    /**
+     * Gets the current date in HTTP format.
+     *
+     * @return Current date in HTTP format
+     */
+    public static final String getCurrentDate() {
+        long now = System.currentTimeMillis();
+
+        if ((now - currentDateGenerated) > 1000) {
+            synchronized (format) {
+                if ((now - currentDateGenerated) > 1000) {
+                    currentDateGenerated = now;
+                    currentDate = format.format(new Date(now));
+                }
+            }
+        }
+
+        return currentDate;
+    }
+
+    /**
+     * Parses date with given formatters.
+     *
+     * @param value The string to parse
+     * @param formats Array of formats to use
+     *
+     * @return Parsed date (or <code>null</code> if no formatter mached)
+     */
+    private static final Long internalParseDate(String value,
+        DateFormat[] formats) {
+        Date date = null;
+
+        for (int i = 0; (date == null) && (i < formats.length); i++) {
+            try {
+                date = formats[i].parse(value);
+            } catch (ParseException e) {
+                ;
+            }
+        }
+
+        if (date == null) {
+            return null;
+        }
+
+        return new Long(date.getTime());
+    }
+
+    /**
+     * Tries to parse the given date as an HTTP date. If local format list is
+     * not <code>null</code>, it's used instead.
+     *
+     * @param value The string to parse
+     * @param threadLocalformats Array of formats to use for parsing. If
+     *        <code>null</code>, HTTP formats are used.
+     *
+     * @return Parsed date (or -1 if error occured)
+     */
+    public static final long parseDate(String value,
+        DateFormat[] threadLocalformats) {
+        Long cachedDate = null;
+
+        try {
+            cachedDate = (Long) parseCache.get(value);
+        } catch (Exception e) {}
+
+        if (cachedDate != null) {
+            return cachedDate.longValue();
+        }
+
+        Long date = null;
+
+        if (threadLocalformats != null) {
+            date = internalParseDate(value, threadLocalformats);
+
+            synchronized (parseCache) {
+                updateCache(parseCache, value, date);
+            }
+        } else {
+            synchronized (parseCache) {
+                date = internalParseDate(value, formats);
+                updateCache(parseCache, value, date);
+            }
+        }
+
+        if (date == null) {
+            return (-1L);
+        } else {
+            return date.longValue();
+        }
+    }
+
+    /**
+     * Updates cache.
+     *
+     * @param cache Cache to be updated
+     * @param key Key to be updated
+     * @param value New value
+     */
+    private static final void updateCache(HashMap cache, Object key,
+        Object value) {
+        if (value == null) {
+            return;
+        }
+
+        if (cache.size() > 1000) {
+            cache.clear();
+        }
+
+        cache.put(key, value);
+    }
+}

+ 362 - 0
core/src/main/java/org/acegisecurity/ui/savedrequest/SavedRequest.java

@@ -0,0 +1,362 @@
+/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.acegisecurity.ui.savedrequest;
+
+import org.acegisecurity.util.PortResolver;
+import org.acegisecurity.util.UrlUtils;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import org.springframework.util.Assert;
+
+import java.util.ArrayList;
+import java.util.Enumeration;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+
+import javax.servlet.http.Cookie;
+import javax.servlet.http.HttpServletRequest;
+
+
+/**
+ * Represents central information from a <code>HttpServletRequest</code>.
+ * 
+ * <p>
+ * This class is used by {@link org.acegisecurity.ui.AbstractProcessingFilter}
+ * and {@link org.acegisecurity.wrapper.SavedRequestAwareWrapper} to reproduce
+ * the request after successful authentication. An instance of this class is
+ * stored at the time of an authentication exception by {@link
+ * org.acegisecurity.ui.ExceptionTranslationFilter}.
+ * </p>
+ * 
+ * <p>
+ * <em>IMPLEMENTATION NOTE</em>: It is assumed that this object is accessed
+ * only from the context of a single thread, so no synchronization around
+ * internal collection classes is performed.
+ * </p>
+ * 
+ * <p>
+ * This class is based on code in Apache Tomcat.
+ * </p>
+ *
+ * @author Craig McClanahan
+ * @author Andrey Grebnev
+ * @author Ben Alex
+ * @version $Id$
+ */
+public class SavedRequest {
+    //~ Static fields/initializers =============================================
+
+    protected static final Log logger = LogFactory.getLog(SavedRequest.class);
+
+    //~ Instance fields ========================================================
+
+    private ArrayList cookies = new ArrayList();
+    private ArrayList locales = new ArrayList();
+    private HashMap headers = new HashMap();
+    private HashMap parameters = new HashMap();
+    private String contextPath;
+    private String method;
+    private String pathInfo;
+    private String queryString;
+    private String requestURI;
+    private String requestURL;
+    private String scheme;
+    private String serverName;
+    private String servletPath;
+    private int serverPort;
+
+    //~ Constructors ===========================================================
+
+    public SavedRequest(HttpServletRequest request, PortResolver portResolver) {
+        Assert.notNull(request, "Request required");
+        Assert.notNull(portResolver, "PortResolver required");
+
+        // Cookies
+        Cookie[] cookies = request.getCookies();
+
+        if (cookies != null) {
+            for (int i = 0; i < cookies.length; i++) {
+                this.addCookie(cookies[i]);
+            }
+        }
+
+        // Headers
+        Enumeration names = request.getHeaderNames();
+
+        while (names.hasMoreElements()) {
+            String name = (String) names.nextElement();
+            Enumeration values = request.getHeaders(name);
+
+            while (values.hasMoreElements()) {
+                String value = (String) values.nextElement();
+                this.addHeader(name, value);
+            }
+        }
+
+        // Locales
+        Enumeration locales = request.getLocales();
+
+        while (locales.hasMoreElements()) {
+            Locale locale = (Locale) locales.nextElement();
+            this.addLocale(locale);
+        }
+
+        // Parameters
+        Map parameters = request.getParameterMap();
+        Iterator paramNames = parameters.keySet().iterator();
+
+        while (paramNames.hasNext()) {
+            String paramName = (String) paramNames.next();
+            String[] paramValues = (String[]) parameters.get(paramName);
+            this.addParameter(paramName, paramValues);
+        }
+
+        // Primitives
+        this.method = request.getMethod();
+        this.pathInfo = request.getPathInfo();
+        this.queryString = request.getQueryString();
+        this.requestURI = request.getRequestURI();
+        this.serverPort = portResolver.getServerPort(request);
+        this.requestURL = request.getRequestURL().toString();
+        this.scheme = request.getScheme();
+        this.serverName = request.getServerName();
+        this.contextPath = request.getContextPath();
+        this.servletPath = request.getServletPath();
+    }
+
+    //~ Methods ================================================================
+
+    private void addCookie(Cookie cookie) {
+        cookies.add(cookie);
+    }
+
+    private void addHeader(String name, String value) {
+        ArrayList values = (ArrayList) headers.get(name);
+
+        if (values == null) {
+            values = new ArrayList();
+            headers.put(name, values);
+        }
+
+        values.add(value);
+    }
+
+    private void addLocale(Locale locale) {
+        locales.add(locale);
+    }
+
+    private void addParameter(String name, String[] values) {
+        parameters.put(name, values);
+    }
+
+    /**
+     * Determines if the current request matches the <code>SavedRequest</code>.
+     * All URL arguments are considered, but <em>not</em> method (POST/GET),
+     * cookies, locales, headers or parameters.
+     *
+     * @param request DOCUMENT ME!
+     * @param portResolver DOCUMENT ME!
+     *
+     * @return DOCUMENT ME!
+     */
+    public boolean doesRequestMatch(HttpServletRequest request,
+        PortResolver portResolver) {
+        Assert.notNull(request, "Request required");
+        Assert.notNull(portResolver, "PortResolver required");
+
+        if (!propertyEquals("pathInfo", this.pathInfo, request.getPathInfo())) {
+            return false;
+        }
+
+        if (!propertyEquals("queryString", this.queryString,
+                request.getQueryString())) {
+            return false;
+        }
+
+        if (!propertyEquals("requestURI", this.requestURI,
+                request.getRequestURI())) {
+            return false;
+        }
+
+        if (!propertyEquals("serverPort", new Integer(this.serverPort),
+                new Integer(portResolver.getServerPort(request)))) {
+            return false;
+        }
+
+        if (!propertyEquals("requestURL", this.requestURL,
+                request.getRequestURL().toString())) {
+            return false;
+        }
+
+        if (!propertyEquals("scheme", this.scheme, request.getScheme())) {
+            return false;
+        }
+
+        if (!propertyEquals("serverName", this.serverName,
+                request.getServerName())) {
+            return false;
+        }
+
+        if (!propertyEquals("contextPath", this.contextPath,
+                request.getContextPath())) {
+            return false;
+        }
+
+        if (!propertyEquals("servletPath", this.servletPath,
+                request.getServletPath())) {
+            return false;
+        }
+
+        return true;
+    }
+
+    public String getContextPath() {
+        return contextPath;
+    }
+
+    public List getCookies() {
+        return cookies;
+    }
+
+    /**
+     * Indicates the URL that the user agent used for this request.
+     *
+     * @return the full URL of this request
+     */
+    public String getFullRequestUrl() {
+        return UrlUtils.getFullRequestUrl(this);
+    }
+
+    public Iterator getHeaderNames() {
+        return (headers.keySet().iterator());
+    }
+
+    public Iterator getHeaderValues(String name) {
+        ArrayList values = (ArrayList) headers.get(name);
+
+        if (values == null) {
+            return ((new ArrayList()).iterator());
+        } else {
+            return (values.iterator());
+        }
+    }
+
+    public Iterator getLocales() {
+        return (locales.iterator());
+    }
+
+    public String getMethod() {
+        return (this.method);
+    }
+
+    public Map getParameterMap() {
+        return parameters;
+    }
+
+    public Iterator getParameterNames() {
+        return (parameters.keySet().iterator());
+    }
+
+    public String[] getParameterValues(String name) {
+        return ((String[]) parameters.get(name));
+    }
+
+    public String getPathInfo() {
+        return pathInfo;
+    }
+
+    public String getQueryString() {
+        return (this.queryString);
+    }
+
+    public String getRequestURI() {
+        return (this.requestURI);
+    }
+
+    public String getRequestURL() {
+        return requestURL;
+    }
+
+    /**
+     * Obtains the web application-specific fragment of the URL.
+     *
+     * @return the URL, excluding any server name, context path or servlet path
+     */
+    public String getRequestUrl() {
+        return UrlUtils.getRequestUrl(this);
+    }
+
+    public String getScheme() {
+        return scheme;
+    }
+
+    public String getServerName() {
+        return serverName;
+    }
+
+    public int getServerPort() {
+        return serverPort;
+    }
+
+    public String getServletPath() {
+        return servletPath;
+    }
+
+    private boolean propertyEquals(String log, Object arg1, Object arg2) {
+        if ((arg1 == null) && (arg2 == null)) {
+            if (logger.isDebugEnabled()) {
+                logger.debug(log + ": both null (property equals)");
+            }
+
+            return true;
+        }
+
+        if (((arg1 == null) && (arg2 != null))
+            || ((arg1 != null) && (arg2 == null))) {
+            if (logger.isDebugEnabled()) {
+                logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2
+                    + " (property not equals)");
+            }
+
+            return false;
+        }
+
+        if (arg1.equals(arg2)) {
+            if (logger.isDebugEnabled()) {
+                logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2
+                    + " (property equals)");
+            }
+
+            return true;
+        } else {
+            if (logger.isDebugEnabled()) {
+                logger.debug(log + ": arg1=" + arg1 + "; arg2=" + arg2
+                    + " (property not equals)");
+            }
+
+            return false;
+        }
+    }
+
+    public String toString() {
+        return "SavedRequest[" + getFullRequestUrl() + "]";
+    }
+}

+ 6 - 0
core/src/main/java/org/acegisecurity/ui/savedrequest/package.html

@@ -0,0 +1,6 @@
+<html>
+<body>
+Stores a <code>HttpServletRequest</code> so that it can subsequently be emulated by the
+<code>SavedRequestAwareWrapper</code>.
+</body>
+</html>

+ 130 - 0
core/src/main/java/org/acegisecurity/util/UrlUtils.java

@@ -0,0 +1,130 @@
+/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.acegisecurity.util;
+
+import org.acegisecurity.intercept.web.FilterInvocation;
+import org.acegisecurity.ui.savedrequest.SavedRequest;
+
+import javax.servlet.http.HttpServletRequest;
+
+
+/**
+ * Provides static methods for composing URLs.
+ * 
+ * <p>
+ * Placed into a separate class for visibility, so that changes to URL
+ * formatting conventions will affect all users.
+ * </p>
+ *
+ * @author Ben Alex
+ * @version $Id$
+ */
+public class UrlUtils {
+    //~ Methods ================================================================
+
+    /**
+     * Obtains the full URL the client used to make the request.
+     * 
+     * <p>
+     * Note that the server port will not be shown if it is the default server
+     * port for HTTP or HTTPS (ie 80 and 443 respectively).
+     * </p>
+     *
+     * @param scheme DOCUMENT ME!
+     * @param serverName DOCUMENT ME!
+     * @param serverPort DOCUMENT ME!
+     * @param contextPath DOCUMENT ME!
+     * @param requestUrl DOCUMENT ME!
+     * @param servletPath DOCUMENT ME!
+     * @param requestURI DOCUMENT ME!
+     * @param pathInfo DOCUMENT ME!
+     * @param queryString DOCUMENT ME!
+     *
+     * @return the full URL
+     */
+    private static String buildFullRequestUrl(String scheme, String serverName,
+        int serverPort, String contextPath, String requestUrl,
+        String servletPath, String requestURI, String pathInfo,
+        String queryString) {
+        boolean includePort = true;
+
+        if ("http".equals(scheme.toLowerCase()) && (serverPort == 80)) {
+            includePort = false;
+        }
+
+        if ("https".equals(scheme.toLowerCase()) && (serverPort == 443)) {
+            includePort = false;
+        }
+
+        return scheme + "://" + serverName
+        + ((includePort) ? (":" + serverPort) : "") + contextPath
+        + buildRequestUrl(servletPath, requestURI, contextPath, pathInfo,
+            queryString);
+    }
+
+    /**
+     * Obtains the web application-specific fragment of the URL.
+     *
+     * @param servletPath DOCUMENT ME!
+     * @param requestURI DOCUMENT ME!
+     * @param contextPath DOCUMENT ME!
+     * @param pathInfo DOCUMENT ME!
+     * @param queryString DOCUMENT ME!
+     *
+     * @return the URL, excluding any server name, context path or servlet path
+     */
+    private static String buildRequestUrl(String servletPath,
+        String requestURI, String contextPath, String pathInfo,
+        String queryString) {
+        String uri = servletPath;
+
+        if (uri == null) {
+            uri = requestURI;
+            uri = uri.substring(contextPath.length());
+        }
+
+        return uri + ((pathInfo == null) ? "" : pathInfo)
+        + ((queryString == null) ? "" : ("?" + queryString));
+    }
+
+    public static String getFullRequestUrl(FilterInvocation fi) {
+        HttpServletRequest r = fi.getHttpRequest();
+
+        return buildFullRequestUrl(r.getScheme(), r.getServerName(),
+            r.getServerPort(), r.getContextPath(),
+            r.getRequestURL().toString(), r.getServletPath(),
+            r.getRequestURI(), r.getPathInfo(), r.getQueryString());
+    }
+
+    public static String getFullRequestUrl(SavedRequest sr) {
+        return buildFullRequestUrl(sr.getScheme(), sr.getServerName(),
+            sr.getServerPort(), sr.getContextPath(), sr.getRequestURL(),
+            sr.getServletPath(), sr.getRequestURI(), sr.getPathInfo(),
+            sr.getQueryString());
+    }
+
+    public static String getRequestUrl(FilterInvocation fi) {
+        HttpServletRequest r = fi.getHttpRequest();
+
+        return buildRequestUrl(r.getServletPath(), r.getRequestURI(),
+            r.getContextPath(), r.getPathInfo(), r.getQueryString());
+    }
+
+    public static String getRequestUrl(SavedRequest sr) {
+        return buildRequestUrl(sr.getServletPath(), sr.getRequestURI(),
+            sr.getContextPath(), sr.getPathInfo(), sr.getQueryString());
+    }
+}

+ 409 - 0
core/src/main/java/org/acegisecurity/wrapper/SavedRequestAwareWrapper.java

@@ -0,0 +1,409 @@
+/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.acegisecurity.wrapper;
+
+import org.acegisecurity.ui.AbstractProcessingFilter;
+import org.acegisecurity.ui.savedrequest.Enumerator;
+import org.acegisecurity.ui.savedrequest.FastHttpDateFormat;
+import org.acegisecurity.ui.savedrequest.SavedRequest;
+
+import org.acegisecurity.util.PortResolver;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import java.text.SimpleDateFormat;
+
+import java.util.ArrayList;
+import java.util.Enumeration;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.TimeZone;
+
+import javax.servlet.http.Cookie;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpSession;
+
+
+/**
+ * Provides request parameters, headers and cookies from either an original
+ * request or a saved request.
+ * 
+ * <p>
+ * Note that not all request parameters in the original request are emulated by
+ * this wrapper. Nevertheless, the important data from the original request is
+ * emulated and this should prove adequate for most purposes (in particular
+ * standard HTTP GET and POST operations).
+ * </p>
+ * 
+ * <p>
+ * Added into a request by {@link
+ * org.acegisecurity.wrapper.SecurityContextHolderAwareRequestFilter}.
+ * </p>
+ *
+ * @author Andrey Grebnev
+ * @author Ben Alex
+ * @version $Id$
+ */
+public class SavedRequestAwareWrapper
+    extends SecurityContextHolderAwareRequestWrapper {
+    //~ Static fields/initializers =============================================
+
+    protected static final Log logger = LogFactory.getLog(SavedRequestAwareWrapper.class);
+    protected static final TimeZone GMT_ZONE = TimeZone.getTimeZone("GMT");
+
+    /** The default Locale if none are specified. */
+    protected static Locale defaultLocale = Locale.getDefault();
+
+    //~ Instance fields ========================================================
+
+    protected SavedRequest savedRequest = null;
+
+    /**
+     * The set of SimpleDateFormat formats to use in getDateHeader(). Notice
+     * that because SimpleDateFormat is not thread-safe, we can't declare
+     * formats[] as a static variable.
+     */
+    protected SimpleDateFormat[] formats = new SimpleDateFormat[3];
+
+    //~ Constructors ===========================================================
+
+    public SavedRequestAwareWrapper(HttpServletRequest request,
+        PortResolver portResolver) {
+        super(request);
+
+        HttpSession session = request.getSession(false);
+
+        if (session == null) {
+            if (logger.isDebugEnabled()) {
+                logger.debug(
+                    "Wrapper not replaced; no session available for SavedRequest extraction");
+            }
+
+            return;
+        }
+
+        SavedRequest saved = (SavedRequest) session.getAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY);
+
+        if ((saved != null) && saved.doesRequestMatch(request, portResolver)) {
+            if (logger.isDebugEnabled()) {
+                logger.debug("Wrapper replaced; SavedRequest was: " + saved);
+            }
+
+            savedRequest = saved;
+            session.removeAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY);
+
+            formats[0] = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz",
+                    Locale.US);
+            formats[1] = new SimpleDateFormat("EEEEEE, dd-MMM-yy HH:mm:ss zzz",
+                    Locale.US);
+            formats[2] = new SimpleDateFormat("EEE MMMM d HH:mm:ss yyyy",
+                    Locale.US);
+
+            formats[0].setTimeZone(GMT_ZONE);
+            formats[1].setTimeZone(GMT_ZONE);
+            formats[2].setTimeZone(GMT_ZONE);
+        } else {
+            if (logger.isDebugEnabled()) {
+                logger.debug("Wrapper not replaced; SavedRequest was: " + saved);
+            }
+        }
+    }
+
+    //~ Methods ================================================================
+
+    /**
+     * The default behavior of this method is to return getCookies() on the
+     * wrapped request object.
+     *
+     * @return DOCUMENT ME!
+     */
+    public Cookie[] getCookies() {
+        if (savedRequest == null) {
+            return super.getCookies();
+        } else {
+            List cookies = savedRequest.getCookies();
+
+            return (Cookie[]) cookies.toArray(new Cookie[cookies.size()]);
+        }
+    }
+
+    /**
+     * The default behavior of this method is to return getDateHeader(String
+     * name) on the wrapped request object.
+     *
+     * @param name DOCUMENT ME!
+     *
+     * @return DOCUMENT ME!
+     *
+     * @throws IllegalArgumentException DOCUMENT ME!
+     */
+    public long getDateHeader(String name) {
+        if (savedRequest == null) {
+            return super.getDateHeader(name);
+        } else {
+            String value = getHeader(name);
+
+            if (value == null) {
+                return (-1L);
+            }
+
+            // Attempt to convert the date header in a variety of formats
+            long result = FastHttpDateFormat.parseDate(value, formats);
+
+            if (result != (-1L)) {
+                return result;
+            }
+
+            throw new IllegalArgumentException(value);
+        }
+    }
+
+    /**
+     * The default behavior of this method is to return getHeader(String name)
+     * on the wrapped request object.
+     *
+     * @param name DOCUMENT ME!
+     *
+     * @return DOCUMENT ME!
+     */
+    public String getHeader(String name) {
+        if (savedRequest == null) {
+            return super.getHeader(name);
+        } else {
+            String header = null;
+            Iterator iterator = savedRequest.getHeaderValues(name);
+
+            while (iterator.hasNext()) {
+                header = (String) iterator.next();
+
+                break;
+            }
+
+            return header;
+        }
+    }
+
+    /**
+     * The default behavior of this method is to return getHeaderNames() on the
+     * wrapped request object.
+     *
+     * @return DOCUMENT ME!
+     */
+    public Enumeration getHeaderNames() {
+        if (savedRequest == null) {
+            return super.getHeaderNames();
+        } else {
+            return new Enumerator(savedRequest.getHeaderNames());
+        }
+    }
+
+    /**
+     * The default behavior of this method is to return getHeaders(String name)
+     * on the wrapped request object.
+     *
+     * @param name DOCUMENT ME!
+     *
+     * @return DOCUMENT ME!
+     */
+    public Enumeration getHeaders(String name) {
+        if (savedRequest == null) {
+            return super.getHeaders(name);
+        } else {
+            return new Enumerator(savedRequest.getHeaderValues(name));
+        }
+    }
+
+    /**
+     * The default behavior of this method is to return getIntHeader(String
+     * name) on the wrapped request object.
+     *
+     * @param name DOCUMENT ME!
+     *
+     * @return DOCUMENT ME!
+     */
+    public int getIntHeader(String name) {
+        if (savedRequest == null) {
+            return super.getIntHeader(name);
+        } else {
+            String value = getHeader(name);
+
+            if (value == null) {
+                return (-1);
+            } else {
+                return (Integer.parseInt(value));
+            }
+        }
+    }
+
+    /**
+     * The default behavior of this method is to return getLocale() on the
+     * wrapped request object.
+     *
+     * @return DOCUMENT ME!
+     */
+    public Locale getLocale() {
+        if (savedRequest == null) {
+            return super.getLocale();
+        } else {
+            Locale locale = null;
+            Iterator iterator = savedRequest.getLocales();
+
+            while (iterator.hasNext()) {
+                locale = (Locale) iterator.next();
+
+                break;
+            }
+
+            if (locale == null) {
+                return defaultLocale;
+            } else {
+                return locale;
+            }
+        }
+    }
+
+    /**
+     * The default behavior of this method is to return getLocales() on the
+     * wrapped request object.
+     *
+     * @return DOCUMENT ME!
+     */
+    public Enumeration getLocales() {
+        if (savedRequest == null) {
+            return super.getLocales();
+        } else {
+            Iterator iterator = savedRequest.getLocales();
+
+            if (iterator.hasNext()) {
+                return new Enumerator(iterator);
+            } else {
+                ArrayList results = new ArrayList();
+                results.add(defaultLocale);
+
+                return new Enumerator(results.iterator());
+            }
+        }
+    }
+
+    /**
+     * The default behavior of this method is to return getMethod() on the
+     * wrapped request object.
+     *
+     * @return DOCUMENT ME!
+     */
+    public String getMethod() {
+        if (savedRequest == null) {
+            return super.getMethod();
+        } else {
+            return savedRequest.getMethod();
+        }
+    }
+
+    /**
+     * The default behavior of this method is to return getParameter(String
+     * name) on the wrapped request object.
+     *
+     * @param name DOCUMENT ME!
+     *
+     * @return DOCUMENT ME!
+     */
+    public String getParameter(String name) {
+/*
+   if (savedRequest == null) {
+       return super.getParameter(name);
+   } else {
+       String value = null;
+       String[] values = savedRequest.getParameterValues(name);
+       if (values == null)
+           return null;
+       for (int i = 0; i < values.length; i++) {
+           value = values[i];
+           break;
+       }
+       return value;
+   }
+ */
+
+        //we do not get value from super.getParameter because there is a bug in Jetty servlet-container
+        String value = null;
+        String[] values = null;
+
+        if (savedRequest == null) {
+            values = super.getParameterValues(name);
+        } else {
+            values = savedRequest.getParameterValues(name);
+        }
+
+        if (values == null) {
+            return null;
+        }
+
+        for (int i = 0; i < values.length; i++) {
+            value = values[i];
+
+            break;
+        }
+
+        return value;
+    }
+
+    /**
+     * The default behavior of this method is to return getParameterMap() on
+     * the wrapped request object.
+     *
+     * @return DOCUMENT ME!
+     */
+    public Map getParameterMap() {
+        if (savedRequest == null) {
+            return super.getParameterMap();
+        } else {
+            return savedRequest.getParameterMap();
+        }
+    }
+
+    /**
+     * The default behavior of this method is to return getParameterNames() on
+     * the wrapped request object.
+     *
+     * @return DOCUMENT ME!
+     */
+    public Enumeration getParameterNames() {
+        if (savedRequest == null) {
+            return super.getParameterNames();
+        } else {
+            return new Enumerator(savedRequest.getParameterNames());
+        }
+    }
+
+    /**
+     * The default behavior of this method is to return
+     * getParameterValues(String name) on the wrapped request object.
+     *
+     * @param name DOCUMENT ME!
+     *
+     * @return DOCUMENT ME!
+     */
+    public String[] getParameterValues(String name) {
+        if (savedRequest == null) {
+            return super.getParameterValues(name);
+        } else {
+            return savedRequest.getParameterValues(name);
+        }
+    }
+}

+ 62 - 5
core/src/main/java/org/acegisecurity/wrapper/SecurityContextHolderAwareRequestFilter.java

@@ -1,4 +1,4 @@
-/* Copyright 2004, 2005 Acegi Technology Pty Limited
+/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,8 +15,16 @@
 
 package org.acegisecurity.wrapper;
 
+import org.acegisecurity.util.PortResolver;
+import org.acegisecurity.util.PortResolverImpl;
+
+import org.springframework.util.Assert;
+import org.springframework.util.ReflectionUtils;
+
 import java.io.IOException;
 
+import java.lang.reflect.Constructor;
+
 import javax.servlet.Filter;
 import javax.servlet.FilterChain;
 import javax.servlet.FilterConfig;
@@ -27,13 +35,38 @@ import javax.servlet.http.HttpServletRequest;
 
 
 /**
- * A <code>Filter</code> which populates the <code>ServletRequest</code> with
- * an {@link SecurityContextHolderAwareRequestWrapper}.
+ * A <code>Filter</code> which populates the <code>ServletRequest</code> with a
+ * new request wrapper.
+ * 
+ * <p>
+ * Several request wrappers are included with the framework. The simplest
+ * version is {@link SecurityContextHolderAwareRequestWrapper}. A more complex
+ * and powerful request wrapper is {@link
+ * org.acegisecurity.wrapper.SavedRequestAwareWrapper}. The latter is also the
+ * default.
+ * </p>
+ * 
+ * <p>
+ * To modify the wrapper used, call {@link #setWrapperClass(Class)}.
+ * </p>
+ * 
+ * <p>
+ * Any request wrapper configured for instantiation by this class must provide
+ * a public constructor that accepts two arguments, being a
+ * <code>HttpServletRequest</code> and a <code>PortResolver</code>.
+ * </p>
  *
  * @author Orlando Garcia Carmona
+ * @author Ben Alex
  * @version $Id$
  */
 public class SecurityContextHolderAwareRequestFilter implements Filter {
+    //~ Instance fields ========================================================
+
+    private Class wrapperClass = SavedRequestAwareWrapper.class;
+    private Constructor constructor;
+    private PortResolver portResolver = new PortResolverImpl();
+
     //~ Methods ================================================================
 
     public void destroy() {}
@@ -43,12 +76,36 @@ public class SecurityContextHolderAwareRequestFilter implements Filter {
         throws IOException, ServletException {
         HttpServletRequest request = (HttpServletRequest) servletRequest;
 
-        if (!(request instanceof SecurityContextHolderAwareRequestWrapper)) {
-            request = new SecurityContextHolderAwareRequestWrapper(request);
+        if (!wrapperClass.isAssignableFrom(request.getClass())) {
+            if (constructor == null) {
+                try {
+                    constructor = wrapperClass.getConstructor(new Class[] {HttpServletRequest.class, PortResolver.class});
+                } catch (Exception ex) {
+                    ReflectionUtils.handleReflectionException(ex);
+                }
+            }
+
+            try {
+                request = (HttpServletRequest) constructor.newInstance(new Object[] {request, portResolver});
+            } catch (Exception ex) {
+                ReflectionUtils.handleReflectionException(ex);
+            }
         }
 
         filterChain.doFilter(request, servletResponse);
     }
 
     public void init(FilterConfig filterConfig) throws ServletException {}
+
+    public void setPortResolver(PortResolver portResolver) {
+        Assert.notNull(portResolver, "PortResolver required");
+        this.portResolver = portResolver;
+    }
+
+    public void setWrapperClass(Class wrapperClass) {
+        Assert.notNull(wrapperClass, "WrapperClass required");
+        Assert.isTrue(HttpServletRequest.class.isAssignableFrom(wrapperClass),
+            "Wrapper must be a HttpServletRequest");
+        this.wrapperClass = wrapperClass;
+    }
 }

+ 15 - 13
core/src/test/java/org/acegisecurity/intercept/web/FilterInvocationTests.java

@@ -1,4 +1,4 @@
-/* Copyright 2004, 2005 Acegi Technology Pty Limited
+/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,12 +17,14 @@ package org.acegisecurity.intercept.web;
 
 import org.acegisecurity.MockFilterChain;
 
-import javax.servlet.ServletRequest;
-import javax.servlet.ServletResponse;
+import org.jmock.MockObjectTestCase;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
-import org.jmock.MockObjectTestCase;
+
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+
 
 /**
  * Tests {@link FilterInvocation}.
@@ -44,14 +46,14 @@ public class FilterInvocationTests extends MockObjectTestCase {
 
     //~ Methods ================================================================
 
-    public final void setUp() throws Exception {
-        super.setUp();
-    }
-
     public static void main(String[] args) {
         junit.textui.TestRunner.run(FilterInvocationTests.class);
     }
 
+    public final void setUp() throws Exception {
+        super.setUp();
+    }
+
     public void testGettersAndStringMethods() {
         MockHttpServletRequest request = new MockHttpServletRequest(null, null);
         request.setServletPath("/HelloWorld");
@@ -73,7 +75,7 @@ public class FilterInvocationTests extends MockObjectTestCase {
         assertEquals("/HelloWorld/some/more/segments.html", fi.getRequestUrl());
         assertEquals("FilterInvocation: URL: /HelloWorld/some/more/segments.html",
             fi.toString());
-        assertEquals("http://www.example.com:80/mycontext/HelloWorld/some/more/segments.html",
+        assertEquals("http://www.example.com/mycontext/HelloWorld/some/more/segments.html",
             fi.getFullRequestUrl());
     }
 
@@ -81,7 +83,7 @@ public class FilterInvocationTests extends MockObjectTestCase {
         Class clazz = FilterInvocation.class;
 
         try {
-            clazz.getDeclaredConstructor((Class[])null);
+            clazz.getDeclaredConstructor((Class[]) null);
             fail("Should have thrown NoSuchMethodException");
         } catch (NoSuchMethodException expected) {
             assertTrue(true);
@@ -125,7 +127,7 @@ public class FilterInvocationTests extends MockObjectTestCase {
     }
 
     public void testRejectsServletRequestWhichIsNotHttpServletRequest() {
-        ServletRequest request = (ServletRequest)newDummy(ServletRequest.class);
+        ServletRequest request = (ServletRequest) newDummy(ServletRequest.class);
         MockHttpServletResponse response = new MockHttpServletResponse();
         MockFilterChain chain = new MockFilterChain();
 
@@ -167,7 +169,7 @@ public class FilterInvocationTests extends MockObjectTestCase {
         FilterInvocation fi = new FilterInvocation(request, response, chain);
         assertEquals("/HelloWorld?foo=bar", fi.getRequestUrl());
         assertEquals("FilterInvocation: URL: /HelloWorld?foo=bar", fi.toString());
-        assertEquals("http://www.example.com:80/mycontext/HelloWorld?foo=bar",
+        assertEquals("http://www.example.com/mycontext/HelloWorld?foo=bar",
             fi.getFullRequestUrl());
     }
 
@@ -185,7 +187,7 @@ public class FilterInvocationTests extends MockObjectTestCase {
         FilterInvocation fi = new FilterInvocation(request, response, chain);
         assertEquals("/HelloWorld", fi.getRequestUrl());
         assertEquals("FilterInvocation: URL: /HelloWorld", fi.toString());
-        assertEquals("http://www.example.com:80/mycontext/HelloWorld",
+        assertEquals("http://www.example.com/mycontext/HelloWorld",
             fi.getFullRequestUrl());
     }
 }

+ 19 - 5
core/src/test/java/org/acegisecurity/ui/AbstractProcessingFilterTests.java

@@ -30,6 +30,9 @@ import org.acegisecurity.context.SecurityContextHolder;
 import org.acegisecurity.providers.UsernamePasswordAuthenticationToken;
 
 import org.acegisecurity.ui.rememberme.TokenBasedRememberMeServices;
+import org.acegisecurity.ui.savedrequest.SavedRequest;
+
+import org.acegisecurity.util.PortResolverImpl;
 
 import org.springframework.mock.web.MockFilterConfig;
 import org.springframework.mock.web.MockHttpServletRequest;
@@ -91,6 +94,16 @@ public class AbstractProcessingFilterTests extends TestCase {
         junit.textui.TestRunner.run(AbstractProcessingFilterTests.class);
     }
 
+    private SavedRequest makeSavedRequestForUrl() {
+        MockHttpServletRequest request = createMockRequest();
+        request.setServletPath("/some_protected_file.html");
+        request.setScheme("http");
+        request.setServerName("www.example.com");
+        request.setRequestURI("/mycontext/some_protected_file.html");
+
+        return new SavedRequest(request, new PortResolverImpl());
+    }
+
     protected void setUp() throws Exception {
         super.setUp();
         SecurityContextHolder.clearContext();
@@ -399,8 +412,8 @@ public class AbstractProcessingFilterTests extends TestCase {
         // Setup our HTTP request
         MockHttpServletRequest request = createMockRequest();
         request.getSession()
-               .setAttribute(AbstractProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY,
-            "/my-destination");
+               .setAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY,
+            makeSavedRequestForUrl());
 
         // Setup our filter configuration
         MockFilterConfig config = new MockFilterConfig(null);
@@ -429,8 +442,8 @@ public class AbstractProcessingFilterTests extends TestCase {
         // Setup our HTTP request
         MockHttpServletRequest request = createMockRequest();
         request.getSession()
-               .setAttribute(AbstractProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY,
-            "/my-destination");
+               .setAttribute(AbstractProcessingFilter.ACEGI_SAVED_REQUEST_KEY,
+            makeSavedRequestForUrl());
 
         // Setup our filter configuration
         MockFilterConfig config = new MockFilterConfig(null);
@@ -446,7 +459,8 @@ public class AbstractProcessingFilterTests extends TestCase {
         // Test
         executeFilterInContainerSimulator(config, filter, request, response,
             chain);
-        assertEquals("/my-destination", response.getRedirectedUrl());
+        assertEquals(makeSavedRequestForUrl().getFullRequestUrl(),
+            response.getRedirectedUrl());
         assertNotNull(SecurityContextHolder.getContext().getAuthentication());
     }
 

+ 3 - 8
core/src/test/java/org/acegisecurity/ui/ExceptionTranslationFilterTests.java

@@ -28,8 +28,6 @@ import org.acegisecurity.context.SecurityContextHolder;
 
 import org.acegisecurity.providers.anonymous.AnonymousAuthenticationToken;
 
-import org.acegisecurity.ui.webapp.AuthenticationProcessingFilter;
-
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 
@@ -101,8 +99,7 @@ public class ExceptionTranslationFilterTests extends TestCase {
         filter.doFilter(request, response, chain);
         assertEquals("/mycontext/login.jsp", response.getRedirectedUrl());
         assertEquals("http://www.example.com/mycontext/secure/page.html",
-            request.getSession()
-                   .getAttribute(AuthenticationProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY));
+            AbstractProcessingFilter.obtainFullRequestUrl(request));
     }
 
     public void testAccessDeniedWhenNonAnonymous() throws Exception {
@@ -192,8 +189,7 @@ public class ExceptionTranslationFilterTests extends TestCase {
         filter.doFilter(request, response, chain);
         assertEquals("/mycontext/login.jsp", response.getRedirectedUrl());
         assertEquals("http://www.example.com/mycontext/secure/page.html",
-            request.getSession()
-                   .getAttribute(AuthenticationProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY));
+            AbstractProcessingFilter.obtainFullRequestUrl(request));
     }
 
     public void testRedirectedToLoginFormAndSessionShowsOriginalTargetWithExoticPortWhenAuthenticationException()
@@ -221,8 +217,7 @@ public class ExceptionTranslationFilterTests extends TestCase {
         filter.doFilter(request, response, chain);
         assertEquals("/mycontext/login.jsp", response.getRedirectedUrl());
         assertEquals("http://www.example.com:8080/mycontext/secure/page.html",
-            request.getSession()
-                   .getAttribute(AuthenticationProcessingFilter.ACEGI_SECURITY_TARGET_URL_KEY));
+            AbstractProcessingFilter.obtainFullRequestUrl(request));
     }
 
     public void testStartupDetectsMissingAuthenticationEntryPoint()

+ 9 - 9
core/src/test/java/org/acegisecurity/wrapper/SecurityContextHolderAwareRequestFilterTests.java

@@ -1,4 +1,4 @@
-/* Copyright 2004, 2005 Acegi Technology Pty Limited
+/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,6 +19,8 @@ import junit.framework.TestCase;
 
 import org.acegisecurity.MockFilterConfig;
 
+import org.springframework.mock.web.MockHttpServletRequest;
+
 import java.io.IOException;
 
 import javax.servlet.FilterChain;
@@ -26,8 +28,6 @@ import javax.servlet.ServletException;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
 
-import org.springframework.mock.web.MockHttpServletRequest;
-
 
 /**
  * Tests {@link SecurityContextHolderAwareRequestFilter}.
@@ -48,23 +48,23 @@ public class SecurityContextHolderAwareRequestFilterTests extends TestCase {
 
     //~ Methods ================================================================
 
-    public final void setUp() throws Exception {
-        super.setUp();
-    }
-
     public static void main(String[] args) {
         junit.textui.TestRunner.run(SecurityContextHolderAwareRequestFilterTests.class);
     }
 
+    public final void setUp() throws Exception {
+        super.setUp();
+    }
+
     public void testCorrectOperation() throws Exception {
         SecurityContextHolderAwareRequestFilter filter = new SecurityContextHolderAwareRequestFilter();
         filter.init(new MockFilterConfig());
         filter.doFilter(new MockHttpServletRequest(null, null), null,
-            new MockFilterChain(SecurityContextHolderAwareRequestWrapper.class));
+            new MockFilterChain(SavedRequestAwareWrapper.class));
 
         // Now re-execute the filter, ensuring our replacement wrapper is still used
         filter.doFilter(new MockHttpServletRequest(null, null), null,
-            new MockFilterChain(SecurityContextHolderAwareRequestWrapper.class));
+            new MockFilterChain(SavedRequestAwareWrapper.class));
 
         filter.destroy();
     }

+ 3 - 1
samples/contacts/src/main/webapp/filter/WEB-INF/applicationContext-acegi-security.xml

@@ -21,7 +21,7 @@
          <value>
 		    CONVERT_URL_TO_LOWERCASE_BEFORE_COMPARISON
 		    PATTERN_TYPE_APACHE_ANT
-            /**=httpSessionContextIntegrationFilter,logoutFilter,authenticationProcessingFilter,basicProcessingFilter,rememberMeProcessingFilter,anonymousProcessingFilter,switchUserProcessingFilter,exceptionTranslationFilter,filterInvocationInterceptor
+            /**=httpSessionContextIntegrationFilter,logoutFilter,authenticationProcessingFilter,basicProcessingFilter,securityContextHolderAwareRequestFilter,rememberMeProcessingFilter,anonymousProcessingFilter,switchUserProcessingFilter,exceptionTranslationFilter,filterInvocationInterceptor
          </value>
       </property>
     </bean>
@@ -112,6 +112,8 @@
          </list>
       </constructor-arg>
    </bean>
+   
+   <bean id="securityContextHolderAwareRequestFilter" class="org.acegisecurity.wrapper.SecurityContextHolderAwareRequestFilter"/>
 
    <!-- ===================== HTTP CHANNEL REQUIREMENTS ==================== -->