Explorar o código

Merge pull request #3759 from rwinch/gh-2953

Cache Control only written if not set
Rob Winch %!s(int64=9) %!d(string=hai) anos
pai
achega
0f2a3b18ce

+ 9 - 12
web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java

@@ -17,10 +17,9 @@ package org.springframework.security.web.context;
 
 import javax.servlet.http.HttpServletResponse;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.web.util.OnCommittedResponseWrapper;
 
 /**
  * Base class for response wrappers which encapsulate the logic for storing a security
@@ -40,10 +39,8 @@ import org.springframework.security.core.context.SecurityContextHolder;
  * @author Rob Winch
  * @since 3.0
  */
-public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends
-		OnCommittedResponseWrapper {
-	private final Log logger = LogFactory.getLog(getClass());
-
+public abstract class SaveContextOnUpdateOrErrorResponseWrapper
+		extends OnCommittedResponseWrapper {
 
 	private boolean contextSaved = false;
 	/* See SEC-1052 */
@@ -86,12 +83,12 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends
 	@Override
 	protected void onResponseCommitted() {
 		saveContext(SecurityContextHolder.getContext());
-		contextSaved = true;
+		this.contextSaved = true;
 	}
 
 	@Override
 	public final String encodeRedirectUrl(String url) {
-		if (disableUrlRewriting) {
+		if (this.disableUrlRewriting) {
 			return url;
 		}
 		return super.encodeRedirectUrl(url);
@@ -99,7 +96,7 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends
 
 	@Override
 	public final String encodeRedirectURL(String url) {
-		if (disableUrlRewriting) {
+		if (this.disableUrlRewriting) {
 			return url;
 		}
 		return super.encodeRedirectURL(url);
@@ -107,7 +104,7 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends
 
 	@Override
 	public final String encodeUrl(String url) {
-		if (disableUrlRewriting) {
+		if (this.disableUrlRewriting) {
 			return url;
 		}
 		return super.encodeUrl(url);
@@ -115,7 +112,7 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends
 
 	@Override
 	public final String encodeURL(String url) {
-		if (disableUrlRewriting) {
+		if (this.disableUrlRewriting) {
 			return url;
 		}
 		return super.encodeURL(url);
@@ -126,6 +123,6 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends
 	 * wrapper.
 	 */
 	public final boolean isContextSaved() {
-		return contextSaved;
+		return this.contextSaved;
 	}
 }

+ 50 - 8
web/src/main/java/org/springframework/security/web/header/HeaderWriterFilter.java

@@ -15,15 +15,17 @@
  */
 package org.springframework.security.web.header;
 
-import org.springframework.util.Assert;
-import org.springframework.web.filter.OncePerRequestFilter;
+import java.io.IOException;
+import java.util.List;
 
 import javax.servlet.FilterChain;
 import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
-import java.io.IOException;
-import java.util.*;
+
+import org.springframework.security.web.util.OnCommittedResponseWrapper;
+import org.springframework.util.Assert;
+import org.springframework.web.filter.OncePerRequestFilter;
 
 /**
  * Filter implementation to add headers to the current request. Can be useful to add
@@ -56,12 +58,52 @@ public class HeaderWriterFilter extends OncePerRequestFilter {
 	@Override
 	protected void doFilterInternal(HttpServletRequest request,
 			HttpServletResponse response, FilterChain filterChain)
-			throws ServletException, IOException {
+					throws ServletException, IOException {
 
-		for (HeaderWriter headerWriter : headerWriters) {
-			headerWriter.writeHeaders(request, response);
+		HeaderWriterResponse headerWriterResponse = new HeaderWriterResponse(request,
+				response, this.headerWriters);
+		try {
+			filterChain.doFilter(request, headerWriterResponse);
+		}
+		finally {
+			headerWriterResponse.writeHeaders();
 		}
-		filterChain.doFilter(request, response);
 	}
 
+	static class HeaderWriterResponse extends OnCommittedResponseWrapper {
+		private final HttpServletRequest request;
+		private final List<HeaderWriter> headerWriters;
+
+		HeaderWriterResponse(HttpServletRequest request, HttpServletResponse response,
+				List<HeaderWriter> headerWriters) {
+			super(response);
+			this.request = request;
+			this.headerWriters = headerWriters;
+		}
+
+		/*
+		 * (non-Javadoc)
+		 *
+		 * @see org.springframework.security.web.util.OnCommittedResponseWrapper#
+		 * onResponseCommitted()
+		 */
+		@Override
+		protected void onResponseCommitted() {
+			writeHeaders();
+			this.disableOnResponseCommitted();
+		}
+
+		protected void writeHeaders() {
+			if (isDisableOnResponseCommitted()) {
+				return;
+			}
+			for (HeaderWriter headerWriter : this.headerWriters) {
+				headerWriter.writeHeaders(this.request, getHttpResponse());
+			}
+		}
+
+		private HttpServletResponse getHttpResponse() {
+			return (HttpServletResponse) getResponse();
+		}
+	}
 }

+ 39 - 7
web/src/main/java/org/springframework/security/web/header/writers/CacheControlHeadersWriter.java

@@ -15,14 +15,20 @@
  */
 package org.springframework.security.web.header.writers;
 
+import java.lang.reflect.Method;
 import java.util.ArrayList;
 import java.util.List;
 
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
 import org.springframework.security.web.header.Header;
+import org.springframework.security.web.header.HeaderWriter;
+import org.springframework.util.ReflectionUtils;
 
 /**
- * A {@link StaticHeadersWriter} that inserts headers to prevent caching. Specifically it
- * adds the following headers:
+ * Inserts headers to prevent caching if no cache control headers have been specified.
+ * Specifically it adds the following headers:
  * <ul>
  * <li>Cache-Control: no-cache, no-store, max-age=0, must-revalidate</li>
  * <li>Pragma: no-cache</li>
@@ -32,21 +38,47 @@ import org.springframework.security.web.header.Header;
  * @author Rob Winch
  * @since 3.2
  */
-public final class CacheControlHeadersWriter extends StaticHeadersWriter {
+public final class CacheControlHeadersWriter implements HeaderWriter {
+	private static final String EXPIRES = "Expires";
+	private static final String PRAGMA = "Pragma";
+	private static final String CACHE_CONTROL = "Cache-Control";
+
+	private final Method getHeaderMethod;
+
+	private final HeaderWriter delegate;
 
 	/**
 	 * Creates a new instance
 	 */
 	public CacheControlHeadersWriter() {
-		super(createHeaders());
+		this.delegate = new StaticHeadersWriter(createHeaders());
+		this.getHeaderMethod = ReflectionUtils.findMethod(HttpServletResponse.class,
+				"getHeader", String.class);
+	}
+
+	@Override
+	public void writeHeaders(HttpServletRequest request, HttpServletResponse response) {
+		if (hasHeader(response, CACHE_CONTROL) || hasHeader(response, EXPIRES)
+				|| hasHeader(response, PRAGMA)) {
+			return;
+		}
+		this.delegate.writeHeaders(request, response);
+	}
+
+	private boolean hasHeader(HttpServletResponse response, String headerName) {
+		if (this.getHeaderMethod == null) {
+			return false;
+		}
+		return ReflectionUtils.invokeMethod(this.getHeaderMethod, response,
+				headerName) != null;
 	}
 
 	private static List<Header> createHeaders() {
 		List<Header> headers = new ArrayList<Header>(2);
-		headers.add(new Header("Cache-Control",
+		headers.add(new Header(CACHE_CONTROL,
 				"no-cache, no-store, max-age=0, must-revalidate"));
-		headers.add(new Header("Pragma", "no-cache"));
-		headers.add(new Header("Expires", "0"));
+		headers.add(new Header(PRAGMA, "no-cache"));
+		headers.add(new Header(EXPIRES, "0"));
 		return headers;
 	}
 }

+ 177 - 103
web/src/main/java/org/springframework/security/web/context/OnCommittedResponseWrapper.java → web/src/main/java/org/springframework/security/web/util/OnCommittedResponseWrapper.java

@@ -13,33 +13,31 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.springframework.security.web.context;
+package org.springframework.security.web.util;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.Locale;
 
 import javax.servlet.ServletOutputStream;
 import javax.servlet.http.HttpServletResponse;
 import javax.servlet.http.HttpServletResponseWrapper;
-import java.io.IOException;
-import java.io.PrintWriter;
-import java.util.Locale;
 
 /**
- * Base class for response wrappers which encapsulate the logic for handling an event when the
- * {@link javax.servlet.http.HttpServletResponse} is committed.
+ * Base class for response wrappers which encapsulate the logic for handling an event when
+ * the {@link javax.servlet.http.HttpServletResponse} is committed.
  *
  * @since 4.0.2
  * @author Rob Winch
  */
-abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
-	private final Log logger = LogFactory.getLog(getClass());
+public abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
 
 	private boolean disableOnCommitted;
 
 	/**
-	 * The Content-Length response header. If this is greater than 0, then once {@link #contentWritten} is larger than
-	 * or equal the response is considered committed.
+	 * The Content-Length response header. If this is greater than 0, then once
+	 * {@link #contentWritten} is larger than or equal the response is considered
+	 * committed.
 	 */
 	private long contentLength;
 
@@ -57,7 +55,7 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
 
 	@Override
 	public void addHeader(String name, String value) {
-		if("Content-Length".equalsIgnoreCase(name)) {
+		if ("Content-Length".equalsIgnoreCase(name)) {
 			setContentLength(Long.parseLong(value));
 		}
 		super.addHeader(name, value);
@@ -75,22 +73,33 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
 	}
 
 	/**
-	 * Invoke this method to disable invoking {@link OnCommittedResponseWrapper#onResponseCommitted()} when the {@link javax.servlet.http.HttpServletResponse} is
-	 * committed. This can be useful in the event that Async Web Requests are
-	 * made.
+	 * Invoke this method to disable invoking
+	 * {@link OnCommittedResponseWrapper#onResponseCommitted()} when the
+	 * {@link javax.servlet.http.HttpServletResponse} is committed. This can be useful in
+	 * the event that Async Web Requests are made.
 	 */
-	public void disableOnResponseCommitted() {
+	protected void disableOnResponseCommitted() {
 		this.disableOnCommitted = true;
 	}
 
 	/**
-	 * Implement the logic for handling the {@link javax.servlet.http.HttpServletResponse} being committed
+	 * Returns true if {@link #onResponseCommitted()} will be invoked when the response is
+	 * committed, else false.
+	 * @return if {@link #onResponseCommitted()} is enabled
+	 */
+	protected boolean isDisableOnResponseCommitted() {
+		return this.disableOnCommitted;
+	}
+
+	/**
+	 * Implement the logic for handling the {@link javax.servlet.http.HttpServletResponse}
+	 * being committed
 	 */
 	protected abstract void onResponseCommitted();
 
 	/**
-	 * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
-	 * superclass <code>sendError()</code>
+	 * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked
+	 * before calling the superclass <code>sendError()</code>
 	 */
 	@Override
 	public final void sendError(int sc) throws IOException {
@@ -99,8 +108,8 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
 	}
 
 	/**
-	 * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
-	 * superclass <code>sendError()</code>
+	 * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked
+	 * before calling the superclass <code>sendError()</code>
 	 */
 	@Override
 	public final void sendError(int sc, String msg) throws IOException {
@@ -109,8 +118,8 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
 	}
 
 	/**
-	 * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
-	 * superclass <code>sendRedirect()</code>
+	 * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked
+	 * before calling the superclass <code>sendRedirect()</code>
 	 */
 	@Override
 	public final void sendRedirect(String location) throws IOException {
@@ -119,8 +128,9 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
 	}
 
 	/**
-	 * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the calling
-	 * <code>getOutputStream().close()</code> or <code>getOutputStream().flush()</code>
+	 * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked
+	 * before calling the calling <code>getOutputStream().close()</code> or
+	 * <code>getOutputStream().flush()</code>
 	 */
 	@Override
 	public ServletOutputStream getOutputStream() throws IOException {
@@ -128,8 +138,9 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
 	}
 
 	/**
-	 * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
-	 * <code>getWriter().close()</code> or <code>getWriter().flush()</code>
+	 * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked
+	 * before calling the <code>getWriter().close()</code> or
+	 * <code>getWriter().flush()</code>
 	 */
 	@Override
 	public PrintWriter getWriter() throws IOException {
@@ -137,8 +148,8 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
 	}
 
 	/**
-	 * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
-	 * superclass <code>flushBuffer()</code>
+	 * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked
+	 * before calling the superclass <code>flushBuffer()</code>
 	 */
 	@Override
 	public void flushBuffer() throws IOException {
@@ -187,36 +198,38 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
 	}
 
 	/**
-	 * Adds the contentLengthToWrite to the total contentWritten size and checks to see if the response should be
-	 * written.
+	 * Adds the contentLengthToWrite to the total contentWritten size and checks to see if
+	 * the response should be written.
 	 *
 	 * @param contentLengthToWrite the size of the content that is about to be written.
 	 */
 	private void checkContentLength(long contentLengthToWrite) {
-		contentWritten += contentLengthToWrite;
-		boolean isBodyFullyWritten = contentLength > 0  && contentWritten >= contentLength;
+		this.contentWritten += contentLengthToWrite;
+		boolean isBodyFullyWritten = this.contentLength > 0
+				&& this.contentWritten >= this.contentLength;
 		int bufferSize = getBufferSize();
-		boolean requiresFlush = bufferSize > 0 && contentWritten >= bufferSize;
-		if(isBodyFullyWritten || requiresFlush) {
+		boolean requiresFlush = bufferSize > 0 && this.contentWritten >= bufferSize;
+		if (isBodyFullyWritten || requiresFlush) {
 			doOnResponseCommitted();
 		}
 	}
 
 	/**
 	 * Calls <code>onResponseCommmitted()</code> with the current contents as long as
-	 * {@link #disableOnResponseCommitted()()} was not invoked.
+	 * {@link #disableOnResponseCommitted()} was not invoked.
 	 */
 	private void doOnResponseCommitted() {
-		if(!disableOnCommitted) {
+		if (!this.disableOnCommitted) {
 			onResponseCommitted();
 			disableOnResponseCommitted();
 		}
 	}
 
 	/**
-	 * Ensures {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the prior to methods that commit the response. We delegate all methods
-	 * to the original {@link java.io.PrintWriter} to ensure that the behavior is as close to the original {@link java.io.PrintWriter}
-	 * as possible. See SEC-2039
+	 * Ensures {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before
+	 * calling the prior to methods that commit the response. We delegate all methods to
+	 * the original {@link java.io.PrintWriter} to ensure that the behavior is as close to
+	 * the original {@link java.io.PrintWriter} as possible. See SEC-2039
 	 * @author Rob Winch
 	 */
 	private class SaveContextPrintWriter extends PrintWriter {
@@ -227,197 +240,235 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
 			this.delegate = delegate;
 		}
 
+		@Override
 		public void flush() {
 			doOnResponseCommitted();
-			delegate.flush();
+			this.delegate.flush();
 		}
 
+		@Override
 		public void close() {
 			doOnResponseCommitted();
-			delegate.close();
+			this.delegate.close();
 		}
 
+		@Override
 		public int hashCode() {
-			return delegate.hashCode();
+			return this.delegate.hashCode();
 		}
 
+		@Override
 		public boolean equals(Object obj) {
-			return delegate.equals(obj);
+			return this.delegate.equals(obj);
 		}
 
+		@Override
 		public String toString() {
-			return getClass().getName() + "[delegate=" + delegate.toString() + "]";
+			return getClass().getName() + "[delegate=" + this.delegate.toString() + "]";
 		}
 
+		@Override
 		public boolean checkError() {
-			return delegate.checkError();
+			return this.delegate.checkError();
 		}
 
+		@Override
 		public void write(int c) {
 			trackContentLength(c);
-			delegate.write(c);
+			this.delegate.write(c);
 		}
 
+		@Override
 		public void write(char[] buf, int off, int len) {
 			checkContentLength(len);
-			delegate.write(buf, off, len);
+			this.delegate.write(buf, off, len);
 		}
 
+		@Override
 		public void write(char[] buf) {
 			trackContentLength(buf);
-			delegate.write(buf);
+			this.delegate.write(buf);
 		}
 
+		@Override
 		public void write(String s, int off, int len) {
 			checkContentLength(len);
-			delegate.write(s, off, len);
+			this.delegate.write(s, off, len);
 		}
 
+		@Override
 		public void write(String s) {
 			trackContentLength(s);
-			delegate.write(s);
+			this.delegate.write(s);
 		}
 
+		@Override
 		public void print(boolean b) {
 			trackContentLength(b);
-			delegate.print(b);
+			this.delegate.print(b);
 		}
 
+		@Override
 		public void print(char c) {
 			trackContentLength(c);
-			delegate.print(c);
+			this.delegate.print(c);
 		}
 
+		@Override
 		public void print(int i) {
 			trackContentLength(i);
-			delegate.print(i);
+			this.delegate.print(i);
 		}
 
+		@Override
 		public void print(long l) {
 			trackContentLength(l);
-			delegate.print(l);
+			this.delegate.print(l);
 		}
 
+		@Override
 		public void print(float f) {
 			trackContentLength(f);
-			delegate.print(f);
+			this.delegate.print(f);
 		}
 
+		@Override
 		public void print(double d) {
 			trackContentLength(d);
-			delegate.print(d);
+			this.delegate.print(d);
 		}
 
+		@Override
 		public void print(char[] s) {
 			trackContentLength(s);
-			delegate.print(s);
+			this.delegate.print(s);
 		}
 
+		@Override
 		public void print(String s) {
 			trackContentLength(s);
-			delegate.print(s);
+			this.delegate.print(s);
 		}
 
+		@Override
 		public void print(Object obj) {
 			trackContentLength(obj);
-			delegate.print(obj);
+			this.delegate.print(obj);
 		}
 
+		@Override
 		public void println() {
 			trackContentLengthLn();
-			delegate.println();
+			this.delegate.println();
 		}
 
+		@Override
 		public void println(boolean x) {
 			trackContentLength(x);
 			trackContentLengthLn();
-			delegate.println(x);
+			this.delegate.println(x);
 		}
 
+		@Override
 		public void println(char x) {
 			trackContentLength(x);
 			trackContentLengthLn();
-			delegate.println(x);
+			this.delegate.println(x);
 		}
 
+		@Override
 		public void println(int x) {
 			trackContentLength(x);
 			trackContentLengthLn();
-			delegate.println(x);
+			this.delegate.println(x);
 		}
 
+		@Override
 		public void println(long x) {
 			trackContentLength(x);
 			trackContentLengthLn();
-			delegate.println(x);
+			this.delegate.println(x);
 		}
 
+		@Override
 		public void println(float x) {
 			trackContentLength(x);
 			trackContentLengthLn();
-			delegate.println(x);
+			this.delegate.println(x);
 		}
 
+		@Override
 		public void println(double x) {
 			trackContentLength(x);
 			trackContentLengthLn();
-			delegate.println(x);
+			this.delegate.println(x);
 		}
 
+		@Override
 		public void println(char[] x) {
 			trackContentLength(x);
 			trackContentLengthLn();
-			delegate.println(x);
+			this.delegate.println(x);
 		}
 
+		@Override
 		public void println(String x) {
 			trackContentLength(x);
 			trackContentLengthLn();
-			delegate.println(x);
+			this.delegate.println(x);
 		}
 
+		@Override
 		public void println(Object x) {
 			trackContentLength(x);
 			trackContentLengthLn();
-			delegate.println(x);
+			this.delegate.println(x);
 		}
 
+		@Override
 		public PrintWriter printf(String format, Object... args) {
-			return delegate.printf(format, args);
+			return this.delegate.printf(format, args);
 		}
 
+		@Override
 		public PrintWriter printf(Locale l, String format, Object... args) {
-			return delegate.printf(l, format, args);
+			return this.delegate.printf(l, format, args);
 		}
 
+		@Override
 		public PrintWriter format(String format, Object... args) {
-			return delegate.format(format, args);
+			return this.delegate.format(format, args);
 		}
 
+		@Override
 		public PrintWriter format(Locale l, String format, Object... args) {
-			return delegate.format(l, format, args);
+			return this.delegate.format(l, format, args);
 		}
 
+		@Override
 		public PrintWriter append(CharSequence csq) {
 			checkContentLength(csq.length());
-			return delegate.append(csq);
+			return this.delegate.append(csq);
 		}
 
+		@Override
 		public PrintWriter append(CharSequence csq, int start, int end) {
 			checkContentLength(end - start);
-			return delegate.append(csq, start, end);
+			return this.delegate.append(csq, start, end);
 		}
 
+		@Override
 		public PrintWriter append(char c) {
 			trackContentLength(c);
-			return delegate.append(c);
+			return this.delegate.append(c);
 		}
 	}
 
 	/**
-	 * Ensures{@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling methods that commit the response. We delegate all methods
-	 * to the original {@link javax.servlet.ServletOutputStream} to ensure that the behavior is as close to the original {@link javax.servlet.ServletOutputStream}
-	 * as possible. See SEC-2039
+	 * Ensures{@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before
+	 * calling methods that commit the response. We delegate all methods to the original
+	 * {@link javax.servlet.ServletOutputStream} to ensure that the behavior is as close
+	 * to the original {@link javax.servlet.ServletOutputStream} as possible. See SEC-2039
 	 *
 	 * @author Rob Winch
 	 */
@@ -428,123 +479,146 @@ abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
 			this.delegate = delegate;
 		}
 
+		@Override
 		public void write(int b) throws IOException {
 			trackContentLength(b);
 			this.delegate.write(b);
 		}
 
+		@Override
 		public void flush() throws IOException {
 			doOnResponseCommitted();
-			delegate.flush();
+			this.delegate.flush();
 		}
 
+		@Override
 		public void close() throws IOException {
 			doOnResponseCommitted();
-			delegate.close();
+			this.delegate.close();
 		}
 
+		@Override
 		public int hashCode() {
-			return delegate.hashCode();
+			return this.delegate.hashCode();
 		}
 
+		@Override
 		public boolean equals(Object obj) {
-			return delegate.equals(obj);
+			return this.delegate.equals(obj);
 		}
 
+		@Override
 		public void print(boolean b) throws IOException {
 			trackContentLength(b);
-			delegate.print(b);
+			this.delegate.print(b);
 		}
 
+		@Override
 		public void print(char c) throws IOException {
 			trackContentLength(c);
-			delegate.print(c);
+			this.delegate.print(c);
 		}
 
+		@Override
 		public void print(double d) throws IOException {
 			trackContentLength(d);
-			delegate.print(d);
+			this.delegate.print(d);
 		}
 
+		@Override
 		public void print(float f) throws IOException {
 			trackContentLength(f);
-			delegate.print(f);
+			this.delegate.print(f);
 		}
 
+		@Override
 		public void print(int i) throws IOException {
 			trackContentLength(i);
-			delegate.print(i);
+			this.delegate.print(i);
 		}
 
+		@Override
 		public void print(long l) throws IOException {
 			trackContentLength(l);
-			delegate.print(l);
+			this.delegate.print(l);
 		}
 
+		@Override
 		public void print(String s) throws IOException {
 			trackContentLength(s);
-			delegate.print(s);
+			this.delegate.print(s);
 		}
 
+		@Override
 		public void println() throws IOException {
 			trackContentLengthLn();
-			delegate.println();
+			this.delegate.println();
 		}
 
+		@Override
 		public void println(boolean b) throws IOException {
 			trackContentLength(b);
 			trackContentLengthLn();
-			delegate.println(b);
+			this.delegate.println(b);
 		}
 
+		@Override
 		public void println(char c) throws IOException {
 			trackContentLength(c);
 			trackContentLengthLn();
-			delegate.println(c);
+			this.delegate.println(c);
 		}
 
+		@Override
 		public void println(double d) throws IOException {
 			trackContentLength(d);
 			trackContentLengthLn();
-			delegate.println(d);
+			this.delegate.println(d);
 		}
 
+		@Override
 		public void println(float f) throws IOException {
 			trackContentLength(f);
 			trackContentLengthLn();
-			delegate.println(f);
+			this.delegate.println(f);
 		}
 
+		@Override
 		public void println(int i) throws IOException {
 			trackContentLength(i);
 			trackContentLengthLn();
-			delegate.println(i);
+			this.delegate.println(i);
 		}
 
+		@Override
 		public void println(long l) throws IOException {
 			trackContentLength(l);
 			trackContentLengthLn();
-			delegate.println(l);
+			this.delegate.println(l);
 		}
 
+		@Override
 		public void println(String s) throws IOException {
 			trackContentLength(s);
 			trackContentLengthLn();
-			delegate.println(s);
+			this.delegate.println(s);
 		}
 
+		@Override
 		public void write(byte[] b) throws IOException {
 			trackContentLength(b);
-			delegate.write(b);
+			this.delegate.write(b);
 		}
 
+		@Override
 		public void write(byte[] b, int off, int len) throws IOException {
 			checkContentLength(len);
-			delegate.write(b, off, len);
+			this.delegate.write(b, off, len);
 		}
 
+		@Override
 		public String toString() {
-			return getClass().getName() + "[delegate=" + delegate.toString() + "]";
+			return getClass().getName() + "[delegate=" + this.delegate.toString() + "]";
 		}
 	}
 }

+ 45 - 9
web/src/test/java/org/springframework/security/web/header/HeaderWriterFilterTests.java

@@ -15,21 +15,32 @@
  */
 package org.springframework.security.web.header;
 
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.mockito.Mockito.verify;
-
+import java.io.IOException;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.runners.MockitoJUnitRunner;
+
 import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
-import org.springframework.security.web.header.HeaderWriter;
-import org.springframework.security.web.header.HeaderWriterFilter;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.mockito.Mockito.verifyZeroInteractions;
 
 /**
  * Tests for the {@code HeadersFilter}
@@ -60,8 +71,8 @@ public class HeaderWriterFilterTests {
 	@Test
 	public void additionalHeadersShouldBeAddedToTheResponse() throws Exception {
 		List<HeaderWriter> headerWriters = new ArrayList<HeaderWriter>();
-		headerWriters.add(writer1);
-		headerWriters.add(writer2);
+		headerWriters.add(this.writer1);
+		headerWriters.add(this.writer2);
 
 		HeaderWriterFilter filter = new HeaderWriterFilter(headerWriters);
 
@@ -71,9 +82,34 @@ public class HeaderWriterFilterTests {
 
 		filter.doFilter(request, response, filterChain);
 
-		verify(writer1).writeHeaders(request, response);
-		verify(writer2).writeHeaders(request, response);
+		verify(this.writer1).writeHeaders(request, response);
+		verify(this.writer2).writeHeaders(request, response);
 		assertThat(filterChain.getRequest()).isEqualTo(request); // verify the filterChain
 																	// continued
 	}
+
+	// gh-2953
+	@Test
+	public void headersDelayed() throws Exception {
+		HeaderWriterFilter filter = new HeaderWriterFilter(
+				Arrays.<HeaderWriter>asList(this.writer1));
+
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		MockHttpServletResponse response = new MockHttpServletResponse();
+
+		filter.doFilter(request, response, new FilterChain() {
+			@Override
+			public void doFilter(ServletRequest request, ServletResponse response)
+					throws IOException, ServletException {
+				verifyZeroInteractions(HeaderWriterFilterTests.this.writer1);
+
+				response.flushBuffer();
+
+				verify(HeaderWriterFilterTests.this.writer1).writeHeaders(
+						any(HttpServletRequest.class), any(HttpServletResponse.class));
+			}
+		});
+
+		verifyNoMoreInteractions(this.writer1);
+	}
 }

+ 83 - 11
web/src/test/java/org/springframework/security/web/header/writers/CacheControlHeadersWriterTests.java

@@ -15,19 +15,32 @@
  */
 package org.springframework.security.web.header.writers;
 
-import static org.assertj.core.api.Assertions.assertThat;
-
 import java.util.Arrays;
 
+import javax.servlet.http.HttpServletResponse;
+
 import org.junit.Before;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.powermock.core.classloader.annotations.PrepareOnlyThisForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.util.ReflectionUtils;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.when;
+import static org.powermock.api.mockito.PowerMockito.spy;
 
 /**
  * @author Rob Winch
  *
  */
+@RunWith(PowerMockRunner.class)
+@PrepareOnlyThisForTest(ReflectionUtils.class)
 public class CacheControlHeadersWriterTests {
 
 	private MockHttpServletRequest request;
@@ -38,20 +51,79 @@ public class CacheControlHeadersWriterTests {
 
 	@Before
 	public void setup() {
-		request = new MockHttpServletRequest();
-		response = new MockHttpServletResponse();
-		writer = new CacheControlHeadersWriter();
+		this.request = new MockHttpServletRequest();
+		this.response = new MockHttpServletResponse();
+		this.writer = new CacheControlHeadersWriter();
 	}
 
 	@Test
 	public void writeHeaders() {
-		writer.writeHeaders(request, response);
+		this.writer.writeHeaders(this.request, this.response);
+
+		assertThat(this.response.getHeaderNames().size()).isEqualTo(3);
+		assertThat(this.response.getHeaderValues("Cache-Control")).isEqualTo(
+				Arrays.asList("no-cache, no-store, max-age=0, must-revalidate"));
+		assertThat(this.response.getHeaderValues("Pragma"))
+				.isEqualTo(Arrays.asList("no-cache"));
+		assertThat(this.response.getHeaderValues("Expires"))
+				.isEqualTo(Arrays.asList("0"));
+	}
+
+	@Test
+	public void writeHeadersServlet25() {
+		spy(ReflectionUtils.class);
+		when(ReflectionUtils.findMethod(HttpServletResponse.class, "getHeader",
+				String.class)).thenReturn(null);
+		this.response = spy(this.response);
+		doThrow(NoSuchMethodError.class).when(this.response).getHeader(anyString());
+		this.writer = new CacheControlHeadersWriter();
 
-		assertThat(response.getHeaderNames().size()).isEqualTo(3);
-		assertThat(response.getHeaderValues("Cache-Control")).isEqualTo(
+		this.writer.writeHeaders(this.request, this.response);
+
+		assertThat(this.response.getHeaderNames().size()).isEqualTo(3);
+		assertThat(this.response.getHeaderValues("Cache-Control")).isEqualTo(
 				Arrays.asList("no-cache, no-store, max-age=0, must-revalidate"));
-		assertThat(response.getHeaderValues("Pragma")).isEqualTo(
-				Arrays.asList("no-cache"));
-		assertThat(response.getHeaderValues("Expires")).isEqualTo(Arrays.asList("0"));
+		assertThat(this.response.getHeaderValues("Pragma"))
+				.isEqualTo(Arrays.asList("no-cache"));
+		assertThat(this.response.getHeaderValues("Expires"))
+				.isEqualTo(Arrays.asList("0"));
+	}
+
+	// gh-2953
+	@Test
+	public void writeHeadersDisabledIfCacheControl() {
+		this.response.setHeader("Cache-Control", "max-age: 123");
+
+		this.writer.writeHeaders(this.request, this.response);
+
+		assertThat(this.response.getHeaderNames()).hasSize(1);
+		assertThat(this.response.getHeaderValues("Cache-Control"))
+				.containsOnly("max-age: 123");
+		assertThat(this.response.getHeaderValue("Pragma")).isNull();
+		assertThat(this.response.getHeaderValue("Expires")).isNull();
+	}
+
+	@Test
+	public void writeHeadersDisabledIfPragma() {
+		this.response.setHeader("Pragma", "mock");
+
+		this.writer.writeHeaders(this.request, this.response);
+
+		assertThat(this.response.getHeaderNames()).hasSize(1);
+		assertThat(this.response.getHeaderValues("Pragma")).containsOnly("mock");
+		assertThat(this.response.getHeaderValue("Expires")).isNull();
+		assertThat(this.response.getHeaderValue("Cache-Control")).isNull();
+	}
+
+	@Test
+	public void writeHeadersDisabledIfExpires() {
+		this.response.setHeader("Expires", "mock");
+
+		this.writer.writeHeaders(this.request, this.response);
+
+		assertThat(this.response.getHeaderNames()).hasSize(1);
+		assertThat(this.response.getHeaderValues("Expires")).containsOnly("mock");
+		assertThat(this.response.getHeaderValue("Cache-Control")).isNull();
+		assertThat(this.response.getHeaderValue("Pragma")).isNull();
 	}
 }

+ 3 - 1
web/src/test/java/org/springframework/security/web/context/OnCommittedResponseWrapperTests.java → web/src/test/java/org/springframework/security/web/util/OnCommittedResponseWrapperTests.java

@@ -13,7 +13,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.springframework.security.web.context;
+package org.springframework.security.web.util;
 
 import java.io.IOException;
 import java.io.PrintWriter;
@@ -25,6 +25,8 @@ import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.runners.MockitoJUnitRunner;
 
+import org.springframework.security.web.util.OnCommittedResponseWrapper;
+
 import javax.servlet.ServletOutputStream;
 import javax.servlet.http.HttpServletResponse;