瀏覽代碼

SEC-2973: Add OnCommittedResponseWrapper

This ensures that Spring Session & Security's logic for performing
a save on the response being committed can easily be kept in synch.
Further this ensures that the SecurityContext is now persisted when
the response body meets the content length.
Rob Winch 10 年之前
父節點
當前提交
fcc9a34356

+ 549 - 0
web/src/main/java/org/springframework/security/web/context/OnCommittedResponseWrapper.java

@@ -0,0 +1,549 @@
+/*
+ * Copyright 2002-2015 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+package org.springframework.security.web.context;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import javax.servlet.ServletOutputStream;
+import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpServletResponseWrapper;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.Locale;
+
+/**
+ * Base class for response wrappers which encapsulate the logic for handling an event when the
+ * {@link javax.servlet.http.HttpServletResponse} is committed.
+ *
+ * @since 4.0.2
+ * @author Rob Winch
+ */
+abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper {
+    private final Log logger = LogFactory.getLog(getClass());
+
+    private boolean disableOnCommitted;
+
+    /**
+     * The Content-Length response header. If this is greater than 0, then once {@link #contentWritten} is larger than
+     * or equal the response is considered committed.
+     */
+    private long contentLength;
+
+    /**
+     * The size of data written to the response body.
+     */
+    private long contentWritten;
+
+    /**
+     * @param response the response to be wrapped
+     */
+    public OnCommittedResponseWrapper(HttpServletResponse response) {
+        super(response);
+    }
+
+    @Override
+    public void addHeader(String name, String value) {
+        if("Content-Length".equalsIgnoreCase(name)) {
+            setContentLength(Long.parseLong(value));
+        }
+        super.addHeader(name, value);
+    }
+
+    @Override
+    public void setContentLength(int len) {
+        setContentLength((long) len);
+        super.setContentLength(len);
+    }
+
+    private void setContentLength(long len) {
+        this.contentLength = len;
+        checkContentLength(0);
+    }
+
+    /**
+     * Invoke this method to disable invoking {@link OnCommittedResponseWrapper#onResponseCommitted()} when the {@link javax.servlet.http.HttpServletResponse} is
+     * committed. This can be useful in the event that Async Web Requests are
+     * made.
+     */
+    public void disableOnResponseCommitted() {
+        this.disableOnCommitted = true;
+    }
+
+    /**
+     * Implement the logic for handling the {@link javax.servlet.http.HttpServletResponse} being committed
+     */
+    protected abstract void onResponseCommitted();
+
+    /**
+     * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
+     * superclass <code>sendError()</code>
+     */
+    @Override
+    public final void sendError(int sc) throws IOException {
+        doOnResponseCommitted();
+        super.sendError(sc);
+    }
+
+    /**
+     * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
+     * superclass <code>sendError()</code>
+     */
+    @Override
+    public final void sendError(int sc, String msg) throws IOException {
+        doOnResponseCommitted();
+        super.sendError(sc, msg);
+    }
+
+    /**
+     * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
+     * superclass <code>sendRedirect()</code>
+     */
+    @Override
+    public final void sendRedirect(String location) throws IOException {
+        doOnResponseCommitted();
+        super.sendRedirect(location);
+    }
+
+    /**
+     * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the calling
+     * <code>getOutputStream().close()</code> or <code>getOutputStream().flush()</code>
+     */
+    @Override
+    public ServletOutputStream getOutputStream() throws IOException {
+        return new SaveContextServletOutputStream(super.getOutputStream());
+    }
+
+    /**
+     * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
+     * <code>getWriter().close()</code> or <code>getWriter().flush()</code>
+     */
+    @Override
+    public PrintWriter getWriter() throws IOException {
+        return new SaveContextPrintWriter(super.getWriter());
+    }
+
+    /**
+     * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the
+     * superclass <code>flushBuffer()</code>
+     */
+    @Override
+    public void flushBuffer() throws IOException {
+        doOnResponseCommitted();
+        super.flushBuffer();
+    }
+
+    private void trackContentLength(boolean content) {
+        checkContentLength(content ? 4 : 5); // TODO Localization
+    }
+
+    private void trackContentLength(char content) {
+        checkContentLength(1);
+    }
+
+    private void trackContentLength(Object content) {
+        trackContentLength(String.valueOf(content));
+    }
+
+    private void trackContentLength(byte[] content) {
+        checkContentLength(content == null ? 0 : content.length);
+    }
+
+    private void trackContentLength(char[] content) {
+        checkContentLength(content == null ? 0 : content.length);
+    }
+
+    private void trackContentLength(int content) {
+        trackContentLength(String.valueOf(content));
+    }
+
+    private void trackContentLength(float content) {
+        trackContentLength(String.valueOf(content));
+    }
+
+    private void trackContentLength(double content) {
+        trackContentLength(String.valueOf(content));
+    }
+
+    private void trackContentLengthLn() {
+        trackContentLength("\r\n");
+    }
+
+    private void trackContentLength(String content) {
+        checkContentLength(content.length());
+    }
+
+    /**
+     * Adds the contentLengthToWrite to the total contentWritten size and checks to see if the response should be
+     * written.
+     *
+     * @param contentLengthToWrite the size of the content that is about to be written.
+     */
+    private void checkContentLength(long contentLengthToWrite) {
+        contentWritten += contentLengthToWrite;
+        boolean isBodyFullyWritten = contentLength > 0  && contentWritten >= contentLength;
+        int bufferSize = getBufferSize();
+        boolean requiresFlush = bufferSize > 0 && contentWritten >= bufferSize;
+        if(isBodyFullyWritten || requiresFlush) {
+            doOnResponseCommitted();
+        }
+    }
+
+    /**
+     * Calls <code>onResponseCommmitted()</code> with the current contents as long as
+     * {@link #disableOnResponseCommitted()()} was not invoked.
+     */
+    private void doOnResponseCommitted() {
+        if(!disableOnCommitted) {
+            onResponseCommitted();
+            disableOnResponseCommitted();
+        } else if(logger.isDebugEnabled()){
+            logger.debug("Skip invoking on");
+        }
+    }
+
+    /**
+     * Ensures {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the prior to methods that commit the response. We delegate all methods
+     * to the original {@link java.io.PrintWriter} to ensure that the behavior is as close to the original {@link java.io.PrintWriter}
+     * as possible. See SEC-2039
+     * @author Rob Winch
+     */
+    private class SaveContextPrintWriter extends PrintWriter {
+        private final PrintWriter delegate;
+
+        public SaveContextPrintWriter(PrintWriter delegate) {
+            super(delegate);
+            this.delegate = delegate;
+        }
+
+        public void flush() {
+            doOnResponseCommitted();
+            delegate.flush();
+        }
+
+        public void close() {
+            doOnResponseCommitted();
+            delegate.close();
+        }
+
+        public int hashCode() {
+            return delegate.hashCode();
+        }
+
+        public boolean equals(Object obj) {
+            return delegate.equals(obj);
+        }
+
+        public String toString() {
+            return getClass().getName() + "[delegate=" + delegate.toString() + "]";
+        }
+
+        public boolean checkError() {
+            return delegate.checkError();
+        }
+
+        public void write(int c) {
+            trackContentLength(c);
+            delegate.write(c);
+        }
+
+        public void write(char[] buf, int off, int len) {
+            checkContentLength(len);
+            delegate.write(buf, off, len);
+        }
+
+        public void write(char[] buf) {
+            trackContentLength(buf);
+            delegate.write(buf);
+        }
+
+        public void write(String s, int off, int len) {
+            checkContentLength(len);
+            delegate.write(s, off, len);
+        }
+
+        public void write(String s) {
+            trackContentLength(s);
+            delegate.write(s);
+        }
+
+        public void print(boolean b) {
+            trackContentLength(b);
+            delegate.print(b);
+        }
+
+        public void print(char c) {
+            trackContentLength(c);
+            delegate.print(c);
+        }
+
+        public void print(int i) {
+            trackContentLength(i);
+            delegate.print(i);
+        }
+
+        public void print(long l) {
+            trackContentLength(l);
+            delegate.print(l);
+        }
+
+        public void print(float f) {
+            trackContentLength(f);
+            delegate.print(f);
+        }
+
+        public void print(double d) {
+            trackContentLength(d);
+            delegate.print(d);
+        }
+
+        public void print(char[] s) {
+            trackContentLength(s);
+            delegate.print(s);
+        }
+
+        public void print(String s) {
+            trackContentLength(s);
+            delegate.print(s);
+        }
+
+        public void print(Object obj) {
+            trackContentLength(obj);
+            delegate.print(obj);
+        }
+
+        public void println() {
+            trackContentLengthLn();
+            delegate.println();
+        }
+
+        public void println(boolean x) {
+            trackContentLength(x);
+            trackContentLengthLn();
+            delegate.println(x);
+        }
+
+        public void println(char x) {
+            trackContentLength(x);
+            trackContentLengthLn();
+            delegate.println(x);
+        }
+
+        public void println(int x) {
+            trackContentLength(x);
+            trackContentLengthLn();
+            delegate.println(x);
+        }
+
+        public void println(long x) {
+            trackContentLength(x);
+            trackContentLengthLn();
+            delegate.println(x);
+        }
+
+        public void println(float x) {
+            trackContentLength(x);
+            trackContentLengthLn();
+            delegate.println(x);
+        }
+
+        public void println(double x) {
+            trackContentLength(x);
+            trackContentLengthLn();
+            delegate.println(x);
+        }
+
+        public void println(char[] x) {
+            trackContentLength(x);
+            trackContentLengthLn();
+            delegate.println(x);
+        }
+
+        public void println(String x) {
+            trackContentLength(x);
+            trackContentLengthLn();
+            delegate.println(x);
+        }
+
+        public void println(Object x) {
+            trackContentLength(x);
+            trackContentLengthLn();
+            delegate.println(x);
+        }
+
+        public PrintWriter printf(String format, Object... args) {
+            return delegate.printf(format, args);
+        }
+
+        public PrintWriter printf(Locale l, String format, Object... args) {
+            return delegate.printf(l, format, args);
+        }
+
+        public PrintWriter format(String format, Object... args) {
+            return delegate.format(format, args);
+        }
+
+        public PrintWriter format(Locale l, String format, Object... args) {
+            return delegate.format(l, format, args);
+        }
+
+        public PrintWriter append(CharSequence csq) {
+            checkContentLength(csq.length());
+            return delegate.append(csq);
+        }
+
+        public PrintWriter append(CharSequence csq, int start, int end) {
+            checkContentLength(end - start);
+            return delegate.append(csq, start, end);
+        }
+
+        public PrintWriter append(char c) {
+            trackContentLength(c);
+            return delegate.append(c);
+        }
+    }
+
+    /**
+     * Ensures{@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling methods that commit the response. We delegate all methods
+     * to the original {@link javax.servlet.ServletOutputStream} to ensure that the behavior is as close to the original {@link javax.servlet.ServletOutputStream}
+     * as possible. See SEC-2039
+     *
+     * @author Rob Winch
+     */
+    private class SaveContextServletOutputStream extends ServletOutputStream {
+        private final ServletOutputStream delegate;
+
+        public SaveContextServletOutputStream(ServletOutputStream delegate) {
+            this.delegate = delegate;
+        }
+
+        public void write(int b) throws IOException {
+            trackContentLength(b);
+            this.delegate.write(b);
+        }
+
+        public void flush() throws IOException {
+            doOnResponseCommitted();
+            delegate.flush();
+        }
+
+        public void close() throws IOException {
+            doOnResponseCommitted();
+            delegate.close();
+        }
+
+        public int hashCode() {
+            return delegate.hashCode();
+        }
+
+        public boolean equals(Object obj) {
+            return delegate.equals(obj);
+        }
+
+        public void print(boolean b) throws IOException {
+            trackContentLength(b);
+            delegate.print(b);
+        }
+
+        public void print(char c) throws IOException {
+            trackContentLength(c);
+            delegate.print(c);
+        }
+
+        public void print(double d) throws IOException {
+            trackContentLength(d);
+            delegate.print(d);
+        }
+
+        public void print(float f) throws IOException {
+            trackContentLength(f);
+            delegate.print(f);
+        }
+
+        public void print(int i) throws IOException {
+            trackContentLength(i);
+            delegate.print(i);
+        }
+
+        public void print(long l) throws IOException {
+            trackContentLength(l);
+            delegate.print(l);
+        }
+
+        public void print(String s) throws IOException {
+            trackContentLength(s);
+            delegate.print(s);
+        }
+
+        public void println() throws IOException {
+            trackContentLengthLn();
+            delegate.println();
+        }
+
+        public void println(boolean b) throws IOException {
+            trackContentLength(b);
+            trackContentLengthLn();
+            delegate.println(b);
+        }
+
+        public void println(char c) throws IOException {
+            trackContentLength(c);
+            trackContentLengthLn();
+            delegate.println(c);
+        }
+
+        public void println(double d) throws IOException {
+            trackContentLength(d);
+            trackContentLengthLn();
+            delegate.println(d);
+        }
+
+        public void println(float f) throws IOException {
+            trackContentLength(f);
+            trackContentLengthLn();
+            delegate.println(f);
+        }
+
+        public void println(int i) throws IOException {
+            trackContentLength(i);
+            trackContentLengthLn();
+            delegate.println(i);
+        }
+
+        public void println(long l) throws IOException {
+            trackContentLength(l);
+            trackContentLengthLn();
+            delegate.println(l);
+        }
+
+        public void println(String s) throws IOException {
+            trackContentLength(s);
+            trackContentLengthLn();
+            delegate.println(s);
+        }
+
+        public void write(byte[] b) throws IOException {
+            trackContentLength(b);
+            delegate.write(b);
+        }
+
+        public void write(byte[] b, int off, int len) throws IOException {
+            checkContentLength(len);
+            delegate.write(b, off, len);
+        }
+
+        public String toString() {
+            return getClass().getName() + "[delegate=" + delegate.toString() + "]";
+        }
+    }
+}

