浏览代码

SEC-2272: CsrfRequestDataValueProcessor support Spring 4 and Spring 3

Rob Winch 12 年之前
父节点
当前提交
26166ef6e8

+ 1 - 1
config/src/main/java/org/springframework/security/config/annotation/web/configuration/CsrfWebMvcConfiguration.java

@@ -35,6 +35,6 @@ class CsrfWebMvcConfiguration {
 
     @Bean
     public RequestDataValueProcessor requestDataValueProcessor() {
-        return new CsrfRequestDataValueProcessor();
+        return CsrfRequestDataValueProcessor.create();
     }
 }

+ 1 - 0
config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java

@@ -49,6 +49,7 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
         boolean webmvcPresent = ClassUtils.isPresent(DISPATCHER_SERVLET_CLASS_NAME, getClass().getClassLoader());
         if(webmvcPresent) {
             RootBeanDefinition beanDefinition = new RootBeanDefinition(CsrfRequestDataValueProcessor.class);
+            beanDefinition.setFactoryMethodName("create");
             BeanComponentDefinition componentDefinition =
                     new BeanComponentDefinition(beanDefinition, REQUEST_DATA_VALUE_PROCESSOR);
             pc.registerBeanComponent(componentDefinition);

+ 2 - 1
config/src/test/groovy/org/springframework/security/config/annotation/web/configurers/CsrfConfigurerTests.groovy

@@ -28,6 +28,7 @@ import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
 import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor;
 import org.springframework.security.web.util.RequestMatcher;
+import org.springframework.web.servlet.support.RequestDataValueProcessor;
 
 import spock.lang.Unroll;
 
@@ -64,7 +65,7 @@ class CsrfConfigurerTests extends BaseSpringSpec {
         when:
             loadConfig(CsrfAppliedDefaultConfig)
         then:
-            context.getBean(CsrfRequestDataValueProcessor)
+            context.getBean(RequestDataValueProcessor)
     }
 
     @Configuration

+ 2 - 1
config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy

@@ -32,6 +32,7 @@ import org.springframework.security.web.csrf.CsrfTokenRepository;
 import org.springframework.security.web.csrf.DefaultCsrfToken;
 import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor
 import org.springframework.security.web.util.RequestMatcher
+import org.springframework.web.servlet.support.RequestDataValueProcessor;
 
 import spock.lang.Unroll
 
@@ -85,7 +86,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
             }
             createAppContext()
         then:
-            appContext.getBean("requestDataValueProcessor",CsrfRequestDataValueProcessor)
+            appContext.getBean("requestDataValueProcessor",RequestDataValueProcessor)
     }
 
     def 'csrf custom AccessDeniedHandler'() {

+ 73 - 28
web/src/main/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessor.java

@@ -15,13 +15,18 @@
  */
 package org.springframework.security.web.servlet.support.csrf;
 
+import java.lang.reflect.InvocationHandler;
+import java.lang.reflect.Method;
+import java.lang.reflect.Proxy;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.regex.Pattern;
 
 import javax.servlet.http.HttpServletRequest;
 
 import org.springframework.security.web.csrf.CsrfToken;
+import org.springframework.util.ReflectionUtils;
 import org.springframework.web.servlet.support.RequestDataValueProcessor;
 
 /**
@@ -31,38 +36,35 @@ import org.springframework.web.servlet.support.RequestDataValueProcessor;
  * @author Rob Winch
  * @since 3.2
  */
-public final class CsrfRequestDataValueProcessor implements
-        RequestDataValueProcessor {
+public final class CsrfRequestDataValueProcessor {
+    private Pattern DISABLE_CSRF_TOKEN_PATTERN = Pattern.compile("(?i)^(GET|HEAD|TRACE|OPTIONS)$");
+
+    private String DISABLE_CSRF_TOKEN_ATTR = "DISABLE_CSRF_TOKEN_ATTR";
 
-    /*
-     * (non-Javadoc)
-     *
-     * @see org.springframework.web.servlet.support.RequestDataValueProcessor#
-     * processAction(javax.servlet.http.HttpServletRequest, java.lang.String)
-     */
     public String processAction(HttpServletRequest request, String action) {
         return action;
     }
 
-    /*
-     * (non-Javadoc)
-     *
-     * @see org.springframework.web.servlet.support.RequestDataValueProcessor#
-     * processFormFieldValue(javax.servlet.http.HttpServletRequest,
-     * java.lang.String, java.lang.String, java.lang.String)
-     */
+    public String processAction(HttpServletRequest request, String action, String method) {
+        if(method != null && DISABLE_CSRF_TOKEN_PATTERN.matcher(method).matches()) {
+            request.setAttribute(DISABLE_CSRF_TOKEN_ATTR, Boolean.TRUE);
+        } else {
+            request.removeAttribute(DISABLE_CSRF_TOKEN_ATTR);
+        }
+        return action;
+    }
+
     public String processFormFieldValue(HttpServletRequest request,
             String name, String value, String type) {
         return value;
     }
 
-    /*
-     * (non-Javadoc)
-     *
-     * @see org.springframework.web.servlet.support.RequestDataValueProcessor#
-     * getExtraHiddenFields(javax.servlet.http.HttpServletRequest)
-     */
     public Map<String, String> getExtraHiddenFields(HttpServletRequest request) {
+        if(Boolean.TRUE.equals(request.getAttribute(DISABLE_CSRF_TOKEN_ATTR))) {
+            request.removeAttribute(DISABLE_CSRF_TOKEN_ATTR);
+            return Collections.emptyMap();
+        }
+
         CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class
                 .getName());
         if (token == null) {
@@ -73,14 +75,57 @@ public final class CsrfRequestDataValueProcessor implements
         return hiddenFields;
     }
 
-    /*
-     * (non-Javadoc)
-     *
-     * @see
-     * org.springframework.web.servlet.support.RequestDataValueProcessor#processUrl
-     * (javax.servlet.http.HttpServletRequest, java.lang.String)
-     */
     public String processUrl(HttpServletRequest request, String url) {
         return url;
     }
+
+    CsrfRequestDataValueProcessor() {}
+
+    /**
+     * Creates an instance of {@link CsrfRequestDataValueProcessor} that
+     * implements {@link RequestDataValueProcessor}. This is necessary to ensure
+     * compatibility between Spring 3 and Spring 4.
+     *
+     * @return an instance of {@link CsrfRequestDataValueProcessor} that
+     * implements {@link RequestDataValueProcessor}
+     */
+    public static RequestDataValueProcessor create() {
+        CsrfRequestDataValueProcessor target= new CsrfRequestDataValueProcessor();
+        ClassLoader classLoader = CsrfRequestDataValueProcessor.class.getClassLoader();
+        Class<?>[] interfaces = new Class[] { RequestDataValueProcessor.class};
+        TypeConversionInterceptor interceptor = new TypeConversionInterceptor(target);
+        return (RequestDataValueProcessor) Proxy.newProxyInstance(classLoader, interfaces, interceptor);
+    }
+
+    /**
+     * An {@link InvocationHandler} that assumes the target has all the method
+     * defined on it, but the target does not implement the interface. This is
+     * necessary to deal with the fact that Spring 3 and Spring 4 have different
+     * definitions for the {@link RequestDataValueProcessor} interface.
+     *
+     * @author Rob Winch
+     */
+    private static class TypeConversionInterceptor implements InvocationHandler {
+
+        private final Object target;
+
+        public TypeConversionInterceptor(Object target) {
+            this.target = target;
+        }
+
+        /* (non-Javadoc)
+         * @see java.lang.reflect.InvocationHandler#invoke(java.lang.Object, java.lang.reflect.Method, java.lang.Object[])
+         */
+        public Object invoke(Object proxy, Method method, Object[] args)
+                throws Throwable {
+            Method methodToInvoke = ReflectionUtils.findMethod(target.getClass(), method.getName(), method.getParameterTypes());
+            return methodToInvoke.invoke(target, args);
+        }
+
+        @Override
+        public String toString() {
+            return "RequestDataValueProcessorInterceptor [target=" + target
+                    + "]";
+        }
+    }
 }

+ 61 - 9
web/src/test/java/org/springframework/security/web/servlet/support/csrf/CsrfRequestDataValueProcessorTests.java

@@ -17,15 +17,17 @@ package org.springframework.security.web.servlet.support.csrf;
 
 import static org.fest.assertions.Assertions.assertThat;
 
+import java.lang.reflect.Method;
 import java.util.HashMap;
 import java.util.Map;
 
 import org.junit.Before;
 import org.junit.Test;
 import org.springframework.mock.web.MockHttpServletRequest;
-import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.security.web.csrf.DefaultCsrfToken;
+import org.springframework.util.ReflectionUtils;
+import org.springframework.web.servlet.support.RequestDataValueProcessor;
 
 /**
  * @author Rob Winch
@@ -34,29 +36,62 @@ import org.springframework.security.web.csrf.DefaultCsrfToken;
 public class CsrfRequestDataValueProcessorTests {
     private MockHttpServletRequest request;
 
-    private MockHttpServletResponse response;
-
     private CsrfRequestDataValueProcessor processor;
 
+    private CsrfToken token;
+    private Map<String,String> expected = new HashMap<String,String>();
+
     @Before
     public void setup() {
         request = new MockHttpServletRequest();
-        response = new MockHttpServletResponse();
         processor = new CsrfRequestDataValueProcessor();
+
+        token = new DefaultCsrfToken("1", "a", "b");
+        request.setAttribute(CsrfToken.class.getName(), token);
+
+        expected.put(token.getParameterName(),token.getToken());
+    }
+
+    @Test
+    public void assertAllMethodsDeclared() {
+        Method[] expectedMethods = ReflectionUtils.getAllDeclaredMethods(RequestDataValueProcessor.class);
+        for(Method expected : expectedMethods) {
+            assertThat(ReflectionUtils.findMethod(CsrfRequestDataValueProcessor.class, expected.getName(), expected.getParameterTypes())).as("Expected to find "+ expected+ " defined on "+CsrfRequestDataValueProcessor.class).isNotNull();
+        }
     }
 
     @Test
     public void getExtraHiddenFieldsNoCsrfToken() {
+        request = new MockHttpServletRequest();
         assertThat(processor.getExtraHiddenFields(request)).isEmpty();
     }
 
     @Test
-    public void getExtraHiddenFieldsHasCsrfToken() {
-        CsrfToken token = new DefaultCsrfToken("1", "a", "b");
-        request.setAttribute(CsrfToken.class.getName(), token);
-        Map<String,String> expected = new HashMap<String,String>();
-        expected.put(token.getParameterName(),token.getToken());
+    public void getExtraHiddenFieldsHasCsrfTokenNoMethodSet() {
+        assertThat(processor.getExtraHiddenFields(request)).isEqualTo(expected);
+    }
+
+    @Test
+    public void getExtraHiddenFieldsHasCsrfToken_GET() {
+        processor.processAction(request, "action", "GET");
+        assertThat(processor.getExtraHiddenFields(request)).isEmpty();
+    }
+
+    @Test
+    public void getExtraHiddenFieldsHasCsrfToken_get() {
+        processor.processAction(request, "action", "get");
+        assertThat(processor.getExtraHiddenFields(request)).isEmpty();
+    }
 
+    @Test
+    public void getExtraHiddenFieldsHasCsrfToken_POST() {
+        processor.processAction(request, "action", "POST");
+        assertThat(processor.getExtraHiddenFields(request)).isEqualTo(expected);
+    }
+
+    @Test
+    public void getExtraHiddenFieldsHasCsrfToken_post() {
+        processor.processAction(request, "action", "post");
         assertThat(processor.getExtraHiddenFields(request)).isEqualTo(expected);
     }
 
@@ -66,6 +101,12 @@ public class CsrfRequestDataValueProcessorTests {
         assertThat(processor.processAction(request, action)).isEqualTo(action);
     }
 
+    @Test
+    public void processActionWithMethodArg() {
+        String action = "action";
+        assertThat(processor.processAction(request, action, null)).isEqualTo(action);
+    }
+
     @Test
     public void processFormFieldValue() {
         String value = "action";
@@ -77,4 +118,15 @@ public class CsrfRequestDataValueProcessorTests {
         String url = "url";
         assertThat(processor.processUrl(request, url)).isEqualTo(url);
     }
+
+    @Test
+    public void createGetExtraHiddenFieldsHasCsrfToken() {
+        CsrfToken token = new DefaultCsrfToken("1", "a", "b");
+        request.setAttribute(CsrfToken.class.getName(), token);
+        Map<String,String> expected = new HashMap<String,String>();
+        expected.put(token.getParameterName(),token.getToken());
+
+        RequestDataValueProcessor processor = CsrfRequestDataValueProcessor.create();
+        assertThat(processor.getExtraHiddenFields(request)).isEqualTo(expected);
+    }
 }