Browse Source

SEC-2329: Allow injecting of AuthenticationTrustResolver

Rob Winch 12 years ago
parent
commit
788ba9a1fa

+ 15 - 1
core/src/main/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandler.java

@@ -20,6 +20,7 @@ import org.springframework.security.access.expression.ExpressionUtils;
 import org.springframework.security.authentication.AuthenticationTrustResolver;
 import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
 import org.springframework.security.core.Authentication;
+import org.springframework.util.Assert;
 
 /**
  * The standard implementation of {@code MethodSecurityExpressionHandler}.
@@ -33,7 +34,7 @@ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpr
 
     protected final Log logger = LogFactory.getLog(getClass());
 
-    private final AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl();
+    private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl();
     private ParameterNameDiscoverer parameterNameDiscoverer = new LocalVariableTableParameterNameDiscoverer();
     private PermissionCacheOptimizer permissionCacheOptimizer = null;
 
@@ -143,6 +144,19 @@ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpr
         throw new IllegalArgumentException("Filter target must be a collection or array type, but was " + filterTarget);
     }
 
+    /**
+     * Sets the {@link AuthenticationTrustResolver} to be used. The default is
+     * {@link AuthenticationTrustResolverImpl}.
+     *
+     * @param trustResolver
+     *            the {@link AuthenticationTrustResolver} to use. Cannot be
+     *            null.
+     */
+    public void setTrustResolver(AuthenticationTrustResolver trustResolver) {
+        Assert.notNull(trustResolver, "trustResolver cannot be null");
+        this.trustResolver = trustResolver;
+    }
+
     public void setParameterNameDiscoverer(ParameterNameDiscoverer parameterNameDiscoverer) {
         this.parameterNameDiscoverer = parameterNameDiscoverer;
     }

+ 54 - 0
core/src/test/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerTests.java

@@ -0,0 +1,54 @@
+package org.springframework.security.access.expression.method;
+
+import static org.mockito.Mockito.verify;
+
+import org.aopalliance.intercept.MethodInvocation;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
+import org.springframework.expression.EvaluationContext;
+import org.springframework.expression.Expression;
+import org.springframework.security.authentication.AuthenticationTrustResolver;
+import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContextHolder;
+
+@RunWith(MockitoJUnitRunner.class)
+public class DefaultMethodSecurityExpressionHandlerTests {
+    private DefaultMethodSecurityExpressionHandler handler;
+
+    @Mock
+    private Authentication authentication;
+    @Mock
+    private MethodInvocation methodInvocation;
+    @Mock
+    private AuthenticationTrustResolver trustResolver;
+
+    @Before
+    public void setup() {
+        handler = new DefaultMethodSecurityExpressionHandler();
+    }
+
+    @After
+    public void cleanup() {
+        SecurityContextHolder.clearContext();
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void setTrustResolverNull() {
+        handler.setTrustResolver(null);
+    }
+
+    @Test
+    public void createEvaluationContextCustomTrustResolver() {
+        handler.setTrustResolver(trustResolver);
+
+        Expression expression = handler.getExpressionParser().parseExpression("anonymous");
+        EvaluationContext context = handler.createEvaluationContext(authentication, methodInvocation);
+        expression.getValue(context, Boolean.class);
+
+        verify(trustResolver).isAnonymous(authentication);
+    }
+}

+ 15 - 1
web/src/main/java/org/springframework/security/web/access/expression/DefaultWebSecurityExpressionHandler.java

@@ -6,6 +6,7 @@ import org.springframework.security.authentication.AuthenticationTrustResolver;
 import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.web.FilterInvocation;
+import org.springframework.util.Assert;
 
 /**
  *
@@ -15,7 +16,7 @@ import org.springframework.security.web.FilterInvocation;
 @SuppressWarnings("deprecation")
 public class DefaultWebSecurityExpressionHandler extends AbstractSecurityExpressionHandler<FilterInvocation> implements WebSecurityExpressionHandler {
 
-    private final AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl();
+    private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl();
 
     @Override
     protected SecurityExpressionOperations createSecurityExpressionRoot(Authentication authentication, FilterInvocation fi) {
@@ -25,4 +26,17 @@ public class DefaultWebSecurityExpressionHandler extends AbstractSecurityExpress
         root.setRoleHierarchy(getRoleHierarchy());
         return root;
     }
+
+    /**
+     * Sets the {@link AuthenticationTrustResolver} to be used. The default is
+     * {@link AuthenticationTrustResolverImpl}.
+     *
+     * @param trustResolver
+     *            the {@link AuthenticationTrustResolver} to use. Cannot be
+     *            null.
+     */
+    public void setTrustResolver(AuthenticationTrustResolver trustResolver) {
+        Assert.notNull(trustResolver, "trustResolver cannot be null");
+        this.trustResolver = trustResolver;
+    }
 }