+ 28 - 364
web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java

@@ -12,13 +12,7 @@
  */
 package org.springframework.security.web.context;
 
-import java.io.IOException;
-import java.io.PrintWriter;
-import java.util.Locale;
-
-import javax.servlet.ServletOutputStream;
 import javax.servlet.http.HttpServletResponse;
-import javax.servlet.http.HttpServletResponseWrapper;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -26,11 +20,13 @@ import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 
 /**
- * Base class for response wrappers which encapsulate the logic for storing a security context and which store the
- * <code>SecurityContext</code> when a <code>sendError()</code>, <code>sendRedirect</code>,
- * <code>getOutputStream().close()</code>, <code>getOutputStream().flush()</code>, <code>getWriter().close()</code>, or
- * <code>getWriter().flush()</code> happens on the same thread that this
- * {@link SaveContextOnUpdateOrErrorResponseWrapper} was created. See issue SEC-398 and SEC-2005.
+ * Base class for response wrappers which encapsulate the logic for storing a security
+ * context and which store the <code>SecurityContext</code> when a
+ * <code>sendError()</code>, <code>sendRedirect</code>,
+ * <code>getOutputStream().close()</code>, <code>getOutputStream().flush()</code>,
+ * <code>getWriter().close()</code>, or <code>getWriter().flush()</code> happens on the
+ * same thread that this {@link SaveContextOnUpdateOrErrorResponseWrapper} was created.
+ * See issue SEC-398 and SEC-2005.
  * <p>
  * Sub-classes should implement the {@link #saveContext(SecurityContext context)} method.
  * <p>
@@ -41,33 +37,35 @@ import org.springframework.security.core.context.SecurityContextHolder;
  * @author Rob Winch
  * @since 3.0
  */
