Browse Source

SEC-1167: Extended SavedRequest interface to allow it to be used by wrapper. Removed null checks in wrapper, as the SavedRequest cannot now be null.

Luke Taylor 16 years ago
parent
commit
9c7423599e

+ 19 - 18
web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java

@@ -15,23 +15,24 @@
 
 package org.springframework.security.web.savedrequest;
 
-import org.springframework.security.web.PortResolver;
-import org.springframework.security.web.util.UrlUtils;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.springframework.util.Assert;
-
-import javax.servlet.http.Cookie;
-import javax.servlet.http.HttpServletRequest;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Enumeration;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
 import java.util.TreeMap;
 
+import javax.servlet.http.Cookie;
+import javax.servlet.http.HttpServletRequest;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.springframework.security.web.PortResolver;
+import org.springframework.security.web.util.UrlUtils;
+import org.springframework.util.Assert;
+
 
 /**
  * Represents central information from a <code>HttpServletRequest</code>.<p>This class is used by {@link
@@ -237,22 +238,22 @@ public class DefaultSavedRequest implements SavedRequest {
                 pathInfo, queryString);
     }
 
-    public Iterator<String> getHeaderNames() {
-        return (headers.keySet().iterator());
+    public Collection<String> getHeaderNames() {
+        return headers.keySet();
     }
 
-    public Iterator<String> getHeaderValues(String name) {
+    public List<String> getHeaderValues(String name) {
         List<String> values = headers.get(name);
 
         if (values == null) {
-            values = Collections.emptyList();
+            return Collections.emptyList();
         }
 
-        return (values.iterator());
+        return values;
     }
 
-    public Iterator<Locale> getLocales() {
-        return (locales.iterator());
+    public List<Locale> getLocales() {
+        return locales;
     }
 
     public String getMethod() {
@@ -263,8 +264,8 @@ public class DefaultSavedRequest implements SavedRequest {
         return parameters;
     }
 
-    public Iterator<String> getParameterNames() {
-        return (parameters.keySet().iterator());
+    public Collection<String> getParameterNames() {
+        return parameters.keySet();
     }
 
     public String[] getParameterValues(String name) {

+ 24 - 2
web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java

@@ -1,8 +1,16 @@
 package org.springframework.security.web.savedrequest;
 
+import java.util.Collection;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+
+import javax.servlet.http.Cookie;
+
 /**
- * Encapsulates the functionality required of a cached request, in order for an authentication mechanism (typically
- * form-based login) to redirect to the original URL.
+ * Encapsulates the functionality required of a cached request for both an authentication mechanism (typically
+ * form-based login) to redirect to the original URL and for a <tt>RequestCache</tt> to build a wrapped request,
+ * reproducing the original request data.
  *
  * @author Luke Taylor
  * @version $Id$
@@ -14,4 +22,18 @@ public interface SavedRequest extends java.io.Serializable {
      * @return the URL for the saved request, allowing a redirect to be performed.
      */
     String getRedirectUrl();
+
+    List<Cookie> getCookies();
+
+    String getMethod();
+
+    List<String> getHeaderValues(String name);
+
+    Collection<String> getHeaderNames();
+
+    List<Locale> getLocales();
+
+    String[] getParameterValues(String name);
+
+    Map<String,String[]> getParameterMap();
 }

+ 27 - 97
web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java

@@ -21,7 +21,6 @@ import java.util.Arrays;
 import java.util.Enumeration;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
@@ -45,9 +44,6 @@ import org.apache.commons.logging.LogFactory;
  *
  * <p>
  * Added into a request by {@link org.springframework.security.web.savedrequest.RequestCacheAwareFilter}.
- * </p>
- *
- * TODO: savedRequest cannot now be null, so convert the tests to reflect this and remove the null checks.
  *
  * @author Andrey Grebnev
  * @author Ben Alex
@@ -65,7 +61,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
     //~ Instance fields ================================================================================================
 
