浏览代码

SEC-924: Implement automatic injection of namespace created RememberMeServices into custom AbstractProcessingFilter based beans.
http://jira.springframework.org/browse/SEC-924. Delayed setting of NullRememberMeServices in AbstractProcessingFilter until afterPropertiesSet method is called, allowing the null value to be read by the namespace and the confgiured RememberMeServices bean injected.

Luke Taylor 17 年之前
父节点
当前提交
e303e8b71a

+ 9 - 3
core/src/main/java/org/springframework/security/ui/AbstractProcessingFilter.java

@@ -147,7 +147,11 @@ public abstract class AbstractProcessingFilter extends SpringSecurityFilter impl
 
     private Properties exceptionMappings = new Properties();
 
-    private RememberMeServices rememberMeServices = new NullRememberMeServices();
+    /** 
+     * Delay use of NullRememberMeServices until initialization so that namespace has a chance to inject
+     * the RememberMeServices implementation into custom implementations.
+     */ 
+    private RememberMeServices rememberMeServices = null;
 
     private TargetUrlResolver targetUrlResolver = new TargetUrlResolverImpl();
     
@@ -218,11 +222,13 @@ public abstract class AbstractProcessingFilter extends SpringSecurityFilter impl
         Assert.isTrue(UrlUtils.isValidRedirectUrl(filterProcessesUrl), filterProcessesUrl + " isn't a valid redirect URL");        
         Assert.hasLength(defaultTargetUrl, "defaultTargetUrl must be specified");
         Assert.isTrue(UrlUtils.isValidRedirectUrl(defaultTargetUrl), defaultTargetUrl + " isn't a valid redirect URL");        
-//        Assert.hasLength(authenticationFailureUrl, "authenticationFailureUrl must be specified");
         Assert.isTrue(UrlUtils.isValidRedirectUrl(authenticationFailureUrl), authenticationFailureUrl + " isn't a valid redirect URL");
         Assert.notNull(authenticationManager, "authenticationManager must be specified");
-        Assert.notNull(rememberMeServices, "rememberMeServices cannot be null");
         Assert.notNull(targetUrlResolver, "targetUrlResolver cannot be null");
