소스 검색

SEC-1167: Extended SavedRequest interface to allow it to be used by wrapper. Removed null checks in wrapper, as the SavedRequest cannot now be null.

Luke Taylor 16 년 전
부모
커밋
9c7423599e

+ 19 - 18
web/src/main/java/org/springframework/security/web/savedrequest/DefaultSavedRequest.java

@@ -15,23 +15,24 @@
 
 package org.springframework.security.web.savedrequest;
 
-import org.springframework.security.web.PortResolver;
-import org.springframework.security.web.util.UrlUtils;
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.springframework.util.Assert;
-
-import javax.servlet.http.Cookie;
-import javax.servlet.http.HttpServletRequest;
 import java.util.ArrayList;
+import java.util.Collection;
 import java.util.Collections;
 import java.util.Enumeration;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
 import java.util.TreeMap;
 
+import javax.servlet.http.Cookie;
+import javax.servlet.http.HttpServletRequest;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.springframework.security.web.PortResolver;
+import org.springframework.security.web.util.UrlUtils;
+import org.springframework.util.Assert;
+
 
 /**
  * Represents central information from a <code>HttpServletRequest</code>.<p>This class is used by {@link
@@ -237,22 +238,22 @@ public class DefaultSavedRequest implements SavedRequest {
                 pathInfo, queryString);
     }
 
-    public Iterator<String> getHeaderNames() {
-        return (headers.keySet().iterator());
+    public Collection<String> getHeaderNames() {
+        return headers.keySet();
     }
 
-    public Iterator<String> getHeaderValues(String name) {
+    public List<String> getHeaderValues(String name) {
         List<String> values = headers.get(name);
 
         if (values == null) {
-            values = Collections.emptyList();
+            return Collections.emptyList();
         }
 
-        return (values.iterator());
+        return values;
     }
 
-    public Iterator<Locale> getLocales() {
-        return (locales.iterator());
+    public List<Locale> getLocales() {
+        return locales;
     }
 
     public String getMethod() {
@@ -263,8 +264,8 @@ public class DefaultSavedRequest implements SavedRequest {
         return parameters;
     }
 
-    public Iterator<String> getParameterNames() {
-        return (parameters.keySet().iterator());
+    public Collection<String> getParameterNames() {
+        return parameters.keySet();
     }
 
     public String[] getParameterValues(String name) {

+ 24 - 2
web/src/main/java/org/springframework/security/web/savedrequest/SavedRequest.java

@@ -1,8 +1,16 @@
 package org.springframework.security.web.savedrequest;
 
+import java.util.Collection;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+
+import javax.servlet.http.Cookie;
+
 /**
- * Encapsulates the functionality required of a cached request, in order for an authentication mechanism (typically
- * form-based login) to redirect to the original URL.
+ * Encapsulates the functionality required of a cached request for both an authentication mechanism (typically
+ * form-based login) to redirect to the original URL and for a <tt>RequestCache</tt> to build a wrapped request,
+ * reproducing the original request data.
  *
  * @author Luke Taylor
  * @version $Id$
@@ -14,4 +22,18 @@ public interface SavedRequest extends java.io.Serializable {
      * @return the URL for the saved request, allowing a redirect to be performed.
      */
     String getRedirectUrl();
+
+    List<Cookie> getCookies();
+
+    String getMethod();
+
+    List<String> getHeaderValues(String name);
+
+    Collection<String> getHeaderNames();
+
+    List<Locale> getLocales();
+
+    String[] getParameterValues(String name);
+
+    Map<String,String[]> getParameterMap();
 }

+ 27 - 97
web/src/main/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapper.java

@@ -21,7 +21,6 @@ import java.util.Arrays;
 import java.util.Enumeration;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.Iterator;
 import java.util.List;
 import java.util.Locale;
 import java.util.Map;
@@ -45,9 +44,6 @@ import org.apache.commons.logging.LogFactory;
  *
  * <p>
  * Added into a request by {@link org.springframework.security.web.savedrequest.RequestCacheAwareFilter}.
- * </p>
- *
- * TODO: savedRequest cannot now be null, so convert the tests to reflect this and remove the null checks.
  *
  * @author Andrey Grebnev
  * @author Ben Alex
@@ -65,7 +61,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
     //~ Instance fields ================================================================================================
 
