浏览代码

SEC-2246: HttpSessionRequestCache.getRequest casts to RequestCache

The method getRequest use to cast to DefaultRequestCache, but this
is not necessary.

Now the cast is to SavedRequest.
Rob Winch 12 年之前
父节点
当前提交
9133c33f1d

+ 1 - 1
web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java

@@ -49,7 +49,7 @@ public class HttpSessionRequestCache implements RequestCache {
         HttpSession session = currentRequest.getSession(false);
         HttpSession session = currentRequest.getSession(false);
 
 
         if (session != null) {
         if (session != null) {
-            return (DefaultSavedRequest) session.getAttribute(SAVED_REQUEST);
+            return (SavedRequest) session.getAttribute(SAVED_REQUEST);
         }
         }
 
 
         return null;
         return null;

+ 71 - 1
web/src/test/java/org/springframework/security/web/savedrequest/HttpSessionRequestCacheTests.java

@@ -1,12 +1,22 @@
 package org.springframework.security.web.savedrequest;
 package org.springframework.security.web.savedrequest;
 
 
-import static org.junit.Assert.*;
+import static org.fest.assertions.Assertions.assertThat;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 
 
+import java.util.Collection;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+
+import javax.servlet.http.Cookie;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 
 
 import org.junit.Test;
 import org.junit.Test;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.web.PortResolverImpl;
 import org.springframework.security.web.util.RequestMatcher;
 import org.springframework.security.web.util.RequestMatcher;
 
 
 /**
 /**
@@ -49,5 +59,65 @@ public class HttpSessionRequestCacheTests {
         assertNull(cache.getMatchingRequest(request, response));
         assertNull(cache.getMatchingRequest(request, response));
     }
     }
 
 
+    // SEC-2246
+    @Test
+    public void getRequestCustomNoClassCastException() throws Exception {
+        MockHttpServletRequest request = new MockHttpServletRequest("POST", "/destination");
+        MockHttpServletResponse response = new MockHttpServletResponse();
+        HttpSessionRequestCache cache = new HttpSessionRequestCache() {
+
+            @Override
+            public void saveRequest(HttpServletRequest request,
+                    HttpServletResponse response) {
+                request.getSession().setAttribute(SAVED_REQUEST, new CustomSavedRequest(new DefaultSavedRequest(request, new PortResolverImpl())));
+            }
+
+        };
+        cache.saveRequest(request,response);
+
+        cache.saveRequest(request, response);
+        assertThat(cache.getRequest(request, response)).isInstanceOf(CustomSavedRequest.class);
+    }
+
+    private static final class CustomSavedRequest implements SavedRequest {
+        private final SavedRequest delegate;
 
 
+        private CustomSavedRequest(SavedRequest delegate) {
+            this.delegate = delegate;
+        }
+
+        public String getRedirectUrl() {
+            return delegate.getRedirectUrl();
+        }
+
+        public List<Cookie> getCookies() {
+            return delegate.getCookies();
+        }
+
+        public String getMethod() {
+            return delegate.getMethod();
+        }
+
+        public List<String> getHeaderValues(String name) {
+            return delegate.getHeaderValues(name);
+        }
+
+        public Collection<String> getHeaderNames() {
+            return delegate.getHeaderNames();
+        }
+
+        public List<Locale> getLocales() {
+            return delegate.getLocales();
+        }
+
+        public String[] getParameterValues(String name) {
+            return delegate.getParameterValues(name);
+        }
+
+        public Map<String, String[]> getParameterMap() {
+            return delegate.getParameterMap();
+        }
+
+        private static final long serialVersionUID = 2426831999233621470L;
+    }
 }
 }