Browse Source

SEC-2232: HeaderFactory to HeaderWriter

Rob Winch 12 years ago
parent
commit
8acd205486

+ 5 - 5
config/src/main/java/org/springframework/security/config/http/HeadersBeanDefinitionParser.java

@@ -22,7 +22,7 @@ import org.springframework.beans.factory.support.ManagedList;
 import org.springframework.beans.factory.xml.BeanDefinitionParser;
 import org.springframework.beans.factory.xml.ParserContext;
 import org.springframework.security.web.headers.HeadersFilter;
-import org.springframework.security.web.headers.StaticHeaderFactory;
+import org.springframework.security.web.headers.StaticHeadersWriter;
 import org.springframework.security.web.headers.frameoptions.*;
 import org.springframework.util.StringUtils;
 import org.springframework.util.xml.DomUtils;
@@ -85,7 +85,7 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser {
             if (StringUtils.hasText(headerFactoryRef)) {
                 headerFactories.add(new RuntimeBeanReference(headerFactoryRef));
             } else {
-                BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeaderFactory.class);
+                BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeadersWriter.class);
                 builder.addConstructorArgValue(headerElt.getAttribute(ATT_NAME));
                 builder.addConstructorArgValue(headerElt.getAttribute(ATT_VALUE));
                 headerFactories.add(builder.getBeanDefinition());