-    protected DefaultSavedRequest savedRequest = null;
+    protected SavedRequest savedRequest = null;
 
     /**
      * The set of SimpleDateFormat formats to use in getDateHeader(). Notice that because SimpleDateFormat is
@@ -75,7 +71,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
     //~ Constructors ===================================================================================================
 
-    public SavedRequestAwareWrapper(DefaultSavedRequest saved, HttpServletRequest request) {
+    public SavedRequestAwareWrapper(SavedRequest saved, HttpServletRequest request) {
         super(request);
         savedRequest = saved;
 
@@ -92,9 +88,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
     @Override
     public Cookie[] getCookies() {
-        if (savedRequest == null) {
-            return super.getCookies();
-        }
         List<Cookie> cookies = savedRequest.getCookies();
 
         return cookies.toArray(new Cookie[cookies.size()]);
@@ -102,9 +95,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
     @Override
     public long getDateHeader(String name) {
-        if (savedRequest == null) {
-            return super.getDateHeader(name);
-        }
         String value = getHeader(name);
 
         if (value == null) {
@@ -123,128 +113,79 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
     @Override
     public String getHeader(String name) {
-        if (savedRequest == null) {
-            return super.getHeader(name);
-        }
-
-        String header = null;
-        Iterator<String> iterator = savedRequest.getHeaderValues(name);
+        List<String> values = savedRequest.getHeaderValues(name);
 
-        while (iterator.hasNext()) {
-            header = iterator.next();
-
-            break;
-        }
-
-        return header;
+        return values.isEmpty() ? null : values.get(0);
     }
 
     @Override
     @SuppressWarnings("unchecked")
     public Enumeration getHeaderNames() {
-        if (savedRequest == null) {
-            return super.getHeaderNames();
-        }
-
         return new Enumerator<String>(savedRequest.getHeaderNames());
     }
 
     @Override
     @SuppressWarnings("unchecked")
     public Enumeration getHeaders(String name) {
-        if (savedRequest == null) {
-            return super.getHeaders(name);
-        } else {
-            return new Enumerator<String>(savedRequest.getHeaderValues(name));
-        }
+        return new Enumerator<String>(savedRequest.getHeaderValues(name));
     }
 
     @Override
     public int getIntHeader(String name) {
-        if (savedRequest == null) {
-            return super.getIntHeader(name);
-        } else {
-            String value = getHeader(name);
+        String value = getHeader(name);
 
-            if (value == null) {
-                return -1;
-            } else {
-                return Integer.parseInt(value);
-            }
+        if (value == null) {
+            return -1;
+        } else {
+            return Integer.parseInt(value);
         }
     }
 
     @Override
     public Locale getLocale() {
-        if (savedRequest == null) {
-            return super.getLocale();
-        } else {
-            Locale locale = null;
-            Iterator<Locale> iterator = savedRequest.getLocales();
+        List<Locale> locales = savedRequest.getLocales();
 
-            while (iterator.hasNext()) {
-                locale = (Locale) iterator.next();
-
-                break;
-            }
-
-            if (locale == null) {
-                return defaultLocale;
-            } else {
-                return locale;
-            }
-        }
+        return locales.isEmpty() ? Locale.getDefault() : locales.get(0);
     }
 
     @Override
     @SuppressWarnings("unchecked")
     public Enumeration getLocales() {
-        if (savedRequest == null) {
-            return super.getLocales();
-        }
-
-        Iterator<Locale> iterator = savedRequest.getLocales();
+        List<Locale> locales = savedRequest.getLocales();
 
-        if (iterator.hasNext()) {
-            return new Enumerator<Locale>(iterator);
+        if (locales.isEmpty()) {
+            // Fall back to default locale
+            locales = new ArrayList<Locale>(1);
+            locales.add(Locale.getDefault());
         }
-        // Fall back to default locale
-        ArrayList<Locale> results = new ArrayList<Locale>(1);
-        results.add(defaultLocale);
 
-        return new Enumerator<Locale>(results.iterator());
+        return new Enumerator<Locale>(locales);
     }
 
     @Override
     public String getMethod() {
-        if (savedRequest == null) {
-            return super.getMethod();
-        } else {
-            return savedRequest.getMethod();
-        }
+        return savedRequest.getMethod();
     }
 
     /**
-     * If the parameter is available from the wrapped request then either
-     * <ol>
-     * <li>There is no saved request (it a normal request)</li>
-     * <li>There is a saved request, but the request has been forwarded/included to a URL with parameters, either
-     * supplementing or overriding the saved request values.</li>
-     * </ol>
-     * In both cases the value from the wrapped request should be used.
+     * If the parameter is available from the wrapped request then the request has been forwarded/included to a URL
+     * with parameters, either supplementing or overriding the saved request values.
+     * <p>
+     * In this case, the value from the wrapped request should be used.
      * <p>
      * If the value from the wrapped request is null, an attempt will be made to retrieve the parameter
-     * from the DefaultSavedRequest, if available..
+     * from the saved request.
      */
     @Override
     public String getParameter(String name) {
         String value = super.getParameter(name);
 
-        if (value != null || savedRequest == null) {
+        if (value != null) {
             return value;
         }
 
         String[] values = savedRequest.getParameterValues(name);
+
         if (values == null || values.length == 0) {
             return null;
         }
@@ -255,10 +196,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
     @Override
     @SuppressWarnings("unchecked")
     public Map getParameterMap() {
-        if (savedRequest == null) {
-            return super.getParameterMap();
-        }
-
         Set<String> names = getCombinedParameterNames();
         Map<String, String[]> parameterMap = new HashMap<String, String[]>(names.size());
 
@@ -273,10 +210,7 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
     private Set<String> getCombinedParameterNames() {
         Set<String> names = new HashSet<String>();
         names.addAll(super.getParameterMap().keySet());
-
-        if (savedRequest != null) {
-            names.addAll(savedRequest.getParameterMap().keySet());
-        }
+        names.addAll(savedRequest.getParameterMap().keySet());
 
         return names;
     }
@@ -289,10 +223,6 @@ class SavedRequestAwareWrapper extends HttpServletRequestWrapper {
 
     @Override
     public String[] getParameterValues(String name) {
-        if (savedRequest == null) {
-            return super.getParameterValues(name);
-        }
-
         String[] savedRequestParams = savedRequest.getParameterValues(name);
         String[] wrappedRequestParams = super.getParameterValues(name);
 

+ 2 - 2
web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestTests.java → web/src/test/java/org/springframework/security/web/savedrequest/DefaultSavedRequestTests.java

@@ -10,14 +10,14 @@ import org.springframework.mock.web.MockHttpServletRequest;
 /**
  *
  */
-public class SavedRequestTests {
+public class DefaultSavedRequestTests {
 
     @Test
     public void headersAreCaseInsensitive() throws Exception {
         MockHttpServletRequest request = new MockHttpServletRequest();
         request.addHeader("USER-aGenT", "Mozilla");
         DefaultSavedRequest saved = new DefaultSavedRequest(request, new MockPortResolver(8080, 8443));
-        assertEquals("Mozilla", saved.getHeaderValues("user-agent").next());
+        assertEquals("Mozilla", saved.getHeaderValues("user-agent").get(0));
     }
 
     // TODO: Why are parameters case insensitive. I think this is a mistake

+ 3 - 45
web/src/test/java/org/springframework/security/web/savedrequest/SavedRequestAwareWrapperTests.java

@@ -19,19 +19,10 @@ import org.springframework.security.web.savedrequest.SavedRequestAwareWrapper;
 public class SavedRequestAwareWrapperTests {
 
     private SavedRequestAwareWrapper createWrapper(MockHttpServletRequest requestToSave, MockHttpServletRequest requestToWrap) {
-        DefaultSavedRequest saved = requestToSave == null ? null : new DefaultSavedRequest(requestToSave, new PortResolverImpl());
+        DefaultSavedRequest saved = new DefaultSavedRequest(requestToSave, new PortResolverImpl());
         return new SavedRequestAwareWrapper(saved, requestToWrap);
     }
 
-    @Test
-    public void wrappedRequestCookiesAreReturnedIfNoSavedRequestIsSet() throws Exception {
-        MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
-        wrappedRequest.setCookies(new Cookie[] {new Cookie("cookie", "fromwrapped")});
-        SavedRequestAwareWrapper wrapper = createWrapper(null, wrappedRequest);
-        assertEquals(1, wrapper.getCookies().length);
-        assertEquals("fromwrapped", wrapper.getCookies()[0].getValue());
-    }
-
     @Test
     public void savedRequestCookiesAreReturnedIfSavedRequestIsSet() throws Exception {
         MockHttpServletRequest savedRequest = new MockHttpServletRequest();
@@ -61,27 +52,6 @@ public class SavedRequestAwareWrapperTests {
         assertEquals("header", wrapper.getHeaderNames().nextElement());
     }
 
-    @Test
-    @SuppressWarnings("unchecked")
-    public void wrappedRequestHeaderIsReturnedIfSavedRequestIsNotSet() throws Exception {
-        MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
-        wrappedRequest.addHeader("header", "wrappedheader");
-        SavedRequestAwareWrapper wrapper = createWrapper(null, wrappedRequest);
-
-        assertNull(wrapper.getHeader("nonexistent"));
-        Enumeration headers = wrapper.getHeaders("nonexistent");
-        assertFalse(headers.hasMoreElements());
-
-        assertEquals("wrappedheader", wrapper.getHeader("header"));
-        headers = wrapper.getHeaders("header");
-        assertTrue(headers.hasMoreElements());
-        assertEquals("wrappedheader", headers.nextElement());
-        assertFalse(headers.hasMoreElements());
-        assertTrue(wrapper.getHeaderNames().hasMoreElements());
-        assertEquals("header", wrapper.getHeaderNames().nextElement());
-    }
-
-
     @Test
     /* SEC-830. Assume we have a request to /someUrl?action=foo (the saved request)
      * and then RequestDispatcher.forward() it to /someUrl?action=bar.
@@ -125,8 +95,7 @@ public class SavedRequestAwareWrapperTests {
 
     @Test
     public void getParameterValuesReturnsNullIfParameterIsntSet() {
-        MockHttpServletRequest wrappedRequest = new MockHttpServletRequest();
-        SavedRequestAwareWrapper wrapper = new SavedRequestAwareWrapper(null, wrappedRequest);
+        SavedRequestAwareWrapper wrapper = createWrapper(new MockHttpServletRequest(), new MockHttpServletRequest());
         assertNull(wrapper.getParameterValues("action"));
         assertNull(wrapper.getParameterMap().get("action"));
     }
@@ -148,7 +117,7 @@ public class SavedRequestAwareWrapperTests {
     }
 
     @Test
-    public void expecteDateHeaderIsReturnedFromSavedAndWrappedRequests() throws Exception {
+    public void expecteDateHeaderIsReturnedFromSavedRequest() throws Exception {
         SimpleDateFormat formatter = new SimpleDateFormat("EEE, dd MMM yyyy HH:mm:ss zzz", Locale.US);
         String nowString = FastHttpDateFormat.getCurrentDate();
         Date now = formatter.parse(nowString);
@@ -158,12 +127,6 @@ public class SavedRequestAwareWrapperTests {
         assertEquals(now.getTime(), wrapper.getDateHeader("header"));
 
         assertEquals(-1L, wrapper.getDateHeader("nonexistent"));
-
-        // Now try with no saved request
-        request = new MockHttpServletRequest();
-        request.addHeader("header", now);
-        wrapper = createWrapper(null, request);
-        assertEquals(now.getTime(), wrapper.getDateHeader("header"));
     }
 
     @Test(expected=IllegalArgumentException.class)
@@ -179,8 +142,6 @@ public class SavedRequestAwareWrapperTests {
         MockHttpServletRequest request = new MockHttpServletRequest("PUT", "/notused");
         SavedRequestAwareWrapper wrapper = createWrapper(request, new MockHttpServletRequest("GET", "/notused"));
         assertEquals("PUT", wrapper.getMethod());
-        wrapper = createWrapper(null, request);
-        assertEquals("PUT", wrapper.getMethod());
     }
 
     @Test
@@ -192,9 +153,6 @@ public class SavedRequestAwareWrapperTests {
 
         assertEquals(999, wrapper.getIntHeader("header"));
         assertEquals(-1, wrapper.getIntHeader("nonexistent"));
-
-        wrapper = createWrapper(null, request);
-        assertEquals(999, wrapper.getIntHeader("header"));
     }
 
 }