-public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends HttpServletResponseWrapper {
+public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends
+        OnCommittedResponseWrapper {
     private final Log logger = LogFactory.getLog(getClass());
 
-    private boolean disableSaveOnResponseCommitted;
 
     private boolean contextSaved = false;
     /* See SEC-1052 */
     private final boolean disableUrlRewriting;
 
     /**
-     * @param response              the response to be wrapped
-     * @param disableUrlRewriting   turns the URL encoding methods into null operations, preventing the use
-     *                              of URL rewriting to add the session identifier as a URL parameter.
+     * @param response the response to be wrapped
+     * @param disableUrlRewriting turns the URL encoding methods into null operations,
+     * preventing the use of URL rewriting to add the session identifier as a URL
+     * parameter.
      */
-    public SaveContextOnUpdateOrErrorResponseWrapper(HttpServletResponse response, boolean disableUrlRewriting) {
+    public SaveContextOnUpdateOrErrorResponseWrapper(HttpServletResponse response,
+            boolean disableUrlRewriting) {
         super(response);
         this.disableUrlRewriting = disableUrlRewriting;
     }
 
     /**
-     * Invoke this method to disable automatic saving of the
-     * {@link SecurityContext} when the {@link HttpServletResponse} is
-     * committed. This can be useful in the event that Async Web Requests are
-     * made which may no longer contain the {@link SecurityContext} on it.
+     * Invoke this method to disable automatic saving of the {@link SecurityContext} when
+     * the {@link HttpServletResponse} is committed. This can be useful in the event that
+     * Async Web Requests are made which may no longer contain the {@link SecurityContext}
+     * on it.
      */
     public void disableSaveOnResponseCommitted() {
-        this.disableSaveOnResponseCommitted = true;
+        disableOnResponseCommitted();
     }
 
     /**
@@ -77,76 +75,15 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends HttpServ
      */
     protected abstract void saveContext(SecurityContext context);
 
-    /**
-     * Makes sure the session is updated before calling the
-     * superclass <code>sendError()</code>
-     */
-    @Override
-    public final void sendError(int sc) throws IOException {
-        doSaveContext();
-        super.sendError(sc);
-    }
-
-    /**
-     * Makes sure the session is updated before calling the
-     * superclass <code>sendError()</code>
-     */
-    @Override
-    public final void sendError(int sc, String msg) throws IOException {
-        doSaveContext();
-        super.sendError(sc, msg);
-    }
-
-    /**
-     * Makes sure the context is stored before calling the
-     * superclass <code>sendRedirect()</code>
-     */
-    @Override
-    public final void sendRedirect(String location) throws IOException {
-        doSaveContext();
-        super.sendRedirect(location);
-    }
-
-    /**
-     * Makes sure the context is stored before calling <code>getOutputStream().close()</code> or
-     * <code>getOutputStream().flush()</code>
-     */
-    @Override
-    public ServletOutputStream getOutputStream() throws IOException {
-        return new SaveContextServletOutputStream(super.getOutputStream());
-    }
-
-    /**
-     * Makes sure the context is stored before calling <code>getWriter().close()</code> or
-     * <code>getWriter().flush()</code>
-     */
-    @Override
-    public PrintWriter getWriter() throws IOException {
-        return new SaveContextPrintWriter(super.getWriter());
-    }
-
-    /**
-     * Makes sure the context is stored before calling the
-     * superclass <code>flushBuffer()</code>
-     */
-    @Override
-    public void flushBuffer() throws IOException {
-        doSaveContext();
-        super.flushBuffer();
-    }
-
     /**
      * Calls <code>saveContext()</code> with the current contents of the
-     * <tt>SecurityContextHolder</tt> as long as
-     * {@link #disableSaveOnResponseCommitted()()} was not invoked.
+     * <tt>SecurityContextHolder</tt> as long as {@link #disableSaveOnResponseCommitted()
+     * ()} was not invoked.
      */
-    private void doSaveContext() {
-        if(!disableSaveOnResponseCommitted) {
-            saveContext(SecurityContextHolder.getContext());
-            contextSaved = true;
-        } else if(logger.isDebugEnabled()){
-            logger.debug("Skip saving SecurityContext since saving on response commited is disabled");
-        }
+    @Override
+    protected void onResponseCommitted() {
+        saveContext(SecurityContextHolder.getContext());
+        contextSaved = true;
     }
 
     @Override
@@ -182,283 +119,10 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends HttpServ
     }
 
     /**
-     * Tells if the response wrapper has called <code>saveContext()</code> because of this wrapper.
+     * Tells if the response wrapper has called <code>saveContext()</code> because of this
+     * wrapper.
      */
     public final boolean isContextSaved() {
         return contextSaved;
     }
-
-    /**
-     * Ensures the {@link SecurityContext} is updated prior to methods that commit the response. We delegate all methods
-     * to the original {@link PrintWriter} to ensure that the behavior is as close to the original {@link PrintWriter}
-     * as possible. See SEC-2039
-     * @author Rob Winch
-     */
-    private class SaveContextPrintWriter extends PrintWriter {
-        private final PrintWriter delegate;
-
-        public SaveContextPrintWriter(PrintWriter delegate) {
-            super(delegate);
-            this.delegate = delegate;
-        }
-
-        public void flush() {
-            doSaveContext();
-            delegate.flush();
-        }
-
-        public void close() {
-            doSaveContext();
-            delegate.close();
-        }
-
-        public int hashCode() {
-            return delegate.hashCode();
-        }
-
-        public boolean equals(Object obj) {
-            return delegate.equals(obj);
-        }
-
-        public String toString() {
-            return getClass().getName() + "[delegate=" + delegate.toString() + "]";
-        }
-
-        public boolean checkError() {
-            return delegate.checkError();
-        }
-
-        public void write(int c) {
-            delegate.write(c);
-        }
-
-        public void write(char[] buf, int off, int len) {
-            delegate.write(buf, off, len);
-        }
-
-        public void write(char[] buf) {
-            delegate.write(buf);
-        }
-
-        public void write(String s, int off, int len) {
-            delegate.write(s, off, len);
-        }
-
-        public void write(String s) {
-            delegate.write(s);
-        }
-
-        public void print(boolean b) {
-            delegate.print(b);
-        }
-
-        public void print(char c) {
-            delegate.print(c);
-        }
-
-        public void print(int i) {
-            delegate.print(i);
-        }
-
-        public void print(long l) {
-            delegate.print(l);
-        }
-
-        public void print(float f) {
-            delegate.print(f);
-        }
-
-        public void print(double d) {
-            delegate.print(d);
-        }
-
-        public void print(char[] s) {
-            delegate.print(s);
-        }
-
-        public void print(String s) {
-            delegate.print(s);
-        }
-
-        public void print(Object obj) {
-            delegate.print(obj);
-        }
-
-        public void println() {
-            delegate.println();
-        }
-
-        public void println(boolean x) {
-            delegate.println(x);
-        }
-
-        public void println(char x) {
-            delegate.println(x);
-        }
-
-        public void println(int x) {
-            delegate.println(x);
-        }
-
-        public void println(long x) {
-            delegate.println(x);
-        }
-
-        public void println(float x) {
-            delegate.println(x);
-        }
-
-        public void println(double x) {
-            delegate.println(x);
-        }
-
-        public void println(char[] x) {
-            delegate.println(x);
-        }
-
-        public void println(String x) {
-            delegate.println(x);
-        }
-
-        public void println(Object x) {
-            delegate.println(x);
-        }
-
-        public PrintWriter printf(String format, Object... args) {
-            return delegate.printf(format, args);
-        }
-
-        public PrintWriter printf(Locale l, String format, Object... args) {
-            return delegate.printf(l, format, args);
-        }
-
-        public PrintWriter format(String format, Object... args) {
-            return delegate.format(format, args);
-        }
-
-        public PrintWriter format(Locale l, String format, Object... args) {
-            return delegate.format(l, format, args);
-        }
-
-        public PrintWriter append(CharSequence csq) {
-            return delegate.append(csq);
-        }
-
-        public PrintWriter append(CharSequence csq, int start, int end) {
-            return delegate.append(csq, start, end);
-        }
-
-        public PrintWriter append(char c) {
-            return delegate.append(c);
-        }
-    }
-
-    /**
-     * Ensures the {@link SecurityContext} is updated prior to methods that commit the response. We delegate all methods
-     * to the original {@link ServletOutputStream} to ensure that the behavior is as close to the original {@link ServletOutputStream}
-     * as possible. See SEC-2039
-     *
-     * @author Rob Winch
-     */
-    private class SaveContextServletOutputStream extends ServletOutputStream {
-        private final ServletOutputStream delegate;
-
-        public SaveContextServletOutputStream(ServletOutputStream delegate) {
-            this.delegate = delegate;
-        }
-
-        public void write(int b) throws IOException {
-            this.delegate.write(b);
-        }
-
-        public void flush() throws IOException {
-            doSaveContext();
-            delegate.flush();
-        }
-
-        public void close() throws IOException {
-            doSaveContext();
-            delegate.close();
-        }
-
-        public int hashCode() {
-            return delegate.hashCode();
-        }
-
-        public boolean equals(Object obj) {
-            return delegate.equals(obj);
-        }
-
-        public void print(boolean b) throws IOException {
-            delegate.print(b);
-        }
-
-        public void print(char c) throws IOException {
-            delegate.print(c);
-        }
-
-        public void print(double d) throws IOException {
-            delegate.print(d);
-        }
-
-        public void print(float f) throws IOException {
-            delegate.print(f);
-        }
-
-        public void print(int i) throws IOException {
-            delegate.print(i);
-        }
-
-        public void print(long l) throws IOException {
-            delegate.print(l);
-        }
-
-        public void print(String arg0) throws IOException {
-            delegate.print(arg0);
-        }
-
-        public void println() throws IOException {
-            delegate.println();
-        }
-
-        public void println(boolean b) throws IOException {
-            delegate.println(b);
-        }
-
-        public void println(char c) throws IOException {
-            delegate.println(c);
-        }
-
-        public void println(double d) throws IOException {
-            delegate.println(d);
-        }
-
-        public void println(float f) throws IOException {
-            delegate.println(f);
-        }
-
-        public void println(int i) throws IOException {
-            delegate.println(i);
-        }
-
-        public void println(long l) throws IOException {
-            delegate.println(l);
-        }
-
-        public void println(String s) throws IOException {
-            delegate.println(s);
-        }
-
-        public void write(byte[] b) throws IOException {
-            delegate.write(b);
-        }
-
-        public void write(byte[] b, int off, int len) throws IOException {
-            delegate.write(b, off, len);
-        }
-
-        public String toString() {
-            return getClass().getName() + "[delegate=" + delegate.toString() + "]";
-        }
-    }
 }

+ 1122 - 0
web/src/test/java/org/springframework/security/web/context/OnCommittedResponseWrapperTests.java

@@ -0,0 +1,1122 @@
+/*
+ * Copyright 2002-2015 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.web.context;
+
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.util.Locale;
+
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
+
+import javax.servlet.ServletOutputStream;
+import javax.servlet.http.HttpServletResponse;
+
+import static org.fest.assertions.Assertions.assertThat;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
+
+@RunWith(MockitoJUnitRunner.class)
+public class OnCommittedResponseWrapperTests {
+    private static final String NL = "\r\n";
+
+    @Mock
+    HttpServletResponse delegate;
+    @Mock
+    PrintWriter writer;
+    @Mock
+    ServletOutputStream out;
+
+    OnCommittedResponseWrapper response;
+
+    boolean committed;
+
+    @Before
+    public void setup() throws Exception {
+        response = new OnCommittedResponseWrapper(delegate) {
+            @Override
+            protected void onResponseCommitted() {
+                committed = true;
+            }
+        };
+        when(delegate.getWriter()).thenReturn(writer);
+        when(delegate.getOutputStream()).thenReturn(out);
+    }
+
+
+    // --- printwriter
+
+    @Test
+    public void printWriterHashCode() throws Exception {
+        int expected = writer.hashCode();
+
+        assertThat(response.getWriter().hashCode()).isEqualTo(expected);
+    }
+
+    @Test
+    public void printWriterCheckError() throws Exception {
+        boolean expected = true;
+        when(writer.checkError()).thenReturn(expected);
+
+        assertThat(response.getWriter().checkError()).isEqualTo(expected);
+    }
+
+    @Test
+    public void printWriterWriteInt() throws Exception {
+        int expected = 1;
+
+        response.getWriter().write(expected);
+
+        verify(writer).write(expected);
+    }
+
+    @Test
+    public void printWriterWriteCharIntInt() throws Exception {
+        char[] buff = new char[0];
+        int off = 2;
+        int len = 3;
+
+        response.getWriter().write(buff,off,len);
+
+        verify(writer).write(buff,off,len);
+    }
+
+    @Test
+    public void printWriterWriteChar() throws Exception {
+        char[] buff = new char[0];
+
+        response.getWriter().write(buff);
+
+        verify(writer).write(buff);
+    }
+
+    @Test
+    public void printWriterWriteStringIntInt() throws Exception {
+        String s = "";
+        int off = 2;
+        int len = 3;
+
+        response.getWriter().write(s,off,len);
+
+        verify(writer).write(s,off,len);
+    }
+
+    @Test
+    public void printWriterWriteString() throws Exception {
+        String s = "";
+
+        response.getWriter().write(s);
+
+        verify(writer).write(s);
+    }
+
+    @Test
+    public void printWriterPrintBoolean() throws Exception {
+        boolean b = true;
+
+        response.getWriter().print(b);
+
+        verify(writer).print(b);
+    }
+
+    @Test
+    public void printWriterPrintChar() throws Exception {
+        char c = 1;
+
+        response.getWriter().print(c);
+
+        verify(writer).print(c);
+    }
+
+    @Test
+    public void printWriterPrintInt() throws Exception {
+        int i = 1;
+
+        response.getWriter().print(i);
+
+        verify(writer).print(i);
+    }
+
+    @Test
+    public void printWriterPrintLong() throws Exception {
+        long l = 1;
+
+        response.getWriter().print(l);
+
+        verify(writer).print(l);
+    }
+
+    @Test
+    public void printWriterPrintFloat() throws Exception {
+        float f = 1;
+
+        response.getWriter().print(f);
+
+        verify(writer).print(f);
+    }
+
+    @Test
+    public void printWriterPrintDouble() throws Exception {
+        double x = 1;
+
+        response.getWriter().print(x);
+
+        verify(writer).print(x);
+    }
+
+    @Test
+    public void printWriterPrintCharArray() throws Exception {
+        char[] x = new char[0];
+
+        response.getWriter().print(x);
+
+        verify(writer).print(x);
+    }
+
+    @Test
+    public void printWriterPrintString() throws Exception {
+        String x = "1";
+
+        response.getWriter().print(x);
+
+        verify(writer).print(x);
+    }
+
+    @Test
+    public void printWriterPrintObject() throws Exception {
+        Object x = "1";
+
+        response.getWriter().print(x);
+
+        verify(writer).print(x);
+    }
+
+    @Test
+    public void printWriterPrintln() throws Exception {
+        response.getWriter().println();
+
+        verify(writer).println();
+    }
+
+    @Test
+    public void printWriterPrintlnBoolean() throws Exception {
+        boolean b = true;
+
+        response.getWriter().println(b);
+
+        verify(writer).println(b);
+    }
+
+    @Test
+    public void printWriterPrintlnChar() throws Exception {
+        char c = 1;
+
+        response.getWriter().println(c);
+
+        verify(writer).println(c);
+    }
+
+    @Test
+    public void printWriterPrintlnInt() throws Exception {
+        int i = 1;
+
+        response.getWriter().println(i);
+
+        verify(writer).println(i);
+    }
+
+    @Test
+    public void printWriterPrintlnLong() throws Exception {
+        long l = 1;
+
+        response.getWriter().println(l);
+
+        verify(writer).println(l);
+    }
+
+    @Test
+    public void printWriterPrintlnFloat() throws Exception {
+        float f = 1;
+
+        response.getWriter().println(f);
+
+        verify(writer).println(f);
+    }
+
+    @Test
+    public void printWriterPrintlnDouble() throws Exception {
+        double x = 1;
+
+        response.getWriter().println(x);
+
+        verify(writer).println(x);
+    }
+
+    @Test
+    public void printWriterPrintlnCharArray() throws Exception {
+        char[] x = new char[0];
+
+        response.getWriter().println(x);
+
+        verify(writer).println(x);
+    }
+
+    @Test
+    public void printWriterPrintlnString() throws Exception {
+        String x = "1";
+
+        response.getWriter().println(x);
+
+        verify(writer).println(x);
+    }
+
+    @Test
+    public void printWriterPrintlnObject() throws Exception {
+        Object x = "1";
+
+        response.getWriter().println(x);
+
+        verify(writer).println(x);
+    }
+
+    @Test
+    public void printWriterPrintfStringObjectVargs() throws Exception {
+        String format = "format";
+        Object[] args = new Object[] { "1" };
+
+        response.getWriter().printf(format, args);
+
+        verify(writer).printf(format, args);
+    }
+
+    @Test
+    public void printWriterPrintfLocaleStringObjectVargs() throws Exception {
+        Locale l = Locale.US;
+        String format = "format";
+        Object[] args = new Object[] { "1" };
+
+        response.getWriter().printf(l, format, args);
+
+        verify(writer).printf(l, format, args);
+    }
+
+    @Test
+    public void printWriterFormatStringObjectVargs() throws Exception {
+        String format = "format";
+        Object[] args = new Object[] { "1" };
+
+        response.getWriter().format(format, args);
+
+        verify(writer).format(format, args);
+    }
+
+    @Test
+    public void printWriterFormatLocaleStringObjectVargs() throws Exception {
+        Locale l = Locale.US;
+        String format = "format";
+        Object[] args = new Object[] { "1" };
+
+        response.getWriter().format(l, format, args);
+
+        verify(writer).format(l, format, args);
+    }
+
+
+    @Test
+    public void printWriterAppendCharSequence() throws Exception {
+        String x = "a";
+
+        response.getWriter().append(x);
+
+        verify(writer).append(x);
+    }
+
+    @Test
+    public void printWriterAppendCharSequenceIntInt() throws Exception {
+        String x = "abcdef";
+        int start = 1;
+        int end = 3;
+
+        response.getWriter().append(x, start, end);
+
+        verify(writer).append(x, start, end);
+    }
+
+
+    @Test
+    public void printWriterAppendChar() throws Exception {
+        char x = 1;
+
+        response.getWriter().append(x);
+
+        verify(writer).append(x);
+    }
+
+    // servletoutputstream
+
+
+    @Test
+    public void outputStreamHashCode() throws Exception {
+        int expected = out.hashCode();
+
+        assertThat(response.getOutputStream().hashCode()).isEqualTo(expected);
+    }
+
+    @Test
+    public void outputStreamWriteInt() throws Exception {
+        int expected = 1;
+
+        response.getOutputStream().write(expected);
+
+        verify(out).write(expected);
+    }
+
+    @Test
+    public void outputStreamWriteByte() throws Exception {
+        byte[] expected = new byte[0];
+
+        response.getOutputStream().write(expected);
+
+        verify(out).write(expected);
+    }
+
+    @Test
+    public void outputStreamWriteByteIntInt() throws Exception {
+        int start = 1;
+        int end = 2;
+        byte[] expected = new byte[0];
+
+        response.getOutputStream().write(expected, start, end);
+
+        verify(out).write(expected, start, end);
+    }
+
+    @Test
+    public void outputStreamPrintBoolean() throws Exception {
+        boolean b = true;
+
+        response.getOutputStream().print(b);
+
+        verify(out).print(b);
+    }
+
+    @Test
+    public void outputStreamPrintChar() throws Exception {
+        char c = 1;
+
+        response.getOutputStream().print(c);
+
+        verify(out).print(c);
+    }
+
+    @Test
+    public void outputStreamPrintInt() throws Exception {
+        int i = 1;
+
+        response.getOutputStream().print(i);
+
+        verify(out).print(i);
+    }
+
+    @Test
+    public void outputStreamPrintLong() throws Exception {
+        long l = 1;
+
+        response.getOutputStream().print(l);
+
+        verify(out).print(l);
+    }
+
+    @Test
+    public void outputStreamPrintFloat() throws Exception {
+        float f = 1;
+
+        response.getOutputStream().print(f);
+
+        verify(out).print(f);
+    }
+
+    @Test
+    public void outputStreamPrintDouble() throws Exception {
+        double x = 1;
+
+        response.getOutputStream().print(x);
+
+        verify(out).print(x);
+    }
+
+    @Test
+    public void outputStreamPrintString() throws Exception {
+        String x = "1";
+
+        response.getOutputStream().print(x);
+
+        verify(out).print(x);
+    }
+
+    @Test
+    public void outputStreamPrintln() throws Exception {
+        response.getOutputStream().println();
+
+        verify(out).println();
+    }
+
+    @Test
+    public void outputStreamPrintlnBoolean() throws Exception {
+        boolean b = true;
+
+        response.getOutputStream().println(b);
+
+        verify(out).println(b);
+    }
+
+    @Test
+    public void outputStreamPrintlnChar() throws Exception {
+        char c = 1;
+
+        response.getOutputStream().println(c);
+
+        verify(out).println(c);
+    }
+
+    @Test
+    public void outputStreamPrintlnInt() throws Exception {
+        int i = 1;
+
+        response.getOutputStream().println(i);
+
+        verify(out).println(i);
+    }
+
+    @Test
+    public void outputStreamPrintlnLong() throws Exception {
+        long l = 1;
+
+        response.getOutputStream().println(l);
+
+        verify(out).println(l);
+    }
+
+    @Test
+    public void outputStreamPrintlnFloat() throws Exception {
+        float f = 1;
+
+        response.getOutputStream().println(f);
+
+        verify(out).println(f);
+    }
+
+    @Test
+    public void outputStreamPrintlnDouble() throws Exception {
+        double x = 1;
+
+        response.getOutputStream().println(x);
+
+        verify(out).println(x);
+    }
+
+    @Test
+    public void outputStreamPrintlnString() throws Exception {
+        String x = "1";
+
+        response.getOutputStream().println(x);
+
+        verify(out).println(x);
+    }
+
+    // The amount of content specified in the setContentLength method of the response
+    // has been greater than zero and has been written to the response.
+
+    @Test
+    public void contentLengthPrintWriterWriteIntCommits() throws Exception {
+        int expected = 1;
+        response.setContentLength(String.valueOf(expected).length());
+
+        response.getWriter().write(expected);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterWriteIntMultiDigitCommits() throws Exception {
+        int expected = 10000;
+        response.setContentLength(String.valueOf(expected).length());
+
+        response.getWriter().write(expected);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPlus1PrintWriterWriteIntMultiDigitCommits() throws Exception {
+        int expected = 10000;
+        response.setContentLength(String.valueOf(expected).length() + 1);
+
+        response.getWriter().write(expected);
+
+        assertThat(committed).isFalse();
+
+        response.getWriter().write(1);
+
+        assertThat(committed).isTrue();
+    }
+
+
+    @Test
+    public void contentLengthPrintWriterWriteCharIntIntCommits() throws Exception {
+        char[] buff = new char[0];
+        int off = 2;
+        int len = 3;
+        response.setContentLength(3);
+
+        response.getWriter().write(buff,off,len);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterWriteCharCommits() throws Exception {
+        char[] buff = new char[4];
+        response.setContentLength(buff.length);
+
+        response.getWriter().write(buff);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterWriteStringIntIntCommits() throws Exception {
+        String s = "";
+        int off = 2;
+        int len = 3;
+        response.setContentLength(3);
+
+        response.getWriter().write(s,off,len);
+
+        assertThat(committed).isTrue();
+    }
+
+
+    @Test
+    public void contentLengthPrintWriterWriteStringCommits() throws IOException {
+        String body = "something";
+        response.setContentLength(body.length());
+
+        response.getWriter().write(body);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void printWriterWriteStringContentLengthCommits() throws IOException {
+        String body = "something";
+        response.getWriter().write(body);
+
+        response.setContentLength(body.length());
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void printWriterWriteStringDoesNotCommit() throws IOException {
+        String body = "something";
+
+        response.getWriter().write(body);
+
+        assertThat(committed).isFalse();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintBooleanCommits() throws Exception {
+        boolean b = true;
+        response.setContentLength(1);
+
+        response.getWriter().print(b);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintCharCommits() throws Exception {
+        char c = 1;
+        response.setContentLength(1);
+
+        response.getWriter().print(c);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintIntCommits() throws Exception {
+        int i = 1234;
+        response.setContentLength(String.valueOf(i).length());
+
+        response.getWriter().print(i);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintLongCommits() throws Exception {
+        long l = 12345;
+        response.setContentLength(String.valueOf(l).length());
+
+        response.getWriter().print(l);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintFloatCommits() throws Exception {
+        float f = 12345;
+        response.setContentLength(String.valueOf(f).length());
+
+        response.getWriter().print(f);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintDoubleCommits() throws Exception {
+        double x = 1.2345;
+        response.setContentLength(String.valueOf(x).length());
+
+        response.getWriter().print(x);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintCharArrayCommits() throws Exception {
+        char[] x = new char[10];
+        response.setContentLength(x.length);
+
+        response.getWriter().print(x);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintStringCommits() throws Exception {
+        String x = "12345";
+        response.setContentLength(x.length());
+
+        response.getWriter().print(x);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintObjectCommits() throws Exception {
+        Object x = "12345";
+        response.setContentLength(String.valueOf(x).length());
+
+        response.getWriter().print(x);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintlnCommits() throws Exception {
+        response.setContentLength(NL.length());
+
+        response.getWriter().println();
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintlnBooleanCommits() throws Exception {
+        boolean b = true;
+        response.setContentLength(1);
+
+        response.getWriter().println(b);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintlnCharCommits() throws Exception {
+        char c = 1;
+        response.setContentLength(1);
+
+        response.getWriter().println(c);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintlnIntCommits() throws Exception {
+        int i = 12345;
+        response.setContentLength(String.valueOf(i).length());
+
+        response.getWriter().println(i);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintlnLongCommits() throws Exception {
+        long l = 12345678;
+        response.setContentLength(String.valueOf(l).length());
+
+        response.getWriter().println(l);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintlnFloatCommits() throws Exception {
+        float f = 1234;
+        response.setContentLength(String.valueOf(f).length());
+
+        response.getWriter().println(f);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintlnDoubleCommits() throws Exception {
+        double x = 1;
+        response.setContentLength(String.valueOf(x).length());
+
+        response.getWriter().println(x);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintlnCharArrayCommits() throws Exception {
+        char[] x = new char[20];
+        response.setContentLength(x.length);
+
+        response.getWriter().println(x);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintlnStringCommits() throws Exception {
+        String x = "1";
+        response.setContentLength(String.valueOf(x).length());
+
+        response.getWriter().println(x);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterPrintlnObjectCommits() throws Exception {
+        Object x = "1";
+        response.setContentLength(String.valueOf(x).length());
+
+        response.getWriter().println(x);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterAppendCharSequenceCommits() throws Exception {
+        String x = "a";
+        response.setContentLength(String.valueOf(x).length());
+
+        response.getWriter().append(x);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterAppendCharSequenceIntIntCommits() throws Exception {
+        String x = "abcdef";
+        int start = 1;
+        int end = 3;
+        response.setContentLength(end - start);
+
+        response.getWriter().append(x, start, end);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPrintWriterAppendCharCommits() throws Exception {
+        char x = 1;
+        response.setContentLength(1);
+
+        response.getWriter().append(x);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamWriteIntCommits() throws Exception {
+        int expected = 1;
+        response.setContentLength(String.valueOf(expected).length());
+
+        response.getOutputStream().write(expected);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamWriteIntMultiDigitCommits() throws Exception {
+        int expected = 10000;
+        response.setContentLength(String.valueOf(expected).length());
+
+        response.getOutputStream().write(expected);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthPlus1OutputStreamWriteIntMultiDigitCommits() throws Exception {
+        int expected = 10000;
+        response.setContentLength(String.valueOf(expected).length() + 1);
+
+        response.getOutputStream().write(expected);
+
+        assertThat(committed).isFalse();
+
+        response.getOutputStream().write(1);
+
+        assertThat(committed).isTrue();
+    }
+
+    // gh-171
+    @Test
+    public void contentLengthPlus1OutputStreamWriteByteArrayMultiDigitCommits() throws Exception {
+        String expected = "{\n" +
+                "  \"parameterName\" : \"_csrf\",\n" +
+                "  \"token\" : \"06300b65-c4aa-4c8f-8cda-39ee17f545a0\",\n" +
+                "  \"headerName\" : \"X-CSRF-TOKEN\"\n" +
+                "}";
+        response.setContentLength(expected.length() + 1);
+
+        response.getOutputStream().write(expected.getBytes());
+
+        assertThat(committed).isFalse();
+
+        response.getOutputStream().write("1".getBytes("UTF-8"));
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintBooleanCommits() throws Exception {
+        boolean b = true;
+        response.setContentLength(1);
+
+        response.getOutputStream().print(b);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintCharCommits() throws Exception {
+        char c = 1;
+        response.setContentLength(1);
+
+        response.getOutputStream().print(c);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintIntCommits() throws Exception {
+        int i = 1234;
+        response.setContentLength(String.valueOf(i).length());
+
+        response.getOutputStream().print(i);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintLongCommits() throws Exception {
+        long l = 12345;
+        response.setContentLength(String.valueOf(l).length());
+
+        response.getOutputStream().print(l);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintFloatCommits() throws Exception {
+        float f = 12345;
+        response.setContentLength(String.valueOf(f).length());
+
+        response.getOutputStream().print(f);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintDoubleCommits() throws Exception {
+        double x = 1.2345;
+        response.setContentLength(String.valueOf(x).length());
+
+        response.getOutputStream().print(x);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintStringCommits() throws Exception {
+        String x = "12345";
+        response.setContentLength(x.length());
+
+        response.getOutputStream().print(x);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintlnCommits() throws Exception {
+        response.setContentLength(NL.length());
+
+        response.getOutputStream().println();
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintlnBooleanCommits() throws Exception {
+        boolean b = true;
+        response.setContentLength(1);
+
+        response.getOutputStream().println(b);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintlnCharCommits() throws Exception {
+        char c = 1;
+        response.setContentLength(1);
+
+        response.getOutputStream().println(c);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintlnIntCommits() throws Exception {
+        int i = 12345;
+        response.setContentLength(String.valueOf(i).length());
+
+        response.getOutputStream().println(i);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintlnLongCommits() throws Exception {
+        long l = 12345678;
+        response.setContentLength(String.valueOf(l).length());
+
+        response.getOutputStream().println(l);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintlnFloatCommits() throws Exception {
+        float f = 1234;
+        response.setContentLength(String.valueOf(f).length());
+
+        response.getOutputStream().println(f);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintlnDoubleCommits() throws Exception {
+        double x = 1;
+        response.setContentLength(String.valueOf(x).length());
+
+        response.getOutputStream().println(x);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthOutputStreamPrintlnStringCommits() throws Exception {
+        String x = "1";
+        response.setContentLength(String.valueOf(x).length());
+
+        response.getOutputStream().println(x);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void contentLengthDoesNotCommit() throws IOException {
+        String body = "something";
+
+        response.setContentLength(body.length());
+
+        assertThat(committed).isFalse();
+    }
+
+    @Test
+    public void contentLengthOutputStreamWriteStringCommits() throws IOException {
+        String body = "something";
+        response.setContentLength(body.length());
+
+        response.getOutputStream().print(body);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void addHeaderContentLengthPrintWriterWriteStringCommits() throws Exception {
+        int expected = 1234;
+        response.addHeader("Content-Length",String.valueOf(String.valueOf(expected).length()));
+
+        response.getWriter().write(expected);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void bufferSizePrintWriterWriteCommits() throws Exception {
+        String expected = "1234567890";
+        when(response.getBufferSize()).thenReturn(expected.length());
+
+        response.getWriter().write(expected);
+
+        assertThat(committed).isTrue();
+    }
+
+    @Test
+    public void bufferSizeCommitsOnce() throws Exception {
+        String expected = "1234567890";
+        when(response.getBufferSize()).thenReturn(expected.length());
+
+        response.getWriter().write(expected);
+
+        assertThat(committed).isTrue();
+
+        committed = false;
+
+        response.getWriter().write(expected);
+
+        assertThat(committed).isFalse();
+    }
+}