浏览代码

Prevent Duplicate Cache Headers

Fixes gh-4199
Rob Winch 8 年之前
父节点
当前提交
168f4b8f70

+ 40 - 8
web/src/main/java/org/springframework/security/web/header/writers/CacheControlHeadersWriter.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2013 the original author or authors.
+ * Copyright 2002-2017 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,14 +15,20 @@
  */
 package org.springframework.security.web.header.writers;
 
+import java.lang.reflect.Method;
 import java.util.ArrayList;
 import java.util.List;
 
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
 import org.springframework.security.web.header.Header;
+import org.springframework.security.web.header.HeaderWriter;
+import org.springframework.util.ReflectionUtils;
 
 /**
- * A {@link StaticHeadersWriter} that inserts headers to prevent caching. Specifically it
- * adds the following headers:
+ * Inserts headers to prevent caching if no cache control headers have been specified.
+ * Specifically it adds the following headers:
  * <ul>
  * <li>Cache-Control: no-cache, no-store, max-age=0, must-revalidate</li>
  * <li>Pragma: no-cache</li>
@@ -32,21 +38,47 @@ import org.springframework.security.web.header.Header;
  * @author Rob Winch
  * @since 3.2
  */
-public final class CacheControlHeadersWriter extends StaticHeadersWriter {
+public final class CacheControlHeadersWriter implements HeaderWriter {
+	private static final String EXPIRES = "Expires";
+	private static final String PRAGMA = "Pragma";
+	private static final String CACHE_CONTROL = "Cache-Control";
+
+	private final Method getHeaderMethod;
+
+	private final HeaderWriter delegate;
 
 	/**
 	 * Creates a new instance
 	 */
 	public CacheControlHeadersWriter() {
-		super(createHeaders());
+		this.delegate = new StaticHeadersWriter(createHeaders());
+		this.getHeaderMethod = ReflectionUtils.findMethod(HttpServletResponse.class,
+				"getHeader", String.class);
+	}
+
+	@Override
+	public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
+		if (hasHeader(response, CACHE_CONTROL) || hasHeader(response, EXPIRES)
+				|| hasHeader(response, PRAGMA)) {
+			return;
+		}
+		this.delegate.writeHeaders(request, response);
+	}
+
+	private boolean hasHeader(HttpServletResponse response, String headerName) {
+		if (this.getHeaderMethod == null) {
+			return false;
+		}
+		return ReflectionUtils.invokeMethod(this.getHeaderMethod, response,
+				headerName) != null;
 	}
 
 	private static List<Header> createHeaders() {
 		List<Header> headers = new ArrayList<Header>(2);
-		headers.add(new Header("Cache-Control",
+		headers.add(new Header(CACHE_CONTROL,
 				"no-cache, no-store, max-age=0, must-revalidate"));
-		headers.add(new Header("Pragma", "no-cache"));
-		headers.add(new Header("Expires", "0"));
+		headers.add(new Header(PRAGMA, "no-cache"));
+		headers.add(new Header(EXPIRES, "0"));
 		return headers;
 	}
 }

+ 83 - 11
web/src/test/java/org/springframework/security/web/header/writers/CacheControlHeadersWriterTests.java

@@ -15,19 +15,32 @@
  */
 package org.springframework.security.web.header.writers;
 
-import static org.assertj.core.api.Assertions.assertThat;
-
 import java.util.Arrays;
 
+import javax.servlet.http.HttpServletResponse;
+
 import org.junit.Before;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PrepareOnlyThisForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.util.ReflectionUtils;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.when;
+import static org.powermock.api.mockito.PowerMockito.spy;
 
 /**
  * @author Rob Winch
  *
  */
+@RunWith(PowerMockRunner.class)
+@PrepareOnlyThisForTest(ReflectionUtils.class)
 public class CacheControlHeadersWriterTests {
 
 	private MockHttpServletRequest request;
@@ -38,20 +51,79 @@ public class CacheControlHeadersWriterTests {
 
 	@Before
 	public void setup() {
-		request = new MockHttpServletRequest();
-		response = new MockHttpServletResponse();
-		writer = new CacheControlHeadersWriter();
+		this.request = new MockHttpServletRequest();
+		this.response = new MockHttpServletResponse();
+		this.writer = new CacheControlHeadersWriter();
 	}
 
 	@Test
 	public void writeHeaders() {
-		writer.writeHeaders(request, response);
+		this.writer.writeHeaders(this.request, this.response);
+
+		assertThat(this.response.getHeaderNames().size()).isEqualTo(3);
+		assertThat(this.response.getHeaderValues("Cache-Control")).isEqualTo(
+				Arrays.asList("no-cache, no-store, max-age=0, must-revalidate"));
+		assertThat(this.response.getHeaderValues("Pragma"))
+				.isEqualTo(Arrays.asList("no-cache"));
+		assertThat(this.response.getHeaderValues("Expires"))
+				.isEqualTo(Arrays.asList("0"));
+	}
+
+	@Test
+	public void writeHeadersServlet25() {
+		spy(ReflectionUtils.class);
+		when(ReflectionUtils.findMethod(HttpServletResponse.class, "getHeader",
+				String.class)).thenReturn(null);
+		this.response = spy(this.response);
+		doThrow(NoSuchMethodError.class).when(this.response).getHeader(anyString());
+		this.writer = new CacheControlHeadersWriter();
 
-		assertThat(response.getHeaderNames().size()).isEqualTo(3);
-		assertThat(response.getHeaderValues("Cache-Control")).isEqualTo(
+		this.writer.writeHeaders(this.request, this.response);
+
+		assertThat(this.response.getHeaderNames().size()).isEqualTo(3);
+		assertThat(this.response.getHeaderValues("Cache-Control")).isEqualTo(
 				Arrays.asList("no-cache, no-store, max-age=0, must-revalidate"));
-		assertThat(response.getHeaderValues("Pragma")).isEqualTo(
-				Arrays.asList("no-cache"));
-		assertThat(response.getHeaderValues("Expires")).isEqualTo(Arrays.asList("0"));
+		assertThat(this.response.getHeaderValues("Pragma"))
+				.isEqualTo(Arrays.asList("no-cache"));
+		assertThat(this.response.getHeaderValues("Expires"))
+				.isEqualTo(Arrays.asList("0"));
+	}
+
+	// gh-2953
+	@Test
+	public void writeHeadersDisabledIfCacheControl() {
+		this.response.setHeader("Cache-Control", "max-age: 123");
+
+		this.writer.writeHeaders(this.request, this.response);
+
+		assertThat(this.response.getHeaderNames()).hasSize(1);
+		assertThat(this.response.getHeaderValues("Cache-Control"))
+				.containsOnly("max-age: 123");
+		assertThat(this.response.getHeaderValue("Pragma")).isNull();
+		assertThat(this.response.getHeaderValue("Expires")).isNull();
+	}
+
+	@Test
+	public void writeHeadersDisabledIfPragma() {
+		this.response.setHeader("Pragma", "mock");
+
+		this.writer.writeHeaders(this.request, this.response);
+
+		assertThat(this.response.getHeaderNames()).hasSize(1);
+		assertThat(this.response.getHeaderValues("Pragma")).containsOnly("mock");
+		assertThat(this.response.getHeaderValue("Expires")).isNull();
+		assertThat(this.response.getHeaderValue("Cache-Control")).isNull();
+	}
+
+	@Test
+	public void writeHeadersDisabledIfExpires() {
+		this.response.setHeader("Expires", "mock");
+
+		this.writer.writeHeaders(this.request, this.response);
+
+		assertThat(this.response.getHeaderNames()).hasSize(1);
+		assertThat(this.response.getHeaderValues("Expires")).containsOnly("mock");
+		assertThat(this.response.getHeaderValue("Cache-Control")).isNull();
+		assertThat(this.response.getHeaderValue("Pragma")).isNull();
 	}
 }