瀏覽代碼

SEC-2079: Add Servlet 3 Authentication methods

Add support for HttpServletRequest's login(String,String), logout(),
and authenticate(HttpServletResponse).
Rob Winch 12 年之前
父節點
當前提交
c8d45397fe

+ 7 - 1
config/src/main/java/org/springframework/security/config/http/AuthenticationConfigBuilder.java

@@ -115,6 +115,7 @@ final class AuthenticationConfigBuilder {
     private BeanDefinition jeeFilter;
     private BeanReference jeeProviderRef;
     private RootBeanDefinition preAuthEntryPoint;
+    private BeanMetadataElement mainEntryPoint;
 
     private BeanDefinition logoutFilter;
     @SuppressWarnings("rawtypes")
@@ -499,6 +500,10 @@ final class AuthenticationConfigBuilder {
         return logoutHandlers;
     }
 
+    BeanMetadataElement getEntryPointBean() {
+        return mainEntryPoint;
+    }
+
     void createAnonymousFilter() {
         Element anonymousElt = DomUtils.getChildElementByTagName(httpElt, Elements.ANONYMOUS);
 
@@ -556,7 +561,8 @@ final class AuthenticationConfigBuilder {
         BeanDefinitionBuilder etfBuilder = BeanDefinitionBuilder.rootBeanDefinition(ExceptionTranslationFilter.class);
         etfBuilder.addPropertyValue("accessDeniedHandler", createAccessDeniedHandler(httpElt, pc));
         assert requestCache != null;
-        etfBuilder.addConstructorArgValue(selectEntryPoint());
+        mainEntryPoint = selectEntryPoint();
+        etfBuilder.addConstructorArgValue(mainEntryPoint);
         etfBuilder.addConstructorArgValue(requestCache);
 
         etf = etfBuilder.getBeanDefinition();

+ 17 - 4
config/src/main/java/org/springframework/security/config/http/HttpConfigurationBuilder.java

@@ -23,6 +23,7 @@ import java.util.List;
 
 import javax.servlet.ServletRequest;
 
+import org.springframework.beans.BeanMetadataElement;
 import org.springframework.beans.factory.config.BeanDefinition;
 import org.springframework.beans.factory.config.BeanReference;
 import org.springframework.beans.factory.config.RuntimeBeanReference;
@@ -146,7 +147,7 @@ class HttpConfigurationBuilder {
         createSessionManagementFilters();
         createWebAsyncManagerFilter();
         createRequestCacheFilter();
-        createServletApiFilter();
+        createServletApiFilter(authenticationManager);
         createJaasApiFilter();
         createChannelProcessingFilter();
         createFilterSecurityInterceptor(authenticationManager);
@@ -154,8 +155,19 @@ class HttpConfigurationBuilder {
 
     @SuppressWarnings("rawtypes")
     void setLogoutHandlers(ManagedList logoutHandlers) {
-        if(logoutHandlers != null && concurrentSessionFilter != null) {
-            concurrentSessionFilter.getPropertyValues().add("logoutHandlers", logoutHandlers);
+        if(logoutHandlers != null) {
+            if(concurrentSessionFilter != null) {
+                concurrentSessionFilter.getPropertyValues().add("logoutHandlers", logoutHandlers);
+            }
+            if(servApiFilter != null) {
+                servApiFilter.getPropertyValues().add("logoutHandlers", logoutHandlers);
+            }
+        }
+    }
+
+    void setEntryPoint(BeanMetadataElement entryPoint) {
+        if(servApiFilter != null) {
+            servApiFilter.getPropertyValues().add("authenticationEntryPoint", entryPoint);
         }
     }
 
@@ -363,7 +375,7 @@ class HttpConfigurationBuilder {
     }
 
     // Adds the servlet-api integration filter if required
-    private void createServletApiFilter() {
+    private void createServletApiFilter(BeanReference authenticationManager) {
         final String ATT_SERVLET_API_PROVISION = "servlet-api-provision";
         final String DEF_SERVLET_API_PROVISION = "true";
 
@@ -374,6 +386,7 @@ class HttpConfigurationBuilder {
 
         if ("true".equals(provideServletApi)) {
             servApiFilter = new RootBeanDefinition(SecurityContextHolderAwareRequestFilter.class);
+            servApiFilter.getPropertyValues().add("authenticationManager", authenticationManager);
         }
     }
 

+ 1 - 0
config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java

@@ -140,6 +140,7 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
                 httpBldr.getSessionStrategy(), portMapper, portResolver);
 
         httpBldr.setLogoutHandlers(authBldr.getLogoutHandlers());
+        httpBldr.setEntryPoint(authBldr.getEntryPointBean());
 
         authenticationProviders.addAll(authBldr.getProviders());
 

+ 142 - 0
config/src/test/groovy/org/springframework/security/config/http/SecurityContextHolderAwareRequestConfigTests.groovy

@@ -0,0 +1,142 @@
+/*
+ * Copyright 2002-2012 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.config.http
+
+import static org.springframework.security.config.ConfigTestUtils.AUTH_PROVIDER_XML
+
+import org.springframework.beans.factory.parsing.BeanDefinitionParsingException
+import org.springframework.security.TestDataSource
+import org.springframework.security.authentication.ProviderManager
+import org.springframework.security.authentication.RememberMeAuthenticationProvider
+import org.springframework.security.config.ldap.ContextSourceSettingPostProcessor;
+import org.springframework.security.core.userdetails.MockUserDetailsService
+import org.springframework.security.util.FieldUtils
+import org.springframework.security.web.access.ExceptionTranslationFilter
+import org.springframework.security.web.authentication.SimpleUrlAuthenticationSuccessHandler
+import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter
+import org.springframework.security.web.authentication.logout.CookieClearingLogoutHandler;
+import org.springframework.security.web.authentication.logout.LogoutFilter
+import org.springframework.security.web.authentication.logout.SecurityContextLogoutHandler
+import org.springframework.security.web.authentication.rememberme.InMemoryTokenRepositoryImpl
+import org.springframework.security.web.authentication.rememberme.JdbcTokenRepositoryImpl
+import org.springframework.security.web.authentication.rememberme.PersistentTokenBasedRememberMeServices
+import org.springframework.security.web.authentication.rememberme.RememberMeAuthenticationFilter
+import org.springframework.security.web.authentication.rememberme.TokenBasedRememberMeServices
+import org.springframework.security.web.authentication.www.BasicAuthenticationEntryPoint;
+import org.springframework.security.web.authentication.www.BasicAuthenticationFilter;
+import org.springframework.security.web.servletapi.SecurityContextHolderAwareRequestFilter
+
+/**
+ *
+ * @author Rob Winch
+ */
+class SecurityContextHolderAwareRequestConfigTests extends AbstractHttpConfigTests {
+
+    def withAutoConfig() {
+        httpAutoConfig () {
+
+        }
+        createAppContext(AUTH_PROVIDER_XML)
+
+        def securityContextAwareFilter = getFilter(SecurityContextHolderAwareRequestFilter)
+
+        expect:
+        securityContextAwareFilter.authenticationEntryPoint.loginFormUrl == getFilter(ExceptionTranslationFilter).authenticationEntryPoint.loginFormUrl
+        securityContextAwareFilter.authenticationManager == getFilter(UsernamePasswordAuthenticationFilter).authenticationManager
+        securityContextAwareFilter.logoutHandlers.size() == 1
+        securityContextAwareFilter.logoutHandlers[0].class == SecurityContextLogoutHandler
+    }
+
+    def explicitEntryPoint() {
+        xml.http() {
+            'http-basic'('entry-point-ref': 'ep')
+        }
+        bean('ep', BasicAuthenticationEntryPoint.class.name, ['realmName':'whocares'],[:])
+        createAppContext(AUTH_PROVIDER_XML)
+
+        def securityContextAwareFilter = getFilter(SecurityContextHolderAwareRequestFilter)
+
+        expect:
+        securityContextAwareFilter.authenticationEntryPoint == getFilter(ExceptionTranslationFilter).authenticationEntryPoint
+        securityContextAwareFilter.authenticationManager == getFilter(BasicAuthenticationFilter).authenticationManager
+        securityContextAwareFilter.logoutHandlers == null
+    }
+
+    def formLogin() {
+        xml.http() {
+            'form-login'()
+        }
+        createAppContext(AUTH_PROVIDER_XML)
+
+        def securityContextAwareFilter = getFilter(SecurityContextHolderAwareRequestFilter)
+
+        expect:
+        securityContextAwareFilter.authenticationEntryPoint.loginFormUrl == getFilter(ExceptionTranslationFilter).authenticationEntryPoint.loginFormUrl
+        securityContextAwareFilter.authenticationManager == getFilter(UsernamePasswordAuthenticationFilter).authenticationManager
+        securityContextAwareFilter.logoutHandlers == null
+    }
+
+    def multiHttp() {
+        xml.http('authentication-manager-ref' : 'authManager', 'pattern' : '/first/**') {
+            'form-login'('login-page' : '/login')
+            'logout'('invalidate-session' : 'true')
+        }
+        xml.http('authentication-manager-ref' : 'authManager2') {
+            'form-login'('login-page' : '/login2')
+            'logout'('invalidate-session' : 'false')
+        }
+
+        String secondAuthManager = AUTH_PROVIDER_XML.replace("alias='authManager'", "id='authManager2'")
+        createAppContext(AUTH_PROVIDER_XML + secondAuthManager)
+
+        def securityContextAwareFilter = getFilters('/first/filters').find { it instanceof SecurityContextHolderAwareRequestFilter }
+        def secondSecurityContextAwareFilter = getFilter(SecurityContextHolderAwareRequestFilter)
+
+        expect:
+        securityContextAwareFilter.authenticationEntryPoint.loginFormUrl == '/login'
+        securityContextAwareFilter.authenticationManager == getFilters('/first/filters').find { it instanceof UsernamePasswordAuthenticationFilter}.authenticationManager
+        securityContextAwareFilter.authenticationManager.parent == appContext.getBean('authManager')
+        securityContextAwareFilter.logoutHandlers.size() == 1
+        securityContextAwareFilter.logoutHandlers[0].class == SecurityContextLogoutHandler
+        securityContextAwareFilter.logoutHandlers[0].invalidateHttpSession == true
+
+        secondSecurityContextAwareFilter.authenticationEntryPoint.loginFormUrl == '/login2'
+        secondSecurityContextAwareFilter.authenticationManager == getFilter(UsernamePasswordAuthenticationFilter).authenticationManager
+        secondSecurityContextAwareFilter.authenticationManager.parent == appContext.getBean('authManager2')
+        securityContextAwareFilter.logoutHandlers.size() == 1
+        secondSecurityContextAwareFilter.logoutHandlers[0].class == SecurityContextLogoutHandler
+        secondSecurityContextAwareFilter.logoutHandlers[0].invalidateHttpSession == false
+    }
+
+    def logoutCustom() {
+        xml.http() {
+            'form-login'('login-page' : '/login')
+            'logout'('invalidate-session' : 'false', 'logout-success-url' : '/login?logout', 'delete-cookies' : 'JSESSIONID')
+        }
+        createAppContext(AUTH_PROVIDER_XML)
+
+        def securityContextAwareFilter = getFilter(SecurityContextHolderAwareRequestFilter)
+
+        expect:
+        securityContextAwareFilter.authenticationEntryPoint.loginFormUrl == getFilter(ExceptionTranslationFilter).authenticationEntryPoint.loginFormUrl
+        securityContextAwareFilter.authenticationManager == getFilter(UsernamePasswordAuthenticationFilter).authenticationManager
+        securityContextAwareFilter.logoutHandlers.size() == 2
+        securityContextAwareFilter.logoutHandlers[0].class == SecurityContextLogoutHandler
+        securityContextAwareFilter.logoutHandlers[0].invalidateHttpSession == false
+        securityContextAwareFilter.logoutHandlers[1].class == CookieClearingLogoutHandler
+        securityContextAwareFilter.logoutHandlers[1].cookiesToClear == ['JSESSIONID']
+    }
+}

+ 20 - 1
web/src/main/java/org/springframework/security/web/servletapi/HttpServlet25RequestFactory.java

@@ -1,7 +1,26 @@
+/*
+ * Copyright 2002-2012 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+ * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations under the License.
+ */
 package org.springframework.security.web.servletapi;
 
 import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 
+/**
+ * Creates a {@link SecurityContextHolderAwareRequestWrapper}
+ *
+ * @author Rob Winch
+ * @see SecurityContextHolderAwareRequestWrapper
+ */
 final class HttpServlet25RequestFactory implements HttpServletRequestFactory {
     private final String rolePrefix;
 
@@ -9,7 +28,7 @@ final class HttpServlet25RequestFactory implements HttpServletRequestFactory {
         this.rolePrefix = rolePrefix;
     }
 
-    public HttpServletRequest create(HttpServletRequest request) {
+    public HttpServletRequest create(HttpServletRequest request, HttpServletResponse response) {
         return new SecurityContextHolderAwareRequestWrapper(request, rolePrefix) ;
     }
 }

+ 150 - 4
web/src/main/java/org/springframework/security/web/servletapi/HttpServlet3RequestFactory.java

@@ -12,6 +12,10 @@
  */
 package org.springframework.security.web.servletapi;
 
+import java.io.IOException;
+import java.security.Principal;
+import java.util.List;
+
 import javax.servlet.AsyncContext;
 import javax.servlet.AsyncListener;
 import javax.servlet.ServletContext;
@@ -19,23 +23,121 @@ import javax.servlet.ServletException;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
 import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
+import org.springframework.security.authentication.AuthenticationManager;
+import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
 import org.springframework.security.concurrent.DelegatingSecurityContextRunnable;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.web.AuthenticationEntryPoint;
+import org.springframework.security.web.authentication.logout.LogoutHandler;
 
+/**
+ * Provides integration with the Servlet 3 APIs in addition to the ones found in {@link HttpServlet25RequestFactory}.
+ * The additional methods that are integrated with can be found below:
+ *
+ * <ul>
+ * <li> {@link HttpServletRequest#authenticate(HttpServletResponse)} - Allows the user to determine if they are
+ * authenticated and if not send the user to the login page. See
+ * {@link #setAuthenticationEntryPoint(AuthenticationEntryPoint)}.</li>
+ * <li> {@link HttpServletRequest#login(String, String)} - Allows the user to authenticate using the
+ * {@link AuthenticationManager}. See {@link #setAuthenticationManager(AuthenticationManager)}.</li>
+ * <li> {@link HttpServletRequest#logout()} - Allows the user to logout using the {@link LogoutHandler}s configured in
+ * Spring Security. See {@link #setLogoutHandlers(List)}.</li>
+ * <li> {@link AsyncContext#start(Runnable)} - Automatically copy the {@link SecurityContext} from the
+ * {@link SecurityContextHolder} found on the Thread that invoked {@link AsyncContext#start(Runnable)} to the Thread
+ * that processes the {@link Runnable}.</li>
+ * </ul>
+ *
+ * @author Rob Winch
+ *
+ * @see SecurityContextHolderAwareRequestFilter
+ * @see HttpServlet25RequestFactory
+ * @see Servlet3SecurityContextHolderAwareRequestWrapper
+ * @see SecurityContextAsyncContext
+ */
 final class HttpServlet3RequestFactory implements HttpServletRequestFactory {
+    private Log logger = LogFactory.getLog(getClass());
+
     private final String rolePrefix;
+    private AuthenticationEntryPoint authenticationEntryPoint;
+    private AuthenticationManager authenticationManager;
+    private List<LogoutHandler> logoutHandlers;
 
     HttpServlet3RequestFactory(String rolePrefix) {
         this.rolePrefix = rolePrefix;
     }
 
-    public HttpServletRequest create(HttpServletRequest request) {
-        return new Servlet3SecurityContextHolderAwareRequestWrapper(request, rolePrefix);
+    /**
+     * <p>
+     * Sets the {@link AuthenticationEntryPoint} used when integrating {@link HttpServletRequest} with Servlet 3 APIs.
+     * Specifically, it will be used when {@link HttpServletRequest#authenticate(HttpServletResponse)} is called and the
+     * user is not authenticated.
+     * </p>
+     * <p>
+     * If the value is null (default), then the default container behavior will be be retained when invoking
+     * {@link HttpServletRequest#authenticate(HttpServletResponse)}.
+     * </p>
+     * @param authenticationEntryPoint the {@link AuthenticationEntryPoint} to use when invoking
+     * {@link HttpServletRequest#authenticate(HttpServletResponse)} if the user is not authenticated.
+     */
+
+    public void setAuthenticationEntryPoint(AuthenticationEntryPoint authenticationEntryPoint) {
+        this.authenticationEntryPoint = authenticationEntryPoint;
+    }
+
+    /**
+     * <p>
+     * Sets the {@link AuthenticationManager} used when integrating {@link HttpServletRequest} with Servlet 3 APIs.
+     * Specifically, it will be used when {@link HttpServletRequest#login(String, String)} is invoked to determine if
+     * the user is authenticated.
+     * </p>
+     * <p>
+     * If the value is null (default), then the default container behavior will be retained when invoking
+     * {@link HttpServletRequest#login(String, String)}.
+     * </p>
+     *
+     * @param authenticationManager the {@link AuthenticationManager} to use when invoking
+     * {@link HttpServletRequest#login(String, String)}
+     */
+    public void setAuthenticationManager(AuthenticationManager authenticationManager) {
+        this.authenticationManager = authenticationManager;
+    }
+
+    /**
+     * <p>
+     * Sets the {@link LogoutHandler}s used when integrating with {@link HttpServletRequest} with Servlet 3 APIs.
+     * Specifically it will be used when {@link HttpServletRequest#logout()} is invoked in order to log the user out. So
+     * long as the {@link LogoutHandler}s do not commit the {@link HttpServletResponse} (expected), then the user is in
+     * charge of handling the response.
+     * </p>
+     * <p>
+     * If the value is null (default), the default container behavior will be retained when invoking
+     * {@link HttpServletRequest#logout()}.
+     * </p>
+     *
+     * @param logoutHandlers the {@link List<LogoutHandler>}s when invoking {@link HttpServletRequest#logout()}.
+     */
+    public void setLogoutHandlers(List<LogoutHandler> logoutHandlers) {
+        this.logoutHandlers = logoutHandlers;
+    }
+
+    public HttpServletRequest create(HttpServletRequest request, HttpServletResponse response) {
+         return new Servlet3SecurityContextHolderAwareRequestWrapper(request, rolePrefix, response);
     }
 
-    private static class Servlet3SecurityContextHolderAwareRequestWrapper extends SecurityContextHolderAwareRequestWrapper {
-        public Servlet3SecurityContextHolderAwareRequestWrapper(HttpServletRequest request, String rolePrefix) {
+    private class Servlet3SecurityContextHolderAwareRequestWrapper extends SecurityContextHolderAwareRequestWrapper {
+        private final HttpServletResponse response;
+
+        public Servlet3SecurityContextHolderAwareRequestWrapper(HttpServletRequest request, String rolePrefix, HttpServletResponse response) {
             super(request, rolePrefix);
+            this.response = response;
         }
 
         public AsyncContext startAsync() {
@@ -48,6 +150,50 @@ final class HttpServlet3RequestFactory implements HttpServletRequestFactory {
             AsyncContext startAsync = super.startAsync(servletRequest, servletResponse);
             return new SecurityContextAsyncContext(startAsync);
         }
+
+        public boolean authenticate(HttpServletResponse response) throws IOException, ServletException {
+            AuthenticationEntryPoint entryPoint = authenticationEntryPoint;
+            if(entryPoint == null) {
+                logger.debug("authenticationEntryPoint is null, so allowing original HttpServletRequest to handle authenticate");
+                return super.authenticate(response);
+            }
+            Principal userPrincipal = getUserPrincipal();
+            if(userPrincipal != null) {
+                return true;
+            }
+            entryPoint.commence(this, response, new AuthenticationCredentialsNotFoundException("User is not Authenticated"));
+            return false;
+        }
+
+        public void login(String username, String password) throws ServletException {
+            AuthenticationManager authManager = authenticationManager;
+            if(authManager == null) {
+                logger.debug("authenticationManager is null, so allowing original HttpServletRequest to handle login");
+                super.login(username, password);
+                return;
+            }
+            Authentication authentication;
+            try {
+                authentication = authManager.authenticate(new UsernamePasswordAuthenticationToken(username,password));
+            } catch(AuthenticationException loginFailed) {
+                SecurityContextHolder.clearContext();
+                throw new ServletException(loginFailed.getMessage(), loginFailed);
+            }
+            SecurityContextHolder.getContext().setAuthentication(authentication);
+        }
+
+        public void logout() throws ServletException {
+            List<LogoutHandler> handlers = logoutHandlers;
+            if(handlers == null) {
+                logger.debug("logoutHandlers is null, so allowing original HttpServletRequest to handle logout");
+                super.logout();
+                return;
+            }
+            Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+            for(LogoutHandler logoutHandler : handlers) {
+                logoutHandler.logout(this, response, authentication);
+            }
+        }
     }
 
     private static class SecurityContextAsyncContext implements AsyncContext {

+ 3 - 1
web/src/main/java/org/springframework/security/web/servletapi/HttpServletRequestFactory.java

@@ -13,6 +13,7 @@
 package org.springframework.security.web.servletapi;
 
 import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 
 /**
  * Internal interface for creating a {@link HttpServletRequest}. This allows for creating a different implementation for
@@ -29,7 +30,8 @@ interface HttpServletRequestFactory {
      * Given a {@link HttpServletRequest} returns a {@link HttpServletRequest} that in most cases wraps the original
      * {@link HttpServletRequest}.
      * @param request the original {@link HttpServletRequest}. Cannot be null.
+     * @param response the original {@link HttpServletResponse}. Cannot be null.
      * @return a non-null HttpServletRequest
      */
-    public HttpServletRequest create(HttpServletRequest request);
+    public HttpServletRequest create(HttpServletRequest request, HttpServletResponse response);
 }

+ 118 - 10
web/src/main/java/org/springframework/security/web/servletapi/SecurityContextHolderAwareRequestFilter.java

@@ -1,4 +1,5 @@
-/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
+/* Copyright 2002-2012 the original author or authors.
+ * Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,13 +17,21 @@
 package org.springframework.security.web.servletapi;
 
 import java.io.IOException;
+import java.util.List;
 
+import javax.servlet.AsyncContext;
 import javax.servlet.FilterChain;
 import javax.servlet.ServletException;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
 import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
 
+import org.springframework.security.authentication.AuthenticationManager;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.web.AuthenticationEntryPoint;
+import org.springframework.security.web.authentication.logout.LogoutHandler;
 import org.springframework.util.Assert;
 import org.springframework.util.ClassUtils;
 import org.springframework.web.filter.GenericFilterBean;
@@ -32,35 +41,134 @@ import org.springframework.web.filter.GenericFilterBean;
  * A <code>Filter</code> which populates the <code>ServletRequest</code> with a request wrapper
  * which implements the servlet API security methods.
  * <p>
- * The wrapper class used is {@link SecurityContextHolderAwareRequestWrapper}.
+ * In pre servlet 3 environment the wrapper class used is {@link SecurityContextHolderAwareRequestWrapper}. See its javadoc for the methods that are implemented.
+ * </p>
+ * <p>
+ * In a servlet 3 environment {@link SecurityContextHolderAwareRequestWrapper} is extended to provide the following additional methods:
+ * </p>
+ * <ul>
+ * <li> {@link HttpServletRequest#authenticate(HttpServletResponse)} - Allows the user to determine if they are
+ * authenticated and if not send the user to the login page. See
+ * {@link #setAuthenticationEntryPoint(AuthenticationEntryPoint)}.</li>
+ * <li> {@link HttpServletRequest#login(String, String)} - Allows the user to authenticate using the
+ * {@link AuthenticationManager}. See {@link #setAuthenticationManager(AuthenticationManager)}.</li>
+ * <li> {@link HttpServletRequest#logout()} - Allows the user to logout using the {@link LogoutHandler}s configured in
+ * Spring Security. See {@link #setLogoutHandlers(List)}.</li>
+ * <li> {@link AsyncContext#start(Runnable)} - Automatically copy the {@link SecurityContext} from the
+ * {@link SecurityContextHolder} found on the Thread that invoked {@link AsyncContext#start(Runnable)} to the Thread
+ * that processes the {@link Runnable}.</li>
+ * </ul>
+ *
  *
  * @author Orlando Garcia Carmona
  * @author Ben Alex
  * @author Luke Taylor
+ * @author Rob Winch
  */
 public class SecurityContextHolderAwareRequestFilter extends GenericFilterBean {
     //~ Instance fields ================================================================================================
 
+    private String rolePrefix;
+
     private HttpServletRequestFactory requestFactory;
 
-    public SecurityContextHolderAwareRequestFilter() {
-        setRequestFactory(null);
-    }
+    private AuthenticationEntryPoint authenticationEntryPoint;
+
+    private AuthenticationManager authenticationManager;
+
+    private List<LogoutHandler> logoutHandlers;
 
     //~ Methods ========================================================================================================
 
     public void setRolePrefix(String rolePrefix) {
         Assert.notNull(rolePrefix, "Role prefix must not be null");
-        setRequestFactory(rolePrefix.trim());
+        this.rolePrefix = rolePrefix;
+    }
+
+    /**
+     * <p>
+     * Sets the {@link AuthenticationEntryPoint} used when integrating {@link HttpServletRequest} with Servlet 3 APIs.
+     * Specifically, it will be used when {@link HttpServletRequest#authenticate(HttpServletResponse)} is called and the
+     * user is not authenticated.
+     * </p>
+     * <p>
+     * If the value is null (default), then the default container behavior will be be retained when invoking
+     * {@link HttpServletRequest#authenticate(HttpServletResponse)}.
+     * </p>
+     *
+     * @param authenticationEntryPoint the {@link AuthenticationEntryPoint} to use when invoking
+     * {@link HttpServletRequest#authenticate(HttpServletResponse)} if the user is not authenticated.
+     *
+     * @throws IllegalStateException if the Servlet 3 APIs are not found on the classpath
+     */
+    public void setAuthenticationEntryPoint(AuthenticationEntryPoint authenticationEntryPoint) {
+        this.authenticationEntryPoint = authenticationEntryPoint;
     }
 
-    private void setRequestFactory(String rolePrefix) {
-        boolean isServlet3 = ClassUtils.hasMethod(ServletRequest.class, "startAsync");
-        requestFactory = isServlet3 ? new HttpServlet3RequestFactory(rolePrefix) : new HttpServlet25RequestFactory(rolePrefix);
+    /**
+     * <p>
+     * Sets the {@link AuthenticationManager} used when integrating {@link HttpServletRequest} with Servlet 3 APIs.
+     * Specifically, it will be used when {@link HttpServletRequest#login(String, String)} is invoked to determine if
+     * the user is authenticated.
+     * </p>
+     * <p>
+     * If the value is null (default), then the default container behavior will be retained when invoking
+     * {@link HttpServletRequest#login(String, String)}.
+     * </p>
+     *
+     * @param authenticationManager the {@link AuthenticationManager} to use when invoking
+     * {@link HttpServletRequest#login(String, String)}
+     *
+     * @throws IllegalStateException if the Servlet 3 APIs are not found on the classpath
+     */
+    public void setAuthenticationManager(AuthenticationManager authenticationManager) {
+        this.authenticationManager = authenticationManager;
+    }
+
+    /**
+     * <p>
+     * Sets the {@link LogoutHandler}s used when integrating with {@link HttpServletRequest} with Servlet 3 APIs.
+     * Specifically it will be used when {@link HttpServletRequest#logout()} is invoked in order to log the user out. So
+     * long as the {@link LogoutHandler}s do not commit the {@link HttpServletResponse} (expected), then the user is in
+     * charge of handling the response.
+     * </p>
+     * <p>
+     * If the value is null (default), the default container behavior will be retained when invoking
+     * {@link HttpServletRequest#logout()}.
+     * </p>
+     *
+     * @param logoutHandlers the {@link List<LogoutHandler>}s when invoking {@link HttpServletRequest#logout()}.
+     *
+     * @throws IllegalStateException if the Servlet 3 APIs are not found on the classpath
+     */
+    public void setLogoutHandlers(List<LogoutHandler> logoutHandlers) {
+        this.logoutHandlers = logoutHandlers;
     }
 
     public void doFilter(ServletRequest req, ServletResponse res, FilterChain chain)
             throws IOException, ServletException {
-        chain.doFilter(requestFactory.create((HttpServletRequest)req), res);
+        chain.doFilter(requestFactory.create((HttpServletRequest)req, (HttpServletResponse) res), res);
+    }
+
+    @Override
+    public void afterPropertiesSet() throws ServletException {
+        super.afterPropertiesSet();
+        requestFactory = isServlet3() ? createServlet3Factory(rolePrefix) : new HttpServlet25RequestFactory(rolePrefix);
+    }
+
+    private HttpServlet3RequestFactory createServlet3Factory(String rolePrefix) {
+        HttpServlet3RequestFactory factory = new HttpServlet3RequestFactory(rolePrefix);
+        factory.setAuthenticationEntryPoint(authenticationEntryPoint);
+        factory.setAuthenticationManager(authenticationManager);
+        factory.setLogoutHandlers(logoutHandlers);
+        return factory;
+    }
+
+    /**
+     * Returns true if the Servlet 3 APIs are detected.
+     * @return
+     */
+    private boolean isServlet3() {
+        return ClassUtils.hasMethod(ServletRequest.class, "startAsync");
     }
 }

+ 7 - 2
web/src/main/java/org/springframework/security/web/servletapi/SecurityContextHolderAwareRequestWrapper.java

@@ -33,8 +33,13 @@ import org.springframework.security.core.userdetails.UserDetails;
 /**
  * A Spring Security-aware <code>HttpServletRequestWrapper</code>, which uses the
  * <code>SecurityContext</code>-defined <code>Authentication</code> object to implement the servlet API security
- * methods {@link SecurityContextHolderAwareRequestWrapper#isUserInRole(String)} and {@link
- * HttpServletRequestWrapper#getRemoteUser()}.
+ * methods:
+ *
+ * <ul>
+ * <li>{@link #getUserPrincipal()}</li>
+ * <li>{@link SecurityContextHolderAwareRequestWrapper#isUserInRole(String)}</li>
+ * <li>{@link HttpServletRequestWrapper#getRemoteUser()}.</li>
+ * </ul>
  *
  * @see SecurityContextHolderAwareRequestFilter
  *

+ 208 - 4
web/src/test/java/org/springframework/security/web/servletapi/SecurityContextHolderAwareRequestFilterTests.java

@@ -1,4 +1,5 @@
-/* Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
+/* Copyright 2002-2012 the original author or authors.
+ * Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -15,31 +16,101 @@
 
 package org.springframework.security.web.servletapi;
 
+import static org.fest.assertions.Assertions.assertThat;
 import static org.mockito.Matchers.any;
-import static org.mockito.Mockito.*;
+import static org.mockito.Matchers.anyString;
+import static org.mockito.Matchers.eq;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.powermock.api.mockito.PowerMockito.doThrow;
+import static org.powermock.api.mockito.PowerMockito.mock;
+import static org.powermock.api.mockito.PowerMockito.verifyZeroInteractions;
+import static org.powermock.api.mockito.PowerMockito.when;
 
+import java.util.Arrays;
+import java.util.List;
+
+import javax.servlet.AsyncContext;
 import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
+import junit.framework.Assert;
+
+import org.junit.After;
+import org.junit.Before;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.Captor;
+import org.mockito.Mock;
+import org.powermock.core.classloader.annotations.PrepareForTest;
+import org.powermock.modules.junit4.PowerMockRunner;
+import org.powermock.reflect.internal.WhiteboxImpl;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.AuthenticationManager;
+import org.springframework.security.authentication.BadCredentialsException;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
+import org.springframework.security.concurrent.DelegatingSecurityContextRunnable;
+import org.springframework.security.core.AuthenticationException;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.web.AuthenticationEntryPoint;
+import org.springframework.security.web.authentication.logout.LogoutHandler;
+import org.springframework.util.ClassUtils;
 
 
 /**
  * Tests {@link SecurityContextHolderAwareRequestFilter}.
  *
  * @author Ben Alex
+ * @author Rob Winch
  */
+@RunWith(PowerMockRunner.class)
+@PrepareForTest(ClassUtils.class)
 public class SecurityContextHolderAwareRequestFilterTests {
+    @Captor
+    private ArgumentCaptor<HttpServletRequest> requestCaptor;
+    @Mock
+    private AuthenticationManager authenticationManager;
+    @Mock
+    private AuthenticationEntryPoint authenticationEntryPoint;
+    @Mock
+    private LogoutHandler logoutHandler;
+    @Mock
+    private FilterChain filterChain;
+    @Mock
+    private HttpServletRequest request;
+    @Mock
+    private HttpServletResponse response;
+
+    private List<LogoutHandler> logoutHandlers;
+
+    private SecurityContextHolderAwareRequestFilter filter;
+
+    @Before
+    public void setUp() throws Exception {
+        logoutHandlers = Arrays.asList(logoutHandler);
+        filter = new SecurityContextHolderAwareRequestFilter();
+        filter.setAuthenticationEntryPoint(authenticationEntryPoint);
+        filter.setAuthenticationManager(authenticationManager);
+        filter.setLogoutHandlers(logoutHandlers);
+        filter.afterPropertiesSet();
+    }
+
+    @After
+    public void clearContext() {
+        SecurityContextHolder.clearContext();
+    }
 
     //~ Methods ========================================================================================================
 
     @Test
     public void expectedRequestWrapperClassIsUsed() throws Exception {
-        SecurityContextHolderAwareRequestFilter filter = new SecurityContextHolderAwareRequestFilter();
         filter.setRolePrefix("ROLE_");
-        final FilterChain filterChain = mock(FilterChain.class);
 
         filter.doFilter(new MockHttpServletRequest(), new MockHttpServletResponse(), filterChain);
 
@@ -50,4 +121,137 @@ public class SecurityContextHolderAwareRequestFilterTests {
 
         filter.destroy();
     }
+
+    @Test
+    public void authenticateFalse() throws Exception {
+        assertThat(wrappedRequest().authenticate(response)).isFalse();
+        verify(authenticationEntryPoint).commence(eq(requestCaptor.getValue()), eq(response), any(AuthenticationException.class));
+        verifyZeroInteractions(authenticationManager, logoutHandler);
+        verify(request, times(0)).authenticate(any(HttpServletResponse.class));
+    }
+
+    @Test
+    public void authenticateTrue() throws Exception {
+        SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("test","password","ROLE_USER"));
+
+        assertThat(wrappedRequest().authenticate(response)).isTrue();
+        verifyZeroInteractions(authenticationEntryPoint, authenticationManager, logoutHandler);
+        verify(request, times(0)).authenticate(any(HttpServletResponse.class));
+    }
+
+    @Test
+    public void authenticateNullEntryPointFalse() throws Exception {
+        filter.setAuthenticationEntryPoint(null);
+        filter.afterPropertiesSet();
+
+        assertThat(wrappedRequest().authenticate(response)).isFalse();
+        verify(request).authenticate(response);
+        verifyZeroInteractions(authenticationEntryPoint, authenticationManager, logoutHandler);
+    }
+
+    @Test
+    public void authenticateNullEntryPointTrue() throws Exception {
+        when(request.authenticate(response)).thenReturn(true);
+        filter.setAuthenticationEntryPoint(null);
+        filter.afterPropertiesSet();
+
+        assertThat(wrappedRequest().authenticate(response)).isTrue();
+        verify(request).authenticate(response);
+        verifyZeroInteractions(authenticationEntryPoint, authenticationManager, logoutHandler);
+    }
+
+    @Test
+    public void login() throws Exception {
+        TestingAuthenticationToken expectedAuth = new TestingAuthenticationToken("user", "password","ROLE_USER");
+        when(authenticationManager.authenticate(any(UsernamePasswordAuthenticationToken.class))).thenReturn(expectedAuth);
+
+        wrappedRequest().login(expectedAuth.getName(),String.valueOf(expectedAuth.getCredentials()));
+
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(expectedAuth);
+        verifyZeroInteractions(authenticationEntryPoint, logoutHandler);
+        verify(request, times(0)).login(anyString(),anyString());
+    }
+
+    @Test
+    public void loginFail() throws Exception {
+        AuthenticationException authException = new BadCredentialsException("Invalid");
+        when(authenticationManager.authenticate(any(UsernamePasswordAuthenticationToken.class))).thenThrow(authException);
+        SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("should","be cleared","ROLE_USER"));
+
+        try {
+            wrappedRequest().login("invalid","credentials");
+            Assert.fail("Expected Exception");
+        } catch(ServletException success) {
+            assertThat(success.getCause()).isEqualTo(authException);
+        }
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
+
+        verifyZeroInteractions(authenticationEntryPoint, logoutHandler);
+        verify(request, times(0)).login(anyString(),anyString());
+    }
+
+    @Test
+    public void loginNullAuthenticationManager() throws Exception {
+        filter.setAuthenticationManager(null);
+        filter.afterPropertiesSet();
+
+        String username = "username";
+        String password = "password";
+
+        wrappedRequest().login(username, password);
+
+        verify(request).login(username, password);
+        verifyZeroInteractions(authenticationEntryPoint, authenticationManager, logoutHandler);
+    }
+
+    @Test
+    public void loginNullAuthenticationManagerFail() throws Exception {
+        filter.setAuthenticationManager(null);
+        filter.afterPropertiesSet();
+
+        String username = "username";
+        String password = "password";
+        ServletException authException = new ServletException("Failed Login");
+        doThrow(authException).when(request).login(username, password);
+
+        try {
+            wrappedRequest().login(username, password);
+            Assert.fail("Expected Exception");
+        } catch(ServletException success) {
+            assertThat(success).isEqualTo(authException);
+        }
+
+        verifyZeroInteractions(authenticationEntryPoint, authenticationManager, logoutHandler);
+    }
+
+    @Test
+    public void logout() throws Exception {
+        TestingAuthenticationToken expectedAuth = new TestingAuthenticationToken("user", "password","ROLE_USER");
+        SecurityContextHolder.getContext().setAuthentication(expectedAuth);
+
+        HttpServletRequest wrappedRequest = wrappedRequest();
+        wrappedRequest.logout();
+
+        verify(logoutHandler).logout(wrappedRequest, response, expectedAuth);
+        verifyZeroInteractions(authenticationManager, logoutHandler);
+        verify(request, times(0)).logout();
+    }
+
+    @Test
+    public void logoutNullLogoutHandler() throws Exception {
+        filter.setLogoutHandlers(null);
+        filter.afterPropertiesSet();
+
+        wrappedRequest().logout();
+
+        verify(request).logout();
+        verifyZeroInteractions(authenticationEntryPoint, authenticationManager, logoutHandler);
+    }
+
+    private HttpServletRequest wrappedRequest() throws Exception {
+        filter.doFilter(request, response, filterChain);
+        verify(filterChain).doFilter(requestCaptor.capture(), any(HttpServletResponse.class));
+
+        return requestCaptor.getValue();
+    }
 }