@@ -96,7 +96,7 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser {
     private void parseContentTypeOptionsElement(Element element) {
         Element contentTypeElt = DomUtils.getChildElementByTagName(element, CONTENT_TYPE_ELEMENT);
         if (contentTypeElt != null) {
-            BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeaderFactory.class);
+            BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeadersWriter.class);
             builder.addConstructorArgValue(CONTENT_TYPE_OPTIONS_HEADER);
             builder.addConstructorArgValue("nosniff");
             headerFactories.add(builder.getBeanDefinition());
@@ -104,7 +104,7 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser {
     }
 
     private void parseFrameOptionsElement(Element element, ParserContext parserContext) {
-        BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(FrameOptionsHeaderFactory.class);
+        BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(FrameOptionsHeaderWriter.class);
 
         Element frameElt = DomUtils.getChildElementByTagName(element, FRAME_OPTIONS_ELEMENT);
         if (frameElt != null) {
@@ -170,7 +170,7 @@ public class HeadersBeanDefinitionParser implements BeanDefinitionParser {
             } else if (!enabled && block) {
                 parserContext.getReaderContext().error("<xss-protection enabled=\"false\"/> does not allow block=\"true\".", xssElt);
             }
-            BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeaderFactory.class);
+            BeanDefinitionBuilder builder = BeanDefinitionBuilder.genericBeanDefinition(StaticHeadersWriter.class);
             builder.addConstructorArgValue(XSS_PROTECTION_HEADER);
             builder.addConstructorArgValue(value);
             headerFactories.add(builder.getBeanDefinition());

+ 1 - 1
docs/manual/src/docbook/appendix-namespace.xml

@@ -418,7 +418,7 @@
                 </section>
                 <section xml:id="nsa-header-ref">
                     <title><literal>header-ref</literal></title>
-                    <para>Reference to a custom implementation of the <classname>HeaderFactory</classname> interface.</para>
+                    <para>Reference to a custom implementation of the <classname>HeaderWriter</classname> interface.</para>
                 </section>
             </section>
             <section xml:id="nsa-header-parents">

+ 0 - 23
web/src/main/java/org/springframework/security/web/headers/HeaderFactory.java

@@ -1,23 +0,0 @@
-package org.springframework.security.web.headers;
-
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-
-/**
- * Contract for a factory that creates {@code Header} instances.
- *
- * @author Marten Deinum
- * @since 3.2
- * @see HeadersFilter
- */
-public interface HeaderFactory {
-
-    /**
-     * Create a {@code Header} instance.
-     *
-     * @param request the request
-     * @param response the response
-     * @return the created Header or <code>null</code>
-     */
-    Header create(HttpServletRequest request, HttpServletResponse response);
-}

+ 39 - 0
web/src/main/java/org/springframework/security/web/headers/HeaderWriter.java

@@ -0,0 +1,39 @@
+/*
+ * Copyright 2002-2013 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.web.headers;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+/**
+ * Contract for a factory that creates {@code Header} instances.
+ *
+ * @see HeadersFilter
+ *
+ * @author Marten Deinum
+ * @author Rob Winch
+ * @since 3.2
+ */
+public interface HeaderWriter {
+
+    /**
+     * Create a {@code Header} instance.
+     *
+     * @param request the request
+     * @param response the response
+     */
+    void writeHeaders(HttpServletRequest request, HttpServletResponse response);
+}

+ 5 - 25
web/src/main/java/org/springframework/security/web/headers/HeadersFilter.java

@@ -34,10 +34,10 @@ import java.util.*;
  */
 public class HeadersFilter extends OncePerRequestFilter {
 
-    /** Collection of HeaderFactory instances to produce Headers. */
-    private final List<HeaderFactory> factories;
+    /** Collection of {@link HeaderWriter} instances to  write out the headers to the response . */
+    private final List<HeaderWriter> factories;
 
-    public HeadersFilter(List<HeaderFactory> factories) {
+    public HeadersFilter(List<HeaderWriter> factories) {
         this.factories = factories;
     }
 
@@ -45,28 +45,8 @@ public class HeadersFilter extends OncePerRequestFilter {
     @Override
     protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException {
 
-        for (HeaderFactory factory : factories) {
-            Header header = factory.create(request, response);
-            if (header != null) {
-                String name = header.getName();
-                String[] values = header.getValues();
-                boolean first = true;
-                for (String value : values) {
-                    if (logger.isDebugEnabled()) {
-                        logger.debug("Adding header '" + name + "' with value '"+value +"'");
-                    }
-                    if (first) {
-                        response.setHeader(name, value);
-                        first = false;
-                    } else {
-                        response.addHeader(name, value);
-                    }
-                }
-            } else {
-                if (logger.isDebugEnabled()) {
-                    logger.debug("Factory produced no header.");
-                }
-            }
+        for (HeaderWriter factory : factories) {
+            factory.writeHeaders(request, response);
         }
         filterChain.doFilter(request, response);
     }

+ 7 - 5
web/src/main/java/org/springframework/security/web/headers/StaticHeaderFactory.java → web/src/main/java/org/springframework/security/web/headers/StaticHeadersWriter.java

@@ -6,23 +6,25 @@ import javax.servlet.http.HttpServletResponse;
 import org.springframework.util.Assert;
 
 /**
- * {@code HeaderFactory} implementation which returns the same {@code Header} instance.
+ * {@code HeaderWriter} implementation which writes the same {@code Header} instance.
  *
  * @author Marten Deinum
  * @since 3.2
  */
-public class StaticHeaderFactory implements HeaderFactory {
+public class StaticHeadersWriter implements HeaderWriter {
 
     private final Header header;
 
-    public StaticHeaderFactory(String name, String... values) {
+    public StaticHeadersWriter(String name, String... values) {
         Assert.hasText(name, "Header name is required");
         Assert.notEmpty(values, "Header values cannot be null or empty");
         Assert.noNullElements(values, "Header values cannot contain null values");
         header = new Header(name, values);
     }
 
-    public Header create(HttpServletRequest request, HttpServletResponse response) {
-        return header;
+    public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
+        for(String value : header.getValues()) {
+            response.addHeader(header.getName(), value);
+        }
     }
 }

+ 1 - 1
web/src/main/java/org/springframework/security/web/headers/frameoptions/AllowFromStrategy.java

@@ -3,7 +3,7 @@ package org.springframework.security.web.headers.frameoptions;
 import javax.servlet.http.HttpServletRequest;
 
 /**
- * Strategy interfaces used by the {@code FrameOptionsHeaderFactory} to determine the actual value to use for the
+ * Strategy interfaces used by the {@code FrameOptionsHeaderWriter} to determine the actual value to use for the
  * X-Frame-Options header when using the ALLOW-FROM directive.
  *
  * @author Marten Deinum

+ 8 - 9
web/src/main/java/org/springframework/security/web/headers/frameoptions/FrameOptionsHeaderFactory.java → web/src/main/java/org/springframework/security/web/headers/frameoptions/FrameOptionsHeaderWriter.java

@@ -1,13 +1,12 @@
 package org.springframework.security.web.headers.frameoptions;
 
-import org.springframework.security.web.headers.Header;
-import org.springframework.security.web.headers.HeaderFactory;
+import org.springframework.security.web.headers.HeaderWriter;
 
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
 /**
- * {@code HeaderFactory} implementation for the X-Frame-Options headers. When using the ALLOW-FROM directive the actual
+ * {@code HeaderWriter} implementation for the X-Frame-Options headers. When using the ALLOW-FROM directive the actual
  * value is determined by a {@code AllowFromStrategy}.
  *
  * @author Marten Deinum
@@ -15,7 +14,7 @@ import javax.servlet.http.HttpServletResponse;
  *
  * @see AllowFromStrategy
  */
-public class FrameOptionsHeaderFactory implements HeaderFactory {
+public class FrameOptionsHeaderWriter implements HeaderWriter {
 
     public static final String FRAME_OPTIONS_HEADER = "X-Frame-Options";
 
@@ -24,21 +23,21 @@ public class FrameOptionsHeaderFactory implements HeaderFactory {
     private final AllowFromStrategy allowFromStrategy;
     private final String mode;
 
-    public FrameOptionsHeaderFactory(String mode) {
+    public FrameOptionsHeaderWriter(String mode) {
         this(mode, new NullAllowFromStrategy());
     }
 
-    public FrameOptionsHeaderFactory(String mode, AllowFromStrategy allowFromStrategy) {
+    public FrameOptionsHeaderWriter(String mode, AllowFromStrategy allowFromStrategy) {
         this.mode=mode;
         this.allowFromStrategy=allowFromStrategy;
     }
 
-    public Header create(HttpServletRequest request, HttpServletResponse response) {
+    public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
         if (ALLOW_FROM.equals(mode)) {
             String value = allowFromStrategy.apply(request);
-            return new Header(FRAME_OPTIONS_HEADER, ALLOW_FROM + " " + value);
+            response.addHeader(FRAME_OPTIONS_HEADER, ALLOW_FROM + " " + value);
         } else {
-            return new Header(FRAME_OPTIONS_HEADER, mode);
+            response.addHeader(FRAME_OPTIONS_HEADER, mode);
         }
     }
 

+ 23 - 47
web/src/test/java/org/springframework/security/web/headers/HeadersFilterTest.java

@@ -15,32 +15,38 @@
  */
 package org.springframework.security.web.headers;
 
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.verify;
+
+import java.util.ArrayList;
+import java.util.List;
+
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
 import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-import java.util.*;
-
-import static org.hamcrest.CoreMatchers.is;
-import static org.junit.Assert.assertThat;
-import static org.junit.Assert.assertTrue;
-import static org.junit.matchers.JUnitMatchers.hasItems;
-
 /**
  * Tests for the {@code HeadersFilter}
  *
  * @author Marten Deinum
  * @since 3.2
  */
+@RunWith(MockitoJUnitRunner.class)
 public class HeadersFilterTest {
+    @Mock
+    private HeaderWriter writer1;
+
+    @Mock
+    private HeaderWriter writer2;
 
     @Test
     public void noHeadersConfigured() throws Exception {
-        List<HeaderFactory> factories = new ArrayList();
-        HeadersFilter filter = new HeadersFilter(factories);
+        List<HeaderWriter> headerWriters = new ArrayList<HeaderWriter>();
+        HeadersFilter filter = new HeadersFilter(headerWriters);
         MockHttpServletRequest request = new MockHttpServletRequest();
         MockHttpServletResponse response = new MockHttpServletResponse();
         MockFilterChain filterChain = new MockFilterChain();
@@ -52,18 +58,11 @@ public class HeadersFilterTest {
 
     @Test
     public void additionalHeadersShouldBeAddedToTheResponse() throws Exception {
-        List<HeaderFactory> factories = new ArrayList();
-        MockHeaderFactory factory1 = new MockHeaderFactory();
-        factory1.setName("X-Header1");
-        factory1.setValue("foo");
-        MockHeaderFactory factory2 = new MockHeaderFactory();
-        factory2.setName("X-Header2");
-        factory2.setValue("bar");
-
-        factories.add(factory1);
-        factories.add(factory2);
+        List<HeaderWriter> headerWriters = new ArrayList<HeaderWriter>();
+        headerWriters.add(writer1);
+        headerWriters.add(writer2);
 
-        HeadersFilter filter = new HeadersFilter(factories);
+        HeadersFilter filter = new HeadersFilter(headerWriters);
 
         MockHttpServletRequest request = new MockHttpServletRequest();
         MockHttpServletResponse response = new MockHttpServletResponse();
@@ -71,30 +70,7 @@ public class HeadersFilterTest {
 
         filter.doFilter(request, response, filterChain);
 
-        Collection<String> headerNames = response.getHeaderNames();
-        assertThat(headerNames.size(), is(2));
-        assertThat(headerNames, hasItems("X-Header1", "X-Header2"));
-        assertThat(response.getHeader("X-Header1"), is("foo"));
-        assertThat(response.getHeader("X-Header2"), is("bar"));
-
-    }
-
-    private static final class MockHeaderFactory implements HeaderFactory {
-
-        private String name;
-        private String value;
-
-        public Header create(HttpServletRequest request, HttpServletResponse response) {
-            return new Header(name, value);
-        }
-
-        public void setName(String name) {
-            this.name=name;
-        }
-
-        public void setValue(String value) {
-            this.value=value;
-        }
-
+        verify(writer1).writeHeaders(request, response);
+        verify(writer2).writeHeaders(request, response);
     }
 }

+ 0 - 26
web/src/test/java/org/springframework/security/web/headers/StaticHeaderFactoryTest.java

@@ -1,26 +0,0 @@
-package org.springframework.security.web.headers;
-
-import org.junit.Test;
-
-import static org.hamcrest.CoreMatchers.is;
-import static org.junit.Assert.assertSame;
-import static org.springframework.test.util.MatcherAssertionErrors.assertThat;
-
-/**
- * Test for the {@code StaticHeaderFactory}
- *
- * @author Marten Deinum
- * @since 3.2
- */
-public class StaticHeaderFactoryTest {
-
-    @Test
-    public void sameHeaderShouldBeReturned() {
-        StaticHeaderFactory factory = new StaticHeaderFactory("X-header", "foo");
-        Header header = factory.create(null, null);
-        assertThat(header.getName(), is("X-header"));
-        assertThat(header.getValues()[0], is("foo"));
-
-        assertSame(header, factory.create(null, null));
-    }
-}

+ 37 - 0
web/src/test/java/org/springframework/security/web/headers/StaticHeaderWriterTests.java

@@ -0,0 +1,37 @@
+package org.springframework.security.web.headers;
+
+import static org.fest.assertions.Assertions.assertThat;
+
+import java.util.Arrays;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+
+/**
+ * Test for the {@code StaticHeadersWriter}
+ *
+ * @author Marten Deinum
+ * @since 3.2
+ */
+public class StaticHeaderWriterTests {
+    private MockHttpServletRequest request;
+    private MockHttpServletResponse response;
+
+    @Before
+    public void setup() {
+        request = new MockHttpServletRequest();
+        response = new MockHttpServletResponse();
+    }
+
+    @Test
+    public void sameHeaderShouldBeReturned() {
+        String headerName = "X-header";
+        String headerValue = "foo";
+        StaticHeadersWriter factory = new StaticHeadersWriter(headerName, headerValue);
+
+        factory.writeHeaders(request, response);
+        assertThat(response.getHeaderValues(headerName)).isEqualTo(Arrays.asList(headerValue));
+    }
+}