+ 15 - 2
web/src/main/java/org/springframework/security/web/context/HttpSessionSecurityContextRepository.java

@@ -69,7 +69,7 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
     private boolean isServlet3 = ClassUtils.hasMethod(ServletRequest.class, "startAsync");
     private String springSecurityContextKey = SPRING_SECURITY_CONTEXT_KEY;
 
-    private final AuthenticationTrustResolver authenticationTrustResolver = new AuthenticationTrustResolverImpl();
+    private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl();
 
     /**
      * Gets the security context for the current request (if available) and returns it.
@@ -295,7 +295,7 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
             HttpSession httpSession = request.getSession(false);
 
             // See SEC-776
-            if (authentication == null || authenticationTrustResolver.isAnonymous(authentication)) {
+            if (authentication == null || trustResolver.isAnonymous(authentication)) {
                 if (logger.isDebugEnabled()) {
                     logger.debug("SecurityContext is empty or contents are anonymous - context will not be stored in HttpSession.");
                 }
@@ -378,4 +378,17 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo
             return null;
         }
     }
+
+    /**
+     * Sets the {@link AuthenticationTrustResolver} to be used. The default is
+     * {@link AuthenticationTrustResolverImpl}.
+     *
+     * @param trustResolver
+     *            the {@link AuthenticationTrustResolver} to use. Cannot be
+     *            null.
+     */
+    public void setTrustResolver(AuthenticationTrustResolver trustResolver) {
+        Assert.notNull(trustResolver, "trustResolver cannot be null");
+        this.trustResolver = trustResolver;
+    }
 }

+ 7 - 2
web/src/main/java/org/springframework/security/web/servletapi/HttpServlet25RequestFactory.java

