Explorar o código

SEC-1471: Allow use of a RequestMatcher with HttpSessionRequestCache to configure which requests should be cached by calls to saveRequest.

Also removed the justUseSavedRequestOnGet property, as this behaviour can be controlled by the RequestMatcher.
Luke Taylor %!s(int64=15) %!d(string=hai) anos
pai
achega
8df356de29

+ 15 - 8
web/src/main/java/org/springframework/security/web/savedrequest/HttpSessionRequestCache.java

@@ -9,9 +9,11 @@ import org.apache.commons.logging.LogFactory;
 import org.springframework.security.web.PortResolver;
 import org.springframework.security.web.PortResolverImpl;
 import org.springframework.security.web.WebAttributes;
+import org.springframework.security.web.util.AnyRequestMatcher;
+import org.springframework.security.web.util.RequestMatcher;
 
 /**
- * <tt>RequestCache</tt> which stores the <tt>SavedRequest</tt> in the HttpSession.
+ * {@code RequestCache} which stores the {@code SavedRequest} in the HttpSession.
  *
  * The {@link DefaultSavedRequest} class is used as the implementation.
  *
@@ -23,13 +25,13 @@ public class HttpSessionRequestCache implements RequestCache {
 
     private PortResolver portResolver = new PortResolverImpl();
     private boolean createSessionAllowed = true;
-    private boolean justUseSavedRequestOnGet;
+    private RequestMatcher requestMatcher = new AnyRequestMatcher();
 
     /**
      * Stores the current request, provided the configuration properties allow it.
      */
     public void saveRequest(HttpServletRequest request, HttpServletResponse response) {
-        if (!justUseSavedRequestOnGet || "GET".equals(request.getMethod())) {
+        if (requestMatcher.matches(request)) {
             DefaultSavedRequest savedRequest = new DefaultSavedRequest(request, portResolver);
 
             if (createSessionAllowed || request.getSession(false) != null) {
@@ -38,8 +40,9 @@ public class HttpSessionRequestCache implements RequestCache {
                 request.getSession().setAttribute(WebAttributes.SAVED_REQUEST, savedRequest);
                 logger.debug("DefaultSavedRequest added to Session: " + savedRequest);
             }
+        } else {
+            logger.debug("Request not saved as configured RequestMatcher did not match");
         }
-
     }
 
     public SavedRequest getRequest(HttpServletRequest currentRequest, HttpServletResponse response) {
@@ -79,11 +82,15 @@ public class HttpSessionRequestCache implements RequestCache {
     }
 
     /**
-     * If <code>true</code>, will only use <code>DefaultSavedRequest</code> to determine the target URL on successful
-     * authentication if the request that caused the authentication request was a GET. Defaults to false.
+     * Allows selective use of saved requests for a subset of requests. By default any request will be cached
+     * by the {@code saveRequest} method.
+     * <p>
+     * If set, only matching requests will be cached.
+     *
+     * @param requestMatcher a request matching strategy which defines which requests should be cached.
      */
-    public void setJustUseSavedRequestOnGet(boolean justUseSavedRequestOnGet) {
-        this.justUseSavedRequestOnGet = justUseSavedRequestOnGet;
+    public void setRequestMatcher(RequestMatcher requestMatcher) {
+        this.requestMatcher = requestMatcher;
     }
 
     /**

+ 0 - 16
web/src/test/java/org/springframework/security/web/access/ExceptionTranslationFilterTests.java

@@ -185,22 +185,6 @@ public class ExceptionTranslationFilterTests {
         assertEquals("http://www.example.com:8080/mycontext/secure/page.html", getSavedRequestUrl(request));
     }
 
-    @Test
-    public void testSavedRequestIsNotStoredForPostIfJustUseSaveRequestOnGetIsSet() throws Exception {
-        ExceptionTranslationFilter filter = new ExceptionTranslationFilter();
-        HttpSessionRequestCache requestCache = new HttpSessionRequestCache();
-        requestCache.setPortResolver(new MockPortResolver(8080, 8443));
-        requestCache.setJustUseSavedRequestOnGet(true);
-        filter.setRequestCache(requestCache);
-        filter.setAuthenticationEntryPoint(mockEntryPoint());
-        MockHttpServletRequest request = new MockHttpServletRequest();
-        FilterChain fc = mock(FilterChain.class);
-        doThrow(new BadCredentialsException("")).when(fc).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
-        request.setMethod("POST");
-        filter.doFilter(request, new MockHttpServletResponse(), fc);
-        assertTrue(request.getSession().getAttribute(WebAttributes.SAVED_REQUEST) == null);
-    }
-
     @Test(expected=IllegalArgumentException.class)
     public void testStartupDetectsMissingAuthenticationEntryPoint() throws Exception {
         ExceptionTranslationFilter filter = new ExceptionTranslationFilter();

+ 21 - 0
web/src/test/java/org/springframework/security/web/savedrequest/HttpSessionRequestCacheTests.java

@@ -2,10 +2,13 @@ package org.springframework.security.web.savedrequest;
 
 import static org.junit.Assert.*;
 
+import javax.servlet.http.HttpServletRequest;
+
 import org.junit.Test;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.web.WebAttributes;
+import org.springframework.security.web.util.RequestMatcher;
 
 /**
  *
@@ -30,4 +33,22 @@ public class HttpSessionRequestCacheTests {
 
     }
 
+    @Test
+    public void requestMatcherDefinesCorrectSubsetOfCachedRequests() throws Exception {
+        HttpSessionRequestCache cache = new HttpSessionRequestCache();
+        cache.setRequestMatcher(new RequestMatcher() {
+            public boolean matches(HttpServletRequest request) {
+                return request.getMethod().equals("GET");
+            }
+        });
+
+        MockHttpServletRequest request = new MockHttpServletRequest("POST", "/destination");
+        MockHttpServletResponse response = new MockHttpServletResponse();
+        cache.saveRequest(request, response);
+        assertNull(cache.getRequest(request, response));
+        assertNull(cache.getRequest(new MockHttpServletRequest(), new MockHttpServletResponse()));
+        assertNull(cache.getMatchingRequest(request, response));
+    }
+
+
 }