+        
+        if (rememberMeServices == null) {
+        	rememberMeServices = new NullRememberMeServices();
+        }
     }
 
     /**

+ 14 - 44
core/src/test/java/org/springframework/security/ui/AbstractProcessingFilterTests.java

@@ -25,6 +25,7 @@ import org.springframework.security.GrantedAuthorityImpl;
 import org.springframework.security.MockAuthenticationManager;
 import org.springframework.security.context.SecurityContextHolder;
 import org.springframework.security.providers.UsernamePasswordAuthenticationToken;
+import org.springframework.security.ui.rememberme.NullRememberMeServices;
 import org.springframework.security.ui.rememberme.TokenBasedRememberMeServices;
 import org.springframework.security.ui.savedrequest.SavedRequest;
 import org.springframework.security.util.PortResolverImpl;
@@ -76,8 +77,7 @@ public class AbstractProcessingFilterTests extends TestCase {
     }
 
     private void executeFilterInContainerSimulator(FilterConfig filterConfig, Filter filter, ServletRequest request,
-        ServletResponse response, FilterChain filterChain)
-        throws ServletException, IOException {
+        ServletResponse response, FilterChain filterChain) throws ServletException, IOException {
         filter.init(filterConfig);
         filter.doFilter(request, response, filterChain);
         filter.destroy();
@@ -115,7 +115,7 @@ public class AbstractProcessingFilterTests extends TestCase {
         SecurityContextHolder.clearContext();
     }
 
-    public void testDefaultProcessesFilterUrlWithPathParameter() {
+    public void testDefaultProcessesFilterUrlMatchesWithPathParameter() {
         MockHttpServletRequest request = createMockRequest();
         MockHttpServletResponse response = new MockHttpServletResponse();
         MockAbstractProcessingFilter filter = new MockAbstractProcessingFilter();
@@ -125,28 +125,6 @@ public class AbstractProcessingFilterTests extends TestCase {
         assertTrue(filter.requiresAuthentication(request, response));
     }
 
-    public void testDoFilterWithNonHttpServletRequestDetected() throws Exception {
-        AbstractProcessingFilter filter = new MockAbstractProcessingFilter();
-
-        try {
-            filter.doFilter(null, new MockHttpServletResponse(), new MockFilterChain());
-            fail("Should have thrown ServletException");
-        } catch (ServletException expected) {
-            assertEquals("Can only process HttpServletRequest", expected.getMessage());
-        }
-    }
-
-    public void testDoFilterWithNonHttpServletResponseDetected() throws Exception {
-        AbstractProcessingFilter filter = new MockAbstractProcessingFilter();
-
-        try {
-            filter.doFilter(new MockHttpServletRequest(null, null), null, new MockFilterChain());
-            fail("Should have thrown ServletException");
-        } catch (ServletException expected) {
-            assertEquals("Can only process HttpServletResponse", expected.getMessage());
-        }
-    }
-
     public void testFailedAuthenticationRedirectsAppropriately() throws Exception {
         // Setup our HTTP request
         MockHttpServletRequest request = createMockRequest();
@@ -209,25 +187,20 @@ public class AbstractProcessingFilterTests extends TestCase {
         assertEquals("test", SecurityContextHolder.getContext().getAuthentication().getPrincipal().toString());
     }
 
-    public void testGettersSetters() {
+    public void testGettersSetters() throws Exception {
         AbstractProcessingFilter filter = new MockAbstractProcessingFilter();
+        filter.setAuthenticationManager(new MockAuthenticationManager());
+        filter.setDefaultTargetUrl("/default");
+        filter.setFilterProcessesUrl("/p");
+        filter.setAuthenticationFailureUrl("/fail");
+        filter.afterPropertiesSet();
+
         assertNotNull(filter.getRememberMeServices());
         filter.setRememberMeServices(new TokenBasedRememberMeServices());
         assertEquals(TokenBasedRememberMeServices.class, filter.getRememberMeServices().getClass());
-
-        filter.setAuthenticationFailureUrl("/x");
-        assertEquals("/x", filter.getAuthenticationFailureUrl());
-
-        filter.setAuthenticationManager(new MockAuthenticationManager());
         assertTrue(filter.getAuthenticationManager() != null);
-
-        filter.setDefaultTargetUrl("/default");
         assertEquals("/default", filter.getDefaultTargetUrl());
-
-        filter.setFilterProcessesUrl("/p");
         assertEquals("/p", filter.getFilterProcessesUrl());
-
-        filter.setAuthenticationFailureUrl("/fail");
         assertEquals("/fail", filter.getAuthenticationFailureUrl());
     }
 
@@ -602,11 +575,13 @@ public class AbstractProcessingFilterTests extends TestCase {
         private boolean grantAccess;
 
         public MockAbstractProcessingFilter(boolean grantAccess) {
+        	setRememberMeServices(new NullRememberMeServices());
             this.grantAccess = grantAccess;
             this.exceptionToThrow = new BadCredentialsException("Mock requested to do so");
         }
 
         public MockAbstractProcessingFilter(AuthenticationException exceptionToThrow) {
+        	setRememberMeServices(new NullRememberMeServices());        	
             this.grantAccess = false;
             this.exceptionToThrow = exceptionToThrow;
         }
@@ -614,8 +589,7 @@ public class AbstractProcessingFilterTests extends TestCase {
         private MockAbstractProcessingFilter() {
         }
 
-        public Authentication attemptAuthentication(HttpServletRequest request)
-            throws AuthenticationException {
+        public Authentication attemptAuthentication(HttpServletRequest request) throws AuthenticationException {
             if (grantAccess) {
                 return new UsernamePasswordAuthenticationToken("test", "test",
                     new GrantedAuthority[] {new GrantedAuthorityImpl("TEST")});
@@ -644,11 +618,7 @@ public class AbstractProcessingFilterTests extends TestCase {
             this.expectToProceed = expectToProceed;
         }
 
-        private MockFilterChain() {
-        }
-
-        public void doFilter(ServletRequest request, ServletResponse response)
-            throws IOException, ServletException {
+        public void doFilter(ServletRequest request, ServletResponse response) throws IOException, ServletException {
             if (expectToProceed) {
                 assertTrue(true);
             } else {