2
0
Эх сурвалжийг харах

Cache headers only if no cache headers set

Fixes: gh-5004
Rob Winch 7 жил өмнө
parent
commit
ea3dd336aa

+ 2 - 2
config/src/test/java/org/springframework/security/config/annotation/web/HttpSecurityHeadersTests.java

@@ -71,8 +71,8 @@ public class HttpSecurityHeadersTests {
 		mockMvc.perform(get("/resources/file.js"))
 			.andExpect(status().isOk())
 			.andExpect(header().string(HttpHeaders.CACHE_CONTROL, "max-age=12345"))
-			.andExpect(header().string(HttpHeaders.PRAGMA, ""))
-			.andExpect(header().string(HttpHeaders.EXPIRES, ""));
+			.andExpect(header().doesNotExist(HttpHeaders.PRAGMA))
+			.andExpect(header().doesNotExist(HttpHeaders.EXPIRES));
 	}
 
 	@Test

+ 50 - 8
web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java

@@ -15,15 +15,17 @@
  */
 package org.springframework.security.web.header;
 
-import org.springframework.util.Assert;
-import org.springframework.web.filter.OncePerRequestFilter;
+import java.io.IOException;
+import java.util.List;
 
 import javax.servlet.FilterChain;
 import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
-import java.io.IOException;
-import java.util.*;
+
+import org.springframework.security.web.util.OnCommittedResponseWrapper;
+import org.springframework.util.Assert;
+import org.springframework.web.filter.OncePerRequestFilter;
 
 /**
  * Filter implementation to add headers to the current response. Can be useful to add
@@ -56,12 +58,52 @@ public class HeaderWriterFilter extends OncePerRequestFilter {
 	@Override
 	protected void doFilterInternal(HttpServletRequest request,
 			HttpServletResponse response, FilterChain filterChain)
-			throws ServletException, IOException {
+					throws ServletException, IOException {
 
-		for (HeaderWriter headerWriter : headerWriters) {
-			headerWriter.writeHeaders(request, response);
+		HeaderWriterResponse headerWriterResponse = new HeaderWriterResponse(request,
+				response, this.headerWriters);
+		try {
+			filterChain.doFilter(request, headerWriterResponse);
+		}
+		finally {
+			headerWriterResponse.writeHeaders();
 		}
-		filterChain.doFilter(request, response);
 	}
 
+	static class HeaderWriterResponse extends OnCommittedResponseWrapper {
+		private final HttpServletRequest request;
+		private final List<HeaderWriter> headerWriters;
+
+		HeaderWriterResponse(HttpServletRequest request, HttpServletResponse response,
+				List<HeaderWriter> headerWriters) {
+			super(response);
+			this.request = request;
+			this.headerWriters = headerWriters;
+		}
+
+		/*
+		 * (non-Javadoc)
+		 *
+		 * @see org.springframework.security.web.util.OnCommittedResponseWrapper#
+		 * onResponseCommitted()
+		 */
+		@Override
+		protected void onResponseCommitted() {
+			writeHeaders();
+			this.disableOnResponseCommitted();
+		}
+
+		protected void writeHeaders() {
+			if (isDisableOnResponseCommitted()) {
+				return;
+			}
+			for (HeaderWriter headerWriter : this.headerWriters) {
+				headerWriter.writeHeaders(this.request, getHttpResponse());
+			}
+		}
+
+		private HttpServletResponse getHttpResponse() {
+			return (HttpServletResponse) getResponse();
+		}
+	}
 }

+ 46 - 10
web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java

@@ -15,21 +15,32 @@
  */
 package org.springframework.security.web.header;
 
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.mockito.Mockito.verify;
-
+import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
-import org.mockito.junit.MockitoJUnitRunner;
+import org.mockito.runners.MockitoJUnitRunner;
+
 import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
