소스 검색

SEC-2495: CSRF disables logout on GET

Rob Winch 11 년 전
부모
커밋
04a527d4ec

+ 18 - 4
config/src/main/java/org/springframework/security/config/http/LogoutBeanDefinitionParser.java

@@ -46,10 +46,12 @@ class LogoutBeanDefinitionParser implements BeanDefinitionParser {
 
     final String rememberMeServices;
     private ManagedList<BeanMetadataElement> logoutHandlers = new ManagedList<BeanMetadataElement>();
+    private boolean csrfEnabled;
 
     public LogoutBeanDefinitionParser(String rememberMeServices, BeanMetadataElement csrfLogoutHandler) {
         this.rememberMeServices = rememberMeServices;
-        if(csrfLogoutHandler != null) {
+        this.csrfEnabled = csrfLogoutHandler != null;
+        if(this.csrfEnabled) {
             logoutHandlers.add(csrfLogoutHandler);
         }
     }
@@ -78,10 +80,9 @@ class LogoutBeanDefinitionParser implements BeanDefinitionParser {
         if (!StringUtils.hasText(logoutUrl)) {
             logoutUrl = DEF_LOGOUT_URL;
         }
-        BeanDefinitionBuilder matcherBuilder = BeanDefinitionBuilder.rootBeanDefinition("org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter$FilterProcessUrlRequestMatcher");
-        matcherBuilder.addConstructorArgValue(logoutUrl);
 
-        builder.addPropertyValue("logoutRequestMatcher", matcherBuilder.getBeanDefinition());
+
+        builder.addPropertyValue("logoutRequestMatcher", getLogoutRequestMatcher(logoutUrl));
 
         if (StringUtils.hasText(successHandlerRef)) {
             if (StringUtils.hasText(logoutSuccessUrl)) {
@@ -117,6 +118,19 @@ class LogoutBeanDefinitionParser implements BeanDefinitionParser {
         return builder.getBeanDefinition();
     }
 
+    private BeanDefinition getLogoutRequestMatcher(String logoutUrl) {
+        if(this.csrfEnabled) {
+            BeanDefinitionBuilder matcherBuilder = BeanDefinitionBuilder.rootBeanDefinition("org.springframework.security.web.util.matcher.AntPathRequestMatcher");
+            matcherBuilder.addConstructorArgValue(logoutUrl);
+            matcherBuilder.addConstructorArgValue("POST");
+            return matcherBuilder.getBeanDefinition();
+        } else {
+            BeanDefinitionBuilder matcherBuilder = BeanDefinitionBuilder.rootBeanDefinition("org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter$FilterProcessUrlRequestMatcher");
+            matcherBuilder.addConstructorArgValue(logoutUrl);
+            return matcherBuilder.getBeanDefinition();
+        }
+    }
+
     ManagedList<BeanMetadataElement> getLogoutHandlers() {
         return logoutHandlers;
     }

+ 49 - 13
config/src/test/groovy/org/springframework/security/config/http/CsrfConfigTests.groovy

@@ -12,27 +12,29 @@
  */
 package org.springframework.security.config.http
 
-import static org.mockito.Mockito.*
 import static org.mockito.Matchers.*
+import static org.mockito.Mockito.*
 
-import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletRequest
 import javax.servlet.http.HttpServletResponse
 
-import org.spockframework.compiler.model.WhenBlock;
 import org.springframework.mock.web.MockFilterChain
 import org.springframework.mock.web.MockHttpServletRequest
 import org.springframework.mock.web.MockHttpServletResponse
-import org.springframework.security.access.AccessDeniedException;
-import org.springframework.security.config.annotation.web.configurers.CsrfConfigurerTests.CsrfTokenRepositoryConfig;
-import org.springframework.security.config.annotation.web.configurers.CsrfConfigurerTests.RequireCsrfProtectionMatcherConfig
-import org.springframework.security.web.access.AccessDeniedHandler;
+import org.springframework.security.access.AccessDeniedException
+import org.springframework.security.authentication.UsernamePasswordAuthenticationToken
+import org.springframework.security.core.Authentication
+import org.springframework.security.core.authority.AuthorityUtils
+import org.springframework.security.core.context.SecurityContextImpl
+import org.springframework.security.web.access.AccessDeniedHandler
+import org.springframework.security.web.context.HttpRequestResponseHolder
+import org.springframework.security.web.context.HttpSessionSecurityContextRepository
 import org.springframework.security.web.csrf.CsrfFilter
-import org.springframework.security.web.csrf.CsrfToken;
-import org.springframework.security.web.csrf.CsrfTokenRepository;
-import org.springframework.security.web.csrf.DefaultCsrfToken;
-import org.springframework.security.web.servlet.support.csrf.CsrfRequestDataValueProcessor
+import org.springframework.security.web.csrf.CsrfToken
+import org.springframework.security.web.csrf.CsrfTokenRepository
+import org.springframework.security.web.csrf.DefaultCsrfToken
 import org.springframework.security.web.util.matcher.RequestMatcher
-import org.springframework.web.servlet.support.RequestDataValueProcessor;
+import org.springframework.web.servlet.support.RequestDataValueProcessor
 
 import spock.lang.Unroll
 
@@ -203,6 +205,7 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
             }
             mockBean(RequestMatcher,'matcher')
             createAppContext()
+            request.method = 'POST'
             RequestMatcher matcher = appContext.getBean("matcher",RequestMatcher)
         when:
             when(matcher.matches(any(HttpServletRequest))).thenReturn(false)
@@ -272,10 +275,43 @@ class CsrfConfigTests extends AbstractHttpConfigTests {
             when(repo.loadToken(any(HttpServletRequest))).thenReturn(token)
             request.setParameter(token.parameterName,token.token)
             request.method = "POST"
-            request.requestURI = "/j_spring_security_logout"
+            request.servletPath = "/j_spring_security_logout"
         when:
             springSecurityFilterChain.doFilter(request,response,chain)
         then:
             verify(repo).saveToken(eq(null),any(HttpServletRequest), any(HttpServletResponse))
     }
+
+        def "SEC-2495: csrf disables logout on GET"() {
+            setup:
+                httpAutoConfig {
+                    'csrf'()
+                }
+                createAppContext()
+                login()
+                request.method = "GET"
+                request.requestURI = "/j_spring_security_logout"
+            when:
+                springSecurityFilterChain.doFilter(request,response,chain)
+            then:
+                getAuthentication(request) != null
+        }
+
+
+        def login(String username="user", String role="ROLE_USER") {
+            login(new UsernamePasswordAuthenticationToken(username, null, AuthorityUtils.createAuthorityList(role)))
+        }
+
+        def login(Authentication auth) {
+            HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository()
+            HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder(request, response)
+            repo.loadContext(requestResponseHolder)
+            repo.saveContext(new SecurityContextImpl(authentication:auth), requestResponseHolder.request, requestResponseHolder.response)
+        }
+
+        def getAuthentication(HttpServletRequest request) {
+            HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository()
+            HttpRequestResponseHolder requestResponseHolder = new HttpRequestResponseHolder(request, response)
+            repo.loadContext(requestResponseHolder)?.authentication
+        }
 }