@@ -15,6 +15,9 @@ package org.springframework.security.web.servletapi;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
+import org.springframework.security.authentication.AuthenticationTrustResolver;
+import org.springframework.util.Assert;
+
 /**
  * Creates a {@link SecurityContextHolderAwareRequestWrapper}
  *
@@ -23,12 +26,14 @@ import javax.servlet.http.HttpServletResponse;
  */
 final class HttpServlet25RequestFactory implements HttpServletRequestFactory {
     private final String rolePrefix;
+    private final AuthenticationTrustResolver trustResolver;
 
-    HttpServlet25RequestFactory(String rolePrefix) {
+    HttpServlet25RequestFactory(AuthenticationTrustResolver trustResolver, String rolePrefix) {
+        this.trustResolver = trustResolver;
         this.rolePrefix = rolePrefix;
     }
 
     public HttpServletRequest create(HttpServletRequest request, HttpServletResponse response) {
-        return new SecurityContextHolderAwareRequestWrapper(request, rolePrefix) ;
+        return new SecurityContextHolderAwareRequestWrapper(request, trustResolver, rolePrefix) ;
     }
 }

+ 19 - 1
web/src/main/java/org/springframework/security/web/servletapi/HttpServlet3RequestFactory.java

@@ -29,6 +29,8 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.springframework.security.authentication.AuthenticationCredentialsNotFoundException;
 import org.springframework.security.authentication.AuthenticationManager;
+import org.springframework.security.authentication.AuthenticationTrustResolver;
+import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
 import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
 import org.springframework.security.concurrent.DelegatingSecurityContextRunnable;
 import org.springframework.security.core.Authentication;
@@ -37,6 +39,7 @@ import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.AuthenticationEntryPoint;
 import org.springframework.security.web.authentication.logout.LogoutHandler;
+import org.springframework.util.Assert;
 
 /**
  * Provides integration with the Servlet 3 APIs in addition to the ones found in {@link HttpServlet25RequestFactory}.
@@ -66,6 +69,7 @@ final class HttpServlet3RequestFactory implements HttpServletRequestFactory {
     private Log logger = LogFactory.getLog(getClass());
 
     private final String rolePrefix;
+    private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl();
     private AuthenticationEntryPoint authenticationEntryPoint;
     private AuthenticationManager authenticationManager;
     private List<LogoutHandler> logoutHandlers;
@@ -128,6 +132,20 @@ final class HttpServlet3RequestFactory implements HttpServletRequestFactory {
         this.logoutHandlers = logoutHandlers;
     }
 
+    /**
+     * Sets the {@link AuthenticationTrustResolver} to be used. The default is
+     * {@link AuthenticationTrustResolverImpl}.
+     *
+     * @param trustResolver
+     *            the {@link AuthenticationTrustResolver} to use. Cannot be
+     *            null.
+     */
+    public void setTrustResolver(AuthenticationTrustResolver trustResolver) {
+        Assert.notNull(trustResolver, "trustResolver cannot be null");
+        this.trustResolver = trustResolver;
+    }
+
+
     public HttpServletRequest create(HttpServletRequest request, HttpServletResponse response) {
          return new Servlet3SecurityContextHolderAwareRequestWrapper(request, rolePrefix, response);
     }
@@ -136,7 +154,7 @@ final class HttpServlet3RequestFactory implements HttpServletRequestFactory {
         private final HttpServletResponse response;
 
         public Servlet3SecurityContextHolderAwareRequestWrapper(HttpServletRequest request, String rolePrefix, HttpServletResponse response) {
-            super(request, rolePrefix);
+            super(request, trustResolver, rolePrefix);
             this.response = response;
         }
 

+ 21 - 2
web/src/main/java/org/springframework/security/web/servletapi/SecurityContextHolderAwareRequestFilter.java

@@ -28,6 +28,8 @@ import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
 
 import org.springframework.security.authentication.AuthenticationManager;
+import org.springframework.security.authentication.AuthenticationTrustResolver;
+import org.springframework.security.authentication.AuthenticationTrustResolverImpl;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.AuthenticationEntryPoint;
@@ -78,6 +80,8 @@ public class SecurityContextHolderAwareRequestFilter extends GenericFilterBean {
 
     private List<LogoutHandler> logoutHandlers;
 
+    private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl();
+
     //~ Methods ========================================================================================================
 
     public void setRolePrefix(String rolePrefix) {
@@ -153,11 +157,26 @@ public class SecurityContextHolderAwareRequestFilter extends GenericFilterBean {
     @Override
     public void afterPropertiesSet() throws ServletException {
         super.afterPropertiesSet();
-        requestFactory = isServlet3() ? createServlet3Factory(rolePrefix) : new HttpServlet25RequestFactory(rolePrefix);
+        requestFactory = isServlet3() ? createServlet3Factory(rolePrefix) : new HttpServlet25RequestFactory(trustResolver, rolePrefix);
     }
 
-    private HttpServlet3RequestFactory createServlet3Factory(String rolePrefix) {
+    /**
+     * Sets the {@link AuthenticationTrustResolver} to be used. The default is
+     * {@link AuthenticationTrustResolverImpl}.
+     *
+     * @param trustResolver
+     *            the {@link AuthenticationTrustResolver} to use. Cannot be
+     *            null.
+     */
+    public void setTrustResolver(AuthenticationTrustResolver trustResolver) {
+        Assert.notNull(trustResolver, "trustResolver cannot be null");
+        this.trustResolver = trustResolver;
+    }
+
+
+    private HttpServletRequestFactory createServlet3Factory(String rolePrefix) {
         HttpServlet3RequestFactory factory = new HttpServlet3RequestFactory(rolePrefix);
+        factory.setTrustResolver(trustResolver);
         factory.setAuthenticationEntryPoint(authenticationEntryPoint);
         factory.setAuthenticationManager(authenticationManager);
         factory.setLogoutHandlers(logoutHandlers);

+ 25 - 3
web/src/main/java/org/springframework/security/web/servletapi/SecurityContextHolderAwareRequestWrapper.java

@@ -28,6 +28,7 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.userdetails.UserDetails;
+import org.springframework.util.Assert;
 
 
 /**
@@ -46,11 +47,12 @@ import org.springframework.security.core.userdetails.UserDetails;
  * @author Orlando Garcia Carmona
  * @author Ben Alex
  * @author Luke Taylor
+ * @author Rob Winch
  */
 public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequestWrapper {
     //~ Instance fields ================================================================================================
 
-    private final AuthenticationTrustResolver authenticationTrustResolver = new AuthenticationTrustResolverImpl();
+    private final AuthenticationTrustResolver trustResolver;
 
     /**
      * The prefix passed by the filter. It will be prepended to any supplied role values before
@@ -60,10 +62,30 @@ public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequest
 
     //~ Constructors ===================================================================================================
 
+    /**
+     * Creates a new instance with {@link AuthenticationTrustResolverImpl}.
+     *
+     * @param request
+     * @param rolePrefix
+     */
     public SecurityContextHolderAwareRequestWrapper(HttpServletRequest request, String rolePrefix) {
-        super(request);
+        this(request, new AuthenticationTrustResolverImpl(), rolePrefix);
+    }
 
+    /**
+     * Creates a new instance
+     *
+     * @param request the original {@link HttpServletRequest}
+     * @param trustResolver
+     *            the {@link AuthenticationTrustResolver} to use. Cannot be
+     *            null.
+     * @param rolePrefix The prefix to be added to {@link #isUserInRole(String)} or null if no prefix.
+     */
+    public SecurityContextHolderAwareRequestWrapper(HttpServletRequest request, AuthenticationTrustResolver trustResolver, String rolePrefix) {
+        super(request);
+        Assert.notNull(trustResolver, "trustResolver cannot be null");
         this.rolePrefix = rolePrefix;
+        this.trustResolver = trustResolver;
     }
 
     //~ Methods ========================================================================================================
@@ -76,7 +98,7 @@ public class SecurityContextHolderAwareRequestWrapper extends HttpServletRequest
     private Authentication getAuthentication() {
         Authentication auth = SecurityContextHolder.getContext().getAuthentication();
 
-        if (!authenticationTrustResolver.isAnonymous(auth)) {
+        if (!trustResolver.isAnonymous(auth)) {
             return auth;
         }
 

+ 16 - 2
web/src/main/java/org/springframework/security/web/session/SessionManagementFilter.java

@@ -40,7 +40,7 @@ public class SessionManagementFilter extends GenericFilterBean {
 
     private final SecurityContextRepository securityContextRepository;
     private SessionAuthenticationStrategy sessionAuthenticationStrategy;
-    private final AuthenticationTrustResolver authenticationTrustResolver = new AuthenticationTrustResolverImpl();
+    private AuthenticationTrustResolver trustResolver = new AuthenticationTrustResolverImpl();
     private InvalidSessionStrategy invalidSessionStrategy = null;
     private AuthenticationFailureHandler failureHandler = new SimpleUrlAuthenticationFailureHandler();
 
@@ -70,7 +70,7 @@ public class SessionManagementFilter extends GenericFilterBean {
         if (!securityContextRepository.containsContext(request)) {
             Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
 
-            if (authentication != null && !authenticationTrustResolver.isAnonymous(authentication)) {
+            if (authentication != null && !trustResolver.isAnonymous(authentication)) {
              // The user has been authenticated during the current request, so call the session strategy
                 try {
                     sessionAuthenticationStrategy.onAuthentication(authentication, request, response);
@@ -136,4 +136,18 @@ public class SessionManagementFilter extends GenericFilterBean {
         Assert.notNull(failureHandler, "failureHandler cannot be null");
         this.failureHandler = failureHandler;
     }
+
+
+    /**
+     * Sets the {@link AuthenticationTrustResolver} to be used. The default is
+     * {@link AuthenticationTrustResolverImpl}.
+     *
+     * @param trustResolver
+     *            the {@link AuthenticationTrustResolver} to use. Cannot be
+     *            null.
+     */
+    public void setTrustResolver(AuthenticationTrustResolver trustResolver) {
+        Assert.notNull(trustResolver, "trustResolver cannot be null");
+        this.trustResolver = trustResolver;
+    }
 }

+ 62 - 1
web/src/test/java/org/springframework/security/web/access/expression/DefaultWebSecurityExpressionHandlerTests.java

@@ -1,22 +1,67 @@
+/*
+ * Copyright 2002-2013 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
 package org.springframework.security.web.access.expression;
 
+import static org.fest.assertions.Assertions.assertThat;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 
+import org.junit.After;
+import org.junit.Before;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.runners.MockitoJUnitRunner;
 import org.springframework.beans.factory.support.RootBeanDefinition;
 import org.springframework.context.support.StaticApplicationContext;
 import org.springframework.expression.EvaluationContext;
+import org.springframework.expression.Expression;
 import org.springframework.expression.ExpressionParser;
 import org.springframework.security.access.SecurityConfig;
+import org.springframework.security.authentication.AuthenticationTrustResolver;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.FilterInvocation;
 
+@RunWith(MockitoJUnitRunner.class)
 public class DefaultWebSecurityExpressionHandlerTests {
+    @Mock
+    private AuthenticationTrustResolver trustResolver;
+
+    @Mock
+    private Authentication authentication;
+
+    @Mock
+    private FilterInvocation invocation;
+
+    private DefaultWebSecurityExpressionHandler handler;
+
+    @Before
+    public void setup() {
+        handler = new DefaultWebSecurityExpressionHandler();
+    }
+
+    @After
+    public void cleanup() {
+        SecurityContextHolder.clearContext();
+    }
 
     @Test
     public void expressionPropertiesAreResolvedAgainsAppContextBeans() throws Exception {
-        DefaultWebSecurityExpressionHandler handler = new DefaultWebSecurityExpressionHandler();
         StaticApplicationContext appContext = new StaticApplicationContext();
         RootBeanDefinition bean = new RootBeanDefinition(SecurityConfig.class);
         bean.getConstructorArgumentValues().addGenericArgumentValue("ROLE_A");
@@ -29,4 +74,20 @@ public class DefaultWebSecurityExpressionHandlerTests {
         assertTrue(parser.parseExpression("@role.attribute == 'ROLE_A'").getValue(ctx, Boolean.class));
     }
 
+
+    @Test(expected = IllegalArgumentException.class)
+    public void setTrustResolverNull() {
+        handler.setTrustResolver(null);
+    }
+
+    @Test
+    public void createEvaluationContextCustomTrustResolver() {
+        handler.setTrustResolver(trustResolver);
+
+        Expression expression = handler.getExpressionParser().parseExpression("anonymous");
+        EvaluationContext context = handler.createEvaluationContext(authentication, invocation);
+        assertThat(expression.getValue(context, Boolean.class)).isFalse();
+
+        verify(trustResolver).isAnonymous(authentication);
+    }
 }

+ 34 - 5
web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java

@@ -13,13 +13,19 @@
 package org.springframework.security.web.context;
 
 import static org.fest.assertions.Assertions.assertThat;
-import static org.junit.Assert.*;
-import static org.springframework.security.web.context.HttpSessionSecurityContextRepository.*;
-import static org.mockito.Mockito.verify;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertSame;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Matchers.anyBoolean;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.reset;
-import static org.mockito.Matchers.*;
-import static org.powermock.api.mockito.PowerMockito.*;
+import static org.mockito.Mockito.verify;
+import static org.powermock.api.mockito.PowerMockito.mock;
+import static org.powermock.api.mockito.PowerMockito.spy;
+import static org.powermock.api.mockito.PowerMockito.when;
+import static org.springframework.security.web.context.HttpSessionSecurityContextRepository.SPRING_SECURITY_CONTEXT_KEY;
 
 import javax.servlet.ServletOutputStream;
 import javax.servlet.ServletRequest;
@@ -35,6 +41,7 @@ import org.powermock.modules.junit4.PowerMockRunner;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
+import org.springframework.security.authentication.AuthenticationTrustResolver;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContext;
@@ -466,4 +473,26 @@ public class HttpSessionSecurityContextRepositoryTests {
         assertEquals(url, holder.getResponse().encodeUrl(url));
         assertEquals(url, holder.getResponse().encodeURL(url));
     }
+
+    @Test
+    public void saveContextCustomTrustResolver() {
+        SecurityContext contextToSave = SecurityContextHolder.createEmptyContext();
+        contextToSave.setAuthentication(testToken);
+        HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
+        MockHttpServletRequest request = new MockHttpServletRequest();
+        HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, new MockHttpServletResponse());
+        repo.loadContext(holder);
+        AuthenticationTrustResolver trustResolver = mock(AuthenticationTrustResolver.class);
+        repo.setTrustResolver(trustResolver);
+
+        repo.saveContext(contextToSave, holder.getRequest(), holder.getResponse());
+
+        verify(trustResolver).isAnonymous(contextToSave.getAuthentication());
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void setTrustResolverNull() {
+        HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository();
+        repo.setTrustResolver(null);
+    }
 }

+ 38 - 1
web/src/test/java/org/springframework/security/web/session/SessionManagementFilterTests.java

@@ -1,3 +1,18 @@
+/*
+ * Copyright 2002-2013 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ */
 package org.springframework.security.web.session;
 
 import static org.junit.Assert.*;
@@ -13,10 +28,10 @@ import org.junit.Test;
 import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.security.authentication.AuthenticationTrustResolver;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
-import org.springframework.security.web.DefaultRedirectStrategy;
 import org.springframework.security.web.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.authentication.session.SessionAuthenticationException;
 import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
@@ -24,6 +39,7 @@ import org.springframework.security.web.context.SecurityContextRepository;
 
 /**
  * @author Luke Taylor
+ * @author Rob Winch
  */
 public class SessionManagementFilterTests {
 
@@ -143,6 +159,27 @@ public class SessionManagementFilterTests {
         assertEquals("/timedOut", response.getRedirectedUrl());
     }
 
+    @Test
+    public void customAuthenticationTrustResolver() throws Exception {
+        AuthenticationTrustResolver trustResolver= mock(AuthenticationTrustResolver.class);
+        SecurityContextRepository repo = mock(SecurityContextRepository.class);
+        SessionManagementFilter filter = new SessionManagementFilter(repo);
+        filter.setTrustResolver(trustResolver);
+        HttpServletRequest request = new MockHttpServletRequest();
+        authenticateUser();
+
+        filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain());
+
+        verify(trustResolver).isAnonymous(any(Authentication.class));
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void setTrustResolverNull() {
+        SecurityContextRepository repo = mock(SecurityContextRepository.class);
+        SessionManagementFilter filter = new SessionManagementFilter(repo);
+        filter.setTrustResolver(null);
+    }
+
     private void authenticateUser() {
         SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "pass"));
     }