-import org.springframework.security.web.header.HeaderWriter;
-import org.springframework.security.web.header.HeaderWriterFilter;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
 
 /**
  * Tests for the {@code HeadersFilter}
@@ -60,8 +71,8 @@ public class HeaderWriterFilterTests {
 	@Test
 	public void additionalHeadersShouldBeAddedToTheResponse() throws Exception {
 		List<HeaderWriter> headerWriters = new ArrayList<>();
-		headerWriters.add(writer1);
-		headerWriters.add(writer2);
+		headerWriters.add(this.writer1);
+		headerWriters.add(this.writer2);
 
 		HeaderWriterFilter filter = new HeaderWriterFilter(headerWriters);
 
@@ -71,9 +82,34 @@ public class HeaderWriterFilterTests {
 
 		filter.doFilter(request, response, filterChain);
 
-		verify(writer1).writeHeaders(request, response);
-		verify(writer2).writeHeaders(request, response);
+		verify(this.writer1).writeHeaders(request, response);
+		verify(this.writer2).writeHeaders(request, response);
 		assertThat(filterChain.getRequest()).isEqualTo(request); // verify the filterChain
 																	// continued
 	}
+
+	// gh-2953
+	@Test
+	public void headersDelayed() throws Exception {
+		HeaderWriterFilter filter = new HeaderWriterFilter(
+				Arrays.<HeaderWriter>asList(this.writer1));
+
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+
+		filter.doFilter(request, response, new FilterChain() {
+			@Override
+			public void doFilter(ServletRequest request, ServletResponse response)
+					throws IOException, ServletException {
+				verifyZeroInteractions(HeaderWriterFilterTests.this.writer1);
+
+				response.flushBuffer();
+
+				verify(HeaderWriterFilterTests.this.writer1).writeHeaders(
+						any(HttpServletRequest.class), any(HttpServletResponse.class));
+			}
+		});
+
+		verifyNoMoreInteractions(this.writer1);
+	}
 }

+ 13 - 9
web/src/test/java/org/springframework/security/web/header/writers/CacheControlHeadersWriterTests.java

@@ -58,11 +58,13 @@ public class CacheControlHeadersWriterTests {
 	public void writeHeaders() {
 		this.writer.writeHeaders(this.request, this.response);
 
-		assertThat(this.response.getHeaderNames()).hasSize(3);
-		assertThat(this.response.getHeaderValues("Cache-Control")).containsExactly(
+		assertThat(this.response.getHeaderNames().size()).isEqualTo(3);
+		assertThat(this.response.getHeaderValues("Cache-Control")).containsOnly(
 				"no-cache, no-store, max-age=0, must-revalidate");
-		assertThat(this.response.getHeaderValues("Pragma")).containsOnly("no-cache");
-		assertThat(this.response.getHeaderValues("Expires")).containsOnly("0");
+		assertThat(this.response.getHeaderValues("Pragma"))
+				.containsOnly("no-cache");
+		assertThat(this.response.getHeaderValues("Expires"))
+				.containsOnly("0");
 	}
 
 	@Test
@@ -76,11 +78,13 @@ public class CacheControlHeadersWriterTests {
 
 		this.writer.writeHeaders(this.request, this.response);
 
-		assertThat(this.response.getHeaderNames()).hasSize(3);
-		assertThat(this.response.getHeaderValues("Cache-Control")).containsExactly(
-				"no-cache, no-store, max-age=0, must-revalidate");
-		assertThat(this.response.getHeaderValues("Pragma")).containsOnly("no-cache");
-		assertThat(this.response.getHeaderValues("Expires")).containsOnly("0");
+		assertThat(this.response.getHeaderNames().size()).isEqualTo(3);
+		assertThat(this.response.getHeaderValues("Cache-Control"))
+				.containsOnly("no-cache, no-store, max-age=0, must-revalidate");
+		assertThat(this.response.getHeaderValues("Pragma"))
+				.containsOnly("no-cache");
+		assertThat(this.response.getHeaderValues("Expires"))
+				.containsOnly("0");
 	}
 
 	// gh-2953