Browse Source

SEC-1998: SecurityContext integration with AsyncContext.start

Rob Winch 12 years ago
parent
commit
593e512558

+ 15 - 0
web/src/main/java/org/springframework/security/web/servletapi/HttpServlet25RequestFactory.java

@@ -0,0 +1,15 @@
+package org.springframework.security.web.servletapi;
+
+import javax.servlet.http.HttpServletRequest;
+
+final class HttpServlet25RequestFactory implements HttpServletRequestFactory {
+    private final String rolePrefix;
+
+    HttpServlet25RequestFactory(String rolePrefix) {
+        this.rolePrefix = rolePrefix;
+    }
+
+    public HttpServletRequest create(HttpServletRequest request) {
+        return new SecurityContextHolderAwareRequestWrapper(request, rolePrefix) ;
+    }
+}

+ 112 - 0
web/src/main/java/org/springframework/security/web/servletapi/HttpServlet3RequestFactory.java

@@ -0,0 +1,112 @@
+/*
+ * Copyright 2002-2012 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+package org.springframework.security.web.servletapi;
+
+import javax.servlet.AsyncContext;
+import javax.servlet.AsyncListener;
+import javax.servlet.ServletContext;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+
+import org.springframework.security.concurrent.DelegatingSecurityContextRunnable;
+
+final class HttpServlet3RequestFactory implements HttpServletRequestFactory {
+    private final String rolePrefix;
+
+    HttpServlet3RequestFactory(String rolePrefix) {
+        this.rolePrefix = rolePrefix;
+    }
+
+    public HttpServletRequest create(HttpServletRequest request) {
+        return new Servlet3SecurityContextHolderAwareRequestWrapper(request, rolePrefix);
+    }
+
+    private static class Servlet3SecurityContextHolderAwareRequestWrapper extends SecurityContextHolderAwareRequestWrapper {
+        public Servlet3SecurityContextHolderAwareRequestWrapper(HttpServletRequest request, String rolePrefix) {
+            super(request, rolePrefix);
+        }
+
+        public AsyncContext startAsync() {
+            AsyncContext startAsync = super.startAsync();
+            return new SecurityContextAsyncContext(startAsync);
+        }
+
+        public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse)
+                throws IllegalStateException {
+            AsyncContext startAsync = super.startAsync(servletRequest, servletResponse);
+            return new SecurityContextAsyncContext(startAsync);
+        }
+    }
+
+    private static class SecurityContextAsyncContext implements AsyncContext {
+        private final AsyncContext asyncContext;
+
+        public SecurityContextAsyncContext(AsyncContext asyncContext) {
+            this.asyncContext = asyncContext;
+        }
+
+        public ServletRequest getRequest() {
+            return asyncContext.getRequest();
+        }
+
+        public ServletResponse getResponse() {
+            return asyncContext.getResponse();
+        }
+
+        public boolean hasOriginalRequestAndResponse() {
+            return asyncContext.hasOriginalRequestAndResponse();
+        }
+
+        public void dispatch() {
+            asyncContext.dispatch();
+        }
+
+        public void dispatch(String path) {
+            asyncContext.dispatch(path);
+        }
+
+        public void dispatch(ServletContext context, String path) {
+            asyncContext.dispatch(context, path);
+        }
+
+        public void complete() {
+            asyncContext.complete();
+        }
+
+        public void start(Runnable run) {
+            asyncContext.start(new DelegatingSecurityContextRunnable(run));
+        }
+
+        public void addListener(AsyncListener listener) {
+            asyncContext.addListener(listener);
+        }
+
+        public void addListener(AsyncListener listener, ServletRequest request, ServletResponse response) {
+            asyncContext.addListener(listener, request, response);
+        }
+
+        public <T extends AsyncListener> T createListener(Class<T> clazz) throws ServletException {
+            return asyncContext.createListener(clazz);
+        }
+
+        public long getTimeout() {
+            return asyncContext.getTimeout();
+        }
+
+        public void setTimeout(long timeout) {
+            asyncContext.setTimeout(timeout);
+        }
+    }
+}

+ 35 - 0
web/src/main/java/org/springframework/security/web/servletapi/HttpServletRequestFactory.java

@@ -0,0 +1,35 @@
+/*
+ * Copyright 2002-2012 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
+package org.springframework.security.web.servletapi;
+
+import javax.servlet.http.HttpServletRequest;
+
+/**
+ * Internal interface for creating a {@link HttpServletRequest}. This allows for creating a different implementation for
+ * Servlet 2.5 and Servlet 3.0 environments.
+ *
+ * @author Rob Winch
+ * @since 3.2
+ * @see HttpServlet3RequestFactory
+ * @see HttpServlet25RequestFactory
+ */
+interface HttpServletRequestFactory {
+
+    /**
+     * Given a {@link HttpServletRequest} returns a {@link HttpServletRequest} that in most cases wraps the original
+     * {@link HttpServletRequest}.
+     * @param request the original {@link HttpServletRequest}. Cannot be null.
+     * @return a non-null HttpServletRequest
+     */
+    public HttpServletRequest create(HttpServletRequest request);
+}

+ 13 - 3
web/src/main/java/org/springframework/security/web/servletapi/SecurityContextHolderAwareRequestFilter.java

@@ -24,6 +24,7 @@ import javax.servlet.ServletResponse;
 import javax.servlet.http.HttpServletRequest;
 
 import org.springframework.util.Assert;
+import org.springframework.util.ClassUtils;
 import org.springframework.web.filter.GenericFilterBean;
 
 
@@ -40,17 +41,26 @@ import org.springframework.web.filter.GenericFilterBean;
 public class SecurityContextHolderAwareRequestFilter extends GenericFilterBean {
     //~ Instance fields ================================================================================================
 
-    private String rolePrefix;
+    private HttpServletRequestFactory requestFactory;
+
+    public SecurityContextHolderAwareRequestFilter() {
+        setRequestFactory(null);
+    }
 
     //~ Methods ========================================================================================================
 
     public void setRolePrefix(String rolePrefix) {
         Assert.notNull(rolePrefix, "Role prefix must not be null");
-        this.rolePrefix = rolePrefix.trim();
+        setRequestFactory(rolePrefix.trim());
+    }
+
+    private void setRequestFactory(String rolePrefix) {
+        boolean isServlet3 = ClassUtils.hasMethod(ServletRequest.class, "startAsync");
+        requestFactory = isServlet3 ? new HttpServlet3RequestFactory(rolePrefix) : new HttpServlet25RequestFactory(rolePrefix);
     }
 
     public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
             throws IOException, ServletException {
-        chain.doFilter(new SecurityContextHolderAwareRequestWrapper((HttpServletRequest) req, rolePrefix), res);
+        chain.doFilter(requestFactory.create((HttpServletRequest)req), res);
     }
 }