Przeglądaj źródła

SEC-2246: HttpSessionRequestCache.getRequest casts to RequestCache

The method getRequest use to cast to DefaultRequestCache, but this
is not necessary.

Now the cast is to SavedRequest.
Rob Winch 12 lat temu
rodzic
commit
9133c33f1d

+ 1 - 1
web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java

@@ -49,7 +49,7 @@ public class HttpSessionRequestCache implements RequestCache {
         HttpSession session = currentRequest.getSession(false);
 
         if (session != null) {
-            return (DefaultSavedRequest) session.getAttribute(SAVED_REQUEST);
+            return (SavedRequest) session.getAttribute(SAVED_REQUEST);
         }
 
         return null;

+ 71 - 1
web/src/test/java/org/springframework/security/web/savedrequest/HttpSessionRequestCacheTests.java

@@ -1,12 +1,22 @@
 package org.springframework.security.web.savedrequest;
 
-import static org.junit.Assert.*;
+import static org.fest.assertions.Assertions.assertThat;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
 
+import java.util.Collection;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+
+import javax.servlet.http.Cookie;
 import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 
 import org.junit.Test;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.web.PortResolverImpl;
 import org.springframework.security.web.util.RequestMatcher;
 
 /**
@@ -49,5 +59,65 @@ public class HttpSessionRequestCacheTests {
         assertNull(cache.getMatchingRequest(request, response));
     }
 
+    // SEC-2246
+    @Test
+    public void getRequestCustomNoClassCastException() throws Exception {
+        MockHttpServletRequest request = new MockHttpServletRequest("POST", "/destination");
+        MockHttpServletResponse response = new MockHttpServletResponse();
+        HttpSessionRequestCache cache = new HttpSessionRequestCache() {
+
+            @Override
+            public void saveRequest(HttpServletRequest request,
+                    HttpServletResponse response) {
+                request.getSession().setAttribute(SAVED_REQUEST, new CustomSavedRequest(new DefaultSavedRequest(request, new PortResolverImpl())));
+            }
+
+        };
+        cache.saveRequest(request,response);
+
+        cache.saveRequest(request, response);
+        assertThat(cache.getRequest(request, response)).isInstanceOf(CustomSavedRequest.class);
+    }
+
+    private static final class CustomSavedRequest implements SavedRequest {
+        private final SavedRequest delegate;
 
+        private CustomSavedRequest(SavedRequest delegate) {
+            this.delegate = delegate;
+        }
+
+        public String getRedirectUrl() {
+            return delegate.getRedirectUrl();
+        }
+
+        public List<Cookie> getCookies() {
+            return delegate.getCookies();
+        }
+
+        public String getMethod() {
+            return delegate.getMethod();
+        }
+
+        public List<String> getHeaderValues(String name) {
+            return delegate.getHeaderValues(name);
+        }
+
+        public Collection<String> getHeaderNames() {
+            return delegate.getHeaderNames();
+        }
+
+        public List<Locale> getLocales() {
+            return delegate.getLocales();
+        }
+
+        public String[] getParameterValues(String name) {
+            return delegate.getParameterValues(name);
+        }
+
+        public Map<String, String[]> getParameterMap() {
+            return delegate.getParameterMap();
+        }
+
+        private static final long serialVersionUID = 2426831999233621470L;
+    }
 }