-    protected DefaultSavedRequest savedRequest = null;
+    protected SavedRequest savedRequest = null;
 
     /**
      * The set of SimpleDateFormat formats to use in getDateHeader(). Notice that because SimpleDateFormat is
@@ -75,7 +71,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
     //~ Constructors ===================================================================================================
 
-    public SavedRequestAwareWrapper(DefaultSavedRequest saved, HttpServletRequest request) {
+    public SavedRequestAwareWrapper(SavedRequest saved, HttpServletRequest request) {
         super(request);
         savedRequest = saved;
 
@@ -92,9 +88,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
     @Override
     public Cookie[] getCookies() {
-        if (savedRequest == null) {
-            return super.getCookies();
-        }
         List<Cookie> cookies = savedRequest.getCookies();
 
         return cookies.toArray(new Cookie[cookies.size()]);
@@ -102,9 +95,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
     @Override
     public long getDateHeader(String name) {
-        if (savedRequest == null) {
-            return super.getDateHeader(name);
-        }
         String value = getHeader(name);
 
         if (value == null) {
@@ -123,128 +113,79 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
     @Override
     public String getHeader(String name) {
-        if (savedRequest == null) {
-            return super.getHeader(name);
-        }
-
-        String header = null;
-        Iterator<String> iterator = savedRequest.getHeaderValues(name);
+        List<String> values = savedRequest.getHeaderValues(name);
 
-        while (iterator.hasNext()) {
-            header = iterator.next();
-
-            break;
-        }
-
-        return header;
+        return values.isEmpty() ? null : values.get(0);
     }
 
     @Override
     @SuppressWarnings("unchecked")
     public Enumeration getHeaderNames() {
-        if (savedRequest == null) {
-            return super.getHeaderNames();
-        }
-
         return new Enumerator<String>(savedRequest.getHeaderNames());
     }
 
     @Override
     @SuppressWarnings("unchecked")
     public Enumeration getHeaders(String name) {
-        if (savedRequest == null) {
-            return super.getHeaders(name);
-        } else {
-            return new Enumerator<String>(savedRequest.getHeaderValues(name));
-        }
+        return new Enumerator<String>(savedRequest.getHeaderValues(name));
     }
 
     @Override
     public int getIntHeader(String name) {
-        if (savedRequest == null) {
-            return super.getIntHeader(name);
-        } else {
-            String value = getHeader(name);
+        String value = getHeader(name);
 
-            if (value == null) {
-                return -1;
-            } else {
-                return Integer.parseInt(value);
-            }
+        if (value == null) {
+            return -1;
+        } else {
+            return Integer.parseInt(value);
         }
     }
 
     @Override
     public Locale getLocale() {
-        if (savedRequest == null) {
-            return super.getLocale();
-        } else {
-            Locale locale = null;
-            Iterator<Locale> iterator = savedRequest.getLocales();
+        List<Locale> locales = savedRequest.getLocales();
 
-            while (iterator.hasNext()) {
-                locale = (Locale) iterator.next();
-
-                break;
-            }
-
-            if (locale == null) {
-                return defaultLocale;
-            } else {
-                return locale;
-            }
-        }
+        return locales.isEmpty() ? Locale.getDefault() : locales.get(0);
     }
 
     @Override
     @SuppressWarnings("unchecked")
     public Enumeration getLocales() {
-        if (savedRequest == null) {
-            return super.getLocales();
-        }
-
-        Iterator<Locale> iterator = savedRequest.getLocales();
+        List<Locale> locales = savedRequest.getLocales();
 
-        if (iterator.hasNext()) {
-            return new Enumerator<Locale>(iterator);
+        if (locales.isEmpty()) {
+            // Fall back to default locale
+            locales = new ArrayList<Locale>(1);
+            locales.add(Locale.getDefault());
         }
-        // Fall back to default locale
-        ArrayList<Locale> results = new ArrayList<Locale>(1);
-        results.add(defaultLocale);
 
-        return new Enumerator<Locale>(results.iterator());
+        return new Enumerator<Locale>(locales);
     }
 
     @Override
     public String getMethod() {
-        if (savedRequest == null) {
-            return super.getMethod();
-        } else {
-            return savedRequest.getMethod();
-        }
+        return savedRequest.getMethod();
     }
 
     /**
-     * If the parameter is available from the wrapped request then either
-     * <ol>
-     * <li>There is no saved request (it a normal request)</li>
-     * <li>There is a saved request, but the request has been forwarded/included to a URL with parameters, either
-     * supplementing or overriding the saved request values.</li>
-     * </ol>
-     * In both cases the value from the wrapped request should be used.
+     * If the parameter is available from the wrapped request then the request has been forwarded/included to a URL
+     * with parameters, either supplementing or overriding the saved request values.
+     * <p>
+     * In this case, the value from the wrapped request should be used.
      * <p>
      * If the value from the wrapped request is null, an attempt will be made to retrieve the parameter
-     * from the DefaultSavedRequest, if available..
+     * from the saved request.
      */
     @Override
     public String getParameter(String name) {
         String value = super.getParameter(name);
 
-        if (value != null || savedRequest == null) {
+        if (value != null) {
             return value;
         }
 
         String[] values = savedRequest.getParameterValues(name);
+
         if (values == null || values.length == 0) {
             return null;
         }
@@ -255,10 +196,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
     @Override
     @SuppressWarnings("unchecked")
     public Map getParameterMap() {
-        if (savedRequest == null) {
-            return super.getParameterMap();
-        }
-
         Set<String> names = getCombinedParameterNames();
         Map<String, String[]> parameterMap = new HashMap<String, String[]>(names.size());
 
@@ -273,10 +210,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
     private Set<String> getCombinedParameterNames() {
         Set<String> names = new HashSet<String>();
         names.addAll(super.getParameterMap().keySet());
-
-        if (savedRequest != null) {
-            names.addAll(savedRequest.getParameterMap().keySet());
-        }
+        names.addAll(savedRequest.getParameterMap().keySet());
 
         return names;
     }
@@ -289,10 +223,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
     @Override
     public String[] getParameterValues(String name) {
-        if (savedRequest == null) {
-            return super.getParameterValues(name);
-        }
-
         String[] savedRequestParams = savedRequest.getParameterValues(name);
         String[] wrappedRequestParams = super.getParameterValues(name);
 

+ 2 - 2
web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestTests.java → web/src/test/java/org/springframework/security/web/savedrequest/DefaultSavedRequestTests.java

@@ -10,14 +10,14 @@ import org.springframework.mock.web.MockHttpServletRequest;
 /**
  *
  */
-public class SavedRequestTests {
+public class DefaultSavedRequestTests {
 
     @Test
     public void headersAreCaseInsensitive() throws Exception {
         MockHttpServletRequest request = new MockHttpServletRequest();
         request.addHeader("USER-aGenT", "Mozilla");
         DefaultSavedRequest saved = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443));
-        assertEquals("Mozilla", saved.getHeaderValues("user-agent").next());
+        assertEquals("Mozilla", saved.getHeaderValues("user-agent").get(0));
     }
 
     // TODO: Why are parameters case insensitive. I think this is a mistake

+ 3 - 45
web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java

@@ -19,19 +19,10 @@ import org.springframework.security.web.savedrequest.SavedRequestAwareWrapper;
 public class SavedRequestAwareWrapperTests {
 
     private SavedRequestAwareWrapper createWrapper(MockHttpServletRequest requestToSave, MockHttpServletRequest requestToWrap) {
-        DefaultSavedRequest saved = requestToSave == null ? null : new DefaultSavedRequest(requestToSave, new PortResolverImpl());
+        DefaultSavedRequest saved = new DefaultSavedRequest(requestToSave, new PortResolverImpl());
         return new SavedRequestAwareWrapper(saved, requestToWrap);
     }
 
-    @Test
-    public void wrappedRequestCookiesAreReturnedIfNoSavedRequestIsSet() throws Exception {
-        MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
-        wrappedRequest.setCookies(new Cookie[] {new Cookie("cookie", "fromwrapped")});
-        SavedRequestAwareWrapper wrapper = createWrapper(null, wrappedRequest);
-        assertEquals(1, wrapper.getCookies().length);
-        assertEquals("fromwrapped", wrapper.getCookies()[0].getValue());
-    }
-
     @Test
     public void savedRequestCookiesAreReturnedIfSavedRequestIsSet() throws Exception {
         MockHttpServletRequest savedRequest = new MockHttpServletRequest();
@@ -61,27 +52,6 @@ public class SavedRequestAwareWrapperTests {
         assertEquals("header", wrapper.getHeaderNames().nextElement());
     }
 
-    @Test
-    @SuppressWarnings("unchecked")
-    public void wrappedRequestHeaderIsReturnedIfSavedRequestIsNotSet() throws Exception {
-        MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
-        wrappedRequest.addHeader("header", "wrappedheader");
-        SavedRequestAwareWrapper wrapper = createWrapper(null, wrappedRequest);
-
-        assertNull(wrapper.getHeader("nonexistent"));
-        Enumeration headers = wrapper.getHeaders("nonexistent");
-        assertFalse(headers.hasMoreElements());
-
-        assertEquals("wrappedheader", wrapper.getHeader("header"));
-        headers = wrapper.getHeaders("header");
-        assertTrue(headers.hasMoreElements());
-        assertEquals("wrappedheader", headers.nextElement());
-        assertFalse(headers.hasMoreElements());
-        assertTrue(wrapper.getHeaderNames().hasMoreElements());
-        assertEquals("header", wrapper.getHeaderNames().nextElement());
-    }
-
-
     @Test
     /* SEC-830. Assume we have a request to /someUrl?action=foo (the saved request)
      * and then RequestDispatcher.forward() it to /someUrl?action=bar.
@@ -125,8 +95,7 @@ public class SavedRequestAwareWrapperTests {
 
     @Test
     public void getParameterValuesReturnsNullIfParameterIsntSet() {
-        MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
-        SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(null, wrappedRequest);
+        SavedRequestAwareWrapper wrapper = createWrapper(new MockHttpServletRequest(), new MockHttpServletRequest());
         assertNull(wrapper.getParameterValues("action"));
         assertNull(wrapper.getParameterMap().get("action"));
     }
@@ -148,7 +117,7 @@ public class SavedRequestAwareWrapperTests {
     }
 
     @Test
-    public void expecteDateHeaderIsReturnedFromSavedAndWrappedRequests() throws Exception {
+    public void expecteDateHeaderIsReturnedFromSavedRequest() throws Exception {
         SimpleDateFormat formatter = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US);
         String nowString = FastHttpDateFormat.getCurrentDate();
         Date now = formatter.parse(nowString);
@@ -158,12 +127,6 @@ public class SavedRequestAwareWrapperTests {
         assertEquals(now.getTime(), wrapper.getDateHeader("header"));
 
         assertEquals(-1L, wrapper.getDateHeader("nonexistent"));
-
-        // Now try with no saved request
-        request = new MockHttpServletRequest();
-        request.addHeader("header", now);
-        wrapper = createWrapper(null, request);
-        assertEquals(now.getTime(), wrapper.getDateHeader("header"));
     }
 
     @Test(expected=IllegalArgumentException.class)
@@ -179,8 +142,6 @@ public class SavedRequestAwareWrapperTests {
         MockHttpServletRequest request = new MockHttpServletRequest("PUT", "/notused");
         SavedRequestAwareWrapper wrapper = createWrapper(request, new MockHttpServletRequest("GET", "/notused"));
         assertEquals("PUT", wrapper.getMethod());
-        wrapper = createWrapper(null, request);
-        assertEquals("PUT", wrapper.getMethod());
     }
 
     @Test
@@ -192,9 +153,6 @@ public class SavedRequestAwareWrapperTests {
 
         assertEquals(999, wrapper.getIntHeader("header"));
         assertEquals(-1, wrapper.getIntHeader("nonexistent"));
-
-        wrapper = createWrapper(null, request);
-        assertEquals(999, wrapper.getIntHeader("header"));
     }
 
 }