2
0
Эх сурвалжийг харах

SEC-1167: Extended SavedRequest interface to allow it to be used by wrapper. Removed null checks in wrapper, as the SavedRequest cannot now be null.

Luke Taylor 16 жил өмнө
parent
commit
9c7423599e

+ 19 - 18
web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java

@@ -15,23 +15,24 @@
 
 
 package org.springframework.security.web.savedrequest;
 package org.springframework.security.web.savedrequest;
 
 
-import org.springframework.security.web.PortResolver;
-import org.springframework.security.web.util.UrlUtils;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.springframework.util.Assert;
-
-import javax.servlet.http.Cookie;
-import javax.servlet.http.HttpServletRequest;
 import java.util.ArrayList;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Collections;
 import java.util.Enumeration;
 import java.util.Enumeration;
-import java.util.Iterator;
 import java.util.List;
 import java.util.List;
 import java.util.Locale;
 import java.util.Locale;
 import java.util.Map;
 import java.util.Map;
 import java.util.TreeMap;
 import java.util.TreeMap;
 
 
+import javax.servlet.http.Cookie;
+import javax.servlet.http.HttpServletRequest;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.springframework.security.web.PortResolver;
+import org.springframework.security.web.util.UrlUtils;
+import org.springframework.util.Assert;
+
 
 
 /**
 /**
  * Represents central information from a <code>HttpServletRequest</code>.<p>This class is used by {@link
  * Represents central information from a <code>HttpServletRequest</code>.<p>This class is used by {@link
@@ -237,22 +238,22 @@ public class DefaultSavedRequest implements SavedRequest {
                 pathInfo, queryString);
                 pathInfo, queryString);
     }
     }
 
 
-    public Iterator<String> getHeaderNames() {
-        return (headers.keySet().iterator());
+    public Collection<String> getHeaderNames() {
+        return headers.keySet();
     }
     }
 
 
-    public Iterator<String> getHeaderValues(String name) {
+    public List<String> getHeaderValues(String name) {
         List<String> values = headers.get(name);
         List<String> values = headers.get(name);
 
 
         if (values == null) {
         if (values == null) {
-            values = Collections.emptyList();
+            return Collections.emptyList();
         }
         }
 
 
-        return (values.iterator());
+        return values;
     }
     }
 
 
-    public Iterator<Locale> getLocales() {
-        return (locales.iterator());
+    public List<Locale> getLocales() {
+        return locales;
     }
     }
 
 
     public String getMethod() {
     public String getMethod() {
@@ -263,8 +264,8 @@ public class DefaultSavedRequest implements SavedRequest {
         return parameters;
         return parameters;
     }
     }
 
 
-    public Iterator<String> getParameterNames() {
-        return (parameters.keySet().iterator());
+    public Collection<String> getParameterNames() {
+        return parameters.keySet();
     }
     }
 
 
     public String[] getParameterValues(String name) {
     public String[] getParameterValues(String name) {

+ 24 - 2
web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java

@@ -1,8 +1,16 @@
 package org.springframework.security.web.savedrequest;
 package org.springframework.security.web.savedrequest;
 
 
+import java.util.Collection;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+
+import javax.servlet.http.Cookie;
+
 /**
 /**
- * Encapsulates the functionality required of a cached request, in order for an authentication mechanism (typically
- * form-based login) to redirect to the original URL.
+ * Encapsulates the functionality required of a cached request for both an authentication mechanism (typically
+ * form-based login) to redirect to the original URL and for a <tt>RequestCache</tt> to build a wrapped request,
+ * reproducing the original request data.
  *
  *
  * @author Luke Taylor
  * @author Luke Taylor
  * @version $Id$
  * @version $Id$
@@ -14,4 +22,18 @@ public interface SavedRequest extends java.io.Serializable {
      * @return the URL for the saved request, allowing a redirect to be performed.
      * @return the URL for the saved request, allowing a redirect to be performed.
      */
      */
     String getRedirectUrl();
     String getRedirectUrl();
+
+    List<Cookie> getCookies();
+
+    String getMethod();
+
+    List<String> getHeaderValues(String name);
+
+    Collection<String> getHeaderNames();
+
+    List<Locale> getLocales();
+
+    String[] getParameterValues(String name);
+
+    Map<String,String[]> getParameterMap();
 }
 }

+ 27 - 97
web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java

@@ -21,7 +21,6 @@ import java.util.Arrays;
 import java.util.Enumeration;
 import java.util.Enumeration;
 import java.util.HashMap;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.HashSet;
-import java.util.Iterator;
 import java.util.List;
 import java.util.List;
 import java.util.Locale;
 import java.util.Locale;
 import java.util.Map;
 import java.util.Map;
@@ -45,9 +44,6 @@ import org.apache.commons.logging.LogFactory;
  *
  *
  * <p>
  * <p>
  * Added into a request by {@link org.springframework.security.web.savedrequest.RequestCacheAwareFilter}.
  * Added into a request by {@link org.springframework.security.web.savedrequest.RequestCacheAwareFilter}.
- * </p>
- *
- * TODO: savedRequest cannot now be null, so convert the tests to reflect this and remove the null checks.
  *
  *
  * @author Andrey Grebnev
  * @author Andrey Grebnev
  * @author Ben Alex
  * @author Ben Alex
@@ -65,7 +61,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
 
     //~ Instance fields ================================================================================================
     //~ Instance fields ================================================================================================
 
 
-    protected DefaultSavedRequest savedRequest = null;
+    protected SavedRequest savedRequest = null;
 
 
     /**
     /**
      * The set of SimpleDateFormat formats to use in getDateHeader(). Notice that because SimpleDateFormat is
      * The set of SimpleDateFormat formats to use in getDateHeader(). Notice that because SimpleDateFormat is
@@ -75,7 +71,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
 
     //~ Constructors ===================================================================================================
     //~ Constructors ===================================================================================================
 
 
-    public SavedRequestAwareWrapper(DefaultSavedRequest saved, HttpServletRequest request) {
+    public SavedRequestAwareWrapper(SavedRequest saved, HttpServletRequest request) {
         super(request);
         super(request);
         savedRequest = saved;
         savedRequest = saved;
 
 
@@ -92,9 +88,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
 
     @Override
     @Override
     public Cookie[] getCookies() {
     public Cookie[] getCookies() {
-        if (savedRequest == null) {
-            return super.getCookies();
-        }
         List<Cookie> cookies = savedRequest.getCookies();
         List<Cookie> cookies = savedRequest.getCookies();
 
 
         return cookies.toArray(new Cookie[cookies.size()]);
         return cookies.toArray(new Cookie[cookies.size()]);
@@ -102,9 +95,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
 
     @Override
     @Override
     public long getDateHeader(String name) {
     public long getDateHeader(String name) {
-        if (savedRequest == null) {
-            return super.getDateHeader(name);
-        }
         String value = getHeader(name);
         String value = getHeader(name);
 
 
         if (value == null) {
         if (value == null) {
@@ -123,128 +113,79 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
 
     @Override
     @Override
     public String getHeader(String name) {
     public String getHeader(String name) {
-        if (savedRequest == null) {
-            return super.getHeader(name);
-        }
-
-        String header = null;
-        Iterator<String> iterator = savedRequest.getHeaderValues(name);
+        List<String> values = savedRequest.getHeaderValues(name);
 
 
-        while (iterator.hasNext()) {
-            header = iterator.next();
-
-            break;
-        }
-
-        return header;
+        return values.isEmpty() ? null : values.get(0);
     }
     }
 
 
     @Override
     @Override
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     public Enumeration getHeaderNames() {
     public Enumeration getHeaderNames() {
-        if (savedRequest == null) {
-            return super.getHeaderNames();
-        }
-
         return new Enumerator<String>(savedRequest.getHeaderNames());
         return new Enumerator<String>(savedRequest.getHeaderNames());
     }
     }
 
 
     @Override
     @Override
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     public Enumeration getHeaders(String name) {
     public Enumeration getHeaders(String name) {
-        if (savedRequest == null) {
-            return super.getHeaders(name);
-        } else {
-            return new Enumerator<String>(savedRequest.getHeaderValues(name));
-        }
+        return new Enumerator<String>(savedRequest.getHeaderValues(name));
     }
     }
 
 
     @Override
     @Override
     public int getIntHeader(String name) {
     public int getIntHeader(String name) {
-        if (savedRequest == null) {
-            return super.getIntHeader(name);
-        } else {
-            String value = getHeader(name);
+        String value = getHeader(name);
 
 
-            if (value == null) {
-                return -1;
-            } else {
-                return Integer.parseInt(value);
-            }
+        if (value == null) {
+            return -1;
+        } else {
+            return Integer.parseInt(value);
         }
         }
     }
     }
 
 
     @Override
     @Override
     public Locale getLocale() {
     public Locale getLocale() {
-        if (savedRequest == null) {
-            return super.getLocale();
-        } else {
-            Locale locale = null;
-            Iterator<Locale> iterator = savedRequest.getLocales();
+        List<Locale> locales = savedRequest.getLocales();
 
 
-            while (iterator.hasNext()) {
-                locale = (Locale) iterator.next();
-
-                break;
-            }
-
-            if (locale == null) {
-                return defaultLocale;
-            } else {
-                return locale;
-            }
-        }
+        return locales.isEmpty() ? Locale.getDefault() : locales.get(0);
     }
     }
 
 
     @Override
     @Override
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     public Enumeration getLocales() {
     public Enumeration getLocales() {
-        if (savedRequest == null) {
-            return super.getLocales();
-        }
-
-        Iterator<Locale> iterator = savedRequest.getLocales();
+        List<Locale> locales = savedRequest.getLocales();
 
 
-        if (iterator.hasNext()) {
-            return new Enumerator<Locale>(iterator);
+        if (locales.isEmpty()) {
+            // Fall back to default locale
+            locales = new ArrayList<Locale>(1);
+            locales.add(Locale.getDefault());
         }
         }
-        // Fall back to default locale
-        ArrayList<Locale> results = new ArrayList<Locale>(1);
-        results.add(defaultLocale);
 
 
-        return new Enumerator<Locale>(results.iterator());
+        return new Enumerator<Locale>(locales);
     }
     }
 
 
     @Override
     @Override
     public String getMethod() {
     public String getMethod() {
-        if (savedRequest == null) {
-            return super.getMethod();
-        } else {
-            return savedRequest.getMethod();
-        }
+        return savedRequest.getMethod();
     }
     }
 
 
     /**
     /**
-     * If the parameter is available from the wrapped request then either
-     * <ol>
-     * <li>There is no saved request (it a normal request)</li>
-     * <li>There is a saved request, but the request has been forwarded/included to a URL with parameters, either
-     * supplementing or overriding the saved request values.</li>
-     * </ol>
-     * In both cases the value from the wrapped request should be used.
+     * If the parameter is available from the wrapped request then the request has been forwarded/included to a URL
+     * with parameters, either supplementing or overriding the saved request values.
+     * <p>
+     * In this case, the value from the wrapped request should be used.
      * <p>
      * <p>
      * If the value from the wrapped request is null, an attempt will be made to retrieve the parameter
      * If the value from the wrapped request is null, an attempt will be made to retrieve the parameter
-     * from the DefaultSavedRequest, if available..
+     * from the saved request.
      */
      */
     @Override
     @Override
     public String getParameter(String name) {
     public String getParameter(String name) {
         String value = super.getParameter(name);
         String value = super.getParameter(name);
 
 
-        if (value != null || savedRequest == null) {
+        if (value != null) {
             return value;
             return value;
         }
         }
 
 
         String[] values = savedRequest.getParameterValues(name);
         String[] values = savedRequest.getParameterValues(name);
+
         if (values == null || values.length == 0) {
         if (values == null || values.length == 0) {
             return null;
             return null;
         }
         }
@@ -255,10 +196,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
     @Override
     @Override
     @SuppressWarnings("unchecked")
     @SuppressWarnings("unchecked")
     public Map getParameterMap() {
     public Map getParameterMap() {
-        if (savedRequest == null) {
-            return super.getParameterMap();
-        }
-
         Set<String> names = getCombinedParameterNames();
         Set<String> names = getCombinedParameterNames();
         Map<String, String[]> parameterMap = new HashMap<String, String[]>(names.size());
         Map<String, String[]> parameterMap = new HashMap<String, String[]>(names.size());
 
 
@@ -273,10 +210,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
     private Set<String> getCombinedParameterNames() {
     private Set<String> getCombinedParameterNames() {
         Set<String> names = new HashSet<String>();
         Set<String> names = new HashSet<String>();
         names.addAll(super.getParameterMap().keySet());
         names.addAll(super.getParameterMap().keySet());
-
-        if (savedRequest != null) {
-            names.addAll(savedRequest.getParameterMap().keySet());
-        }
+        names.addAll(savedRequest.getParameterMap().keySet());
 
 
         return names;
         return names;
     }
     }
@@ -289,10 +223,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
 
     @Override
     @Override
     public String[] getParameterValues(String name) {
     public String[] getParameterValues(String name) {
-        if (savedRequest == null) {
-            return super.getParameterValues(name);
-        }
-
         String[] savedRequestParams = savedRequest.getParameterValues(name);
         String[] savedRequestParams = savedRequest.getParameterValues(name);
         String[] wrappedRequestParams = super.getParameterValues(name);
         String[] wrappedRequestParams = super.getParameterValues(name);
 
 

+ 2 - 2
web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestTests.java → web/src/test/java/org/springframework/security/web/savedrequest/DefaultSavedRequestTests.java

@@ -10,14 +10,14 @@ import org.springframework.mock.web.MockHttpServletRequest;
 /**
 /**
  *
  *
  */
  */
-public class SavedRequestTests {
+public class DefaultSavedRequestTests {
 
 
     @Test
     @Test
     public void headersAreCaseInsensitive() throws Exception {
     public void headersAreCaseInsensitive() throws Exception {
         MockHttpServletRequest request = new MockHttpServletRequest();
         MockHttpServletRequest request = new MockHttpServletRequest();
         request.addHeader("USER-aGenT", "Mozilla");
         request.addHeader("USER-aGenT", "Mozilla");
         DefaultSavedRequest saved = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443));
         DefaultSavedRequest saved = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443));
-        assertEquals("Mozilla", saved.getHeaderValues("user-agent").next());
+        assertEquals("Mozilla", saved.getHeaderValues("user-agent").get(0));
     }
     }
 
 
     // TODO: Why are parameters case insensitive. I think this is a mistake
     // TODO: Why are parameters case insensitive. I think this is a mistake

+ 3 - 45
web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java

@@ -19,19 +19,10 @@ import org.springframework.security.web.savedrequest.SavedRequestAwareWrapper;
 public class SavedRequestAwareWrapperTests {
 public class SavedRequestAwareWrapperTests {
 
 
     private SavedRequestAwareWrapper createWrapper(MockHttpServletRequest requestToSave, MockHttpServletRequest requestToWrap) {
     private SavedRequestAwareWrapper createWrapper(MockHttpServletRequest requestToSave, MockHttpServletRequest requestToWrap) {
-        DefaultSavedRequest saved = requestToSave == null ? null : new DefaultSavedRequest(requestToSave, new PortResolverImpl());
+        DefaultSavedRequest saved = new DefaultSavedRequest(requestToSave, new PortResolverImpl());
         return new SavedRequestAwareWrapper(saved, requestToWrap);
         return new SavedRequestAwareWrapper(saved, requestToWrap);
     }
     }
 
 
-    @Test
-    public void wrappedRequestCookiesAreReturnedIfNoSavedRequestIsSet() throws Exception {
-        MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
-        wrappedRequest.setCookies(new Cookie[] {new Cookie("cookie", "fromwrapped")});
-        SavedRequestAwareWrapper wrapper = createWrapper(null, wrappedRequest);
-        assertEquals(1, wrapper.getCookies().length);
-        assertEquals("fromwrapped", wrapper.getCookies()[0].getValue());
-    }
-
     @Test
     @Test
     public void savedRequestCookiesAreReturnedIfSavedRequestIsSet() throws Exception {
     public void savedRequestCookiesAreReturnedIfSavedRequestIsSet() throws Exception {
         MockHttpServletRequest savedRequest = new MockHttpServletRequest();
         MockHttpServletRequest savedRequest = new MockHttpServletRequest();
@@ -61,27 +52,6 @@ public class SavedRequestAwareWrapperTests {
         assertEquals("header", wrapper.getHeaderNames().nextElement());
         assertEquals("header", wrapper.getHeaderNames().nextElement());
     }
     }
 
 
-    @Test
-    @SuppressWarnings("unchecked")
-    public void wrappedRequestHeaderIsReturnedIfSavedRequestIsNotSet() throws Exception {
-        MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
-        wrappedRequest.addHeader("header", "wrappedheader");
-        SavedRequestAwareWrapper wrapper = createWrapper(null, wrappedRequest);
-
-        assertNull(wrapper.getHeader("nonexistent"));
-        Enumeration headers = wrapper.getHeaders("nonexistent");
-        assertFalse(headers.hasMoreElements());
-
-        assertEquals("wrappedheader", wrapper.getHeader("header"));
-        headers = wrapper.getHeaders("header");
-        assertTrue(headers.hasMoreElements());
-        assertEquals("wrappedheader", headers.nextElement());
-        assertFalse(headers.hasMoreElements());
-        assertTrue(wrapper.getHeaderNames().hasMoreElements());
-        assertEquals("header", wrapper.getHeaderNames().nextElement());
-    }
-
-
     @Test
     @Test
     /* SEC-830. Assume we have a request to /someUrl?action=foo (the saved request)
     /* SEC-830. Assume we have a request to /someUrl?action=foo (the saved request)
      * and then RequestDispatcher.forward() it to /someUrl?action=bar.
      * and then RequestDispatcher.forward() it to /someUrl?action=bar.
@@ -125,8 +95,7 @@ public class SavedRequestAwareWrapperTests {
 
 
     @Test
     @Test
     public void getParameterValuesReturnsNullIfParameterIsntSet() {
     public void getParameterValuesReturnsNullIfParameterIsntSet() {
-        MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
-        SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(null, wrappedRequest);
+        SavedRequestAwareWrapper wrapper = createWrapper(new MockHttpServletRequest(), new MockHttpServletRequest());
         assertNull(wrapper.getParameterValues("action"));
         assertNull(wrapper.getParameterValues("action"));
         assertNull(wrapper.getParameterMap().get("action"));
         assertNull(wrapper.getParameterMap().get("action"));
     }
     }
@@ -148,7 +117,7 @@ public class SavedRequestAwareWrapperTests {
     }
     }
 
 
     @Test
     @Test
-    public void expecteDateHeaderIsReturnedFromSavedAndWrappedRequests() throws Exception {
+    public void expecteDateHeaderIsReturnedFromSavedRequest() throws Exception {
         SimpleDateFormat formatter = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US);
         SimpleDateFormat formatter = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US);
         String nowString = FastHttpDateFormat.getCurrentDate();
         String nowString = FastHttpDateFormat.getCurrentDate();
         Date now = formatter.parse(nowString);
         Date now = formatter.parse(nowString);
@@ -158,12 +127,6 @@ public class SavedRequestAwareWrapperTests {
         assertEquals(now.getTime(), wrapper.getDateHeader("header"));
         assertEquals(now.getTime(), wrapper.getDateHeader("header"));
 
 
         assertEquals(-1L, wrapper.getDateHeader("nonexistent"));
         assertEquals(-1L, wrapper.getDateHeader("nonexistent"));
-
-        // Now try with no saved request
-        request = new MockHttpServletRequest();
-        request.addHeader("header", now);
-        wrapper = createWrapper(null, request);
-        assertEquals(now.getTime(), wrapper.getDateHeader("header"));
     }
     }
 
 
     @Test(expected=IllegalArgumentException.class)
     @Test(expected=IllegalArgumentException.class)
@@ -179,8 +142,6 @@ public class SavedRequestAwareWrapperTests {
         MockHttpServletRequest request = new MockHttpServletRequest("PUT", "/notused");
         MockHttpServletRequest request = new MockHttpServletRequest("PUT", "/notused");
         SavedRequestAwareWrapper wrapper = createWrapper(request, new MockHttpServletRequest("GET", "/notused"));
         SavedRequestAwareWrapper wrapper = createWrapper(request, new MockHttpServletRequest("GET", "/notused"));
         assertEquals("PUT", wrapper.getMethod());
         assertEquals("PUT", wrapper.getMethod());
-        wrapper = createWrapper(null, request);
-        assertEquals("PUT", wrapper.getMethod());
     }
     }
 
 
     @Test
     @Test
@@ -192,9 +153,6 @@ public class SavedRequestAwareWrapperTests {
 
 
         assertEquals(999, wrapper.getIntHeader("header"));
         assertEquals(999, wrapper.getIntHeader("header"));
         assertEquals(-1, wrapper.getIntHeader("nonexistent"));
         assertEquals(-1, wrapper.getIntHeader("nonexistent"));
-
-        wrapper = createWrapper(null, request);
-        assertEquals(999, wrapper.getIntHeader("header"));
     }
     }
 
 
 }
 }