瀏覽代碼

Add csrfTokenRequestResolver to CsrfDsl

Closes gh-11952
Steve Riesenberg 2 年之前
父節點
當前提交
1d706ae13d

+ 3 - 0
config/src/main/kotlin/org/springframework/security/config/web/servlet/CsrfDsl.kt

@@ -20,6 +20,7 @@ import org.springframework.security.config.annotation.web.builders.HttpSecurity
 import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer
 import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy
 import org.springframework.security.web.csrf.CsrfTokenRepository
+import org.springframework.security.web.csrf.CsrfTokenRequestHandler
 import org.springframework.security.web.util.matcher.RequestMatcher
 import javax.servlet.http.HttpServletRequest
 
@@ -39,6 +40,7 @@ class CsrfDsl {
     var csrfTokenRepository: CsrfTokenRepository? = null
     var requireCsrfProtectionMatcher: RequestMatcher? = null
     var sessionAuthenticationStrategy: SessionAuthenticationStrategy? = null
+    var csrfTokenRequestHandler: CsrfTokenRequestHandler? = null
 
     private var ignoringAntMatchers: Array<out String>? = null
     private var ignoringRequestMatchers: Array<out RequestMatcher>? = null
@@ -89,6 +91,7 @@ class CsrfDsl {
             csrfTokenRepository?.also { csrf.csrfTokenRepository(csrfTokenRepository) }
             requireCsrfProtectionMatcher?.also { csrf.requireCsrfProtectionMatcher(requireCsrfProtectionMatcher) }
             sessionAuthenticationStrategy?.also { csrf.sessionAuthenticationStrategy(sessionAuthenticationStrategy) }
+            csrfTokenRequestHandler?.also { csrf.csrfTokenRequestHandler(csrfTokenRequestHandler) }
             ignoringAntMatchers?.also { csrf.ignoringAntMatchers(*ignoringAntMatchers!!) }
             ignoringRequestMatchers?.also { csrf.ignoringRequestMatchers(*ignoringRequestMatchers!!) }
             ignoringRequestMatchersPatterns?.also { csrf.ignoringRequestMatchers(*ignoringRequestMatchersPatterns!!) }

+ 53 - 0
config/src/test/kotlin/org/springframework/security/config/web/servlet/CsrfDslTests.kt

@@ -37,6 +37,7 @@ import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequ
 import org.springframework.security.web.SecurityFilterChain
 import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy
 import org.springframework.security.web.csrf.CsrfTokenRepository
+import org.springframework.security.web.csrf.CsrfTokenRequestHandler
 import org.springframework.security.web.csrf.DefaultCsrfToken
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher
@@ -46,6 +47,8 @@ import org.springframework.test.web.servlet.post
 import org.springframework.web.bind.annotation.PostMapping
 import org.springframework.web.bind.annotation.RestController
 import org.springframework.web.servlet.config.annotation.EnableWebMvc
+import javax.servlet.http.HttpServletRequest
+import javax.servlet.http.HttpServletResponse
 
 /**
  * Tests for [CsrfDsl]
@@ -302,4 +305,54 @@ class CsrfDslTests {
         fun test2() {
         }
     }
+
+    @Test
+    fun `CSRF when custom csrf token request handler then handler used`() {
+        this.spring.register(RequestHandlerConfig::class.java).autowire()
+        mockkObject(RequestHandlerConfig.HANDLER)
+        every { RequestHandlerConfig.HANDLER.handle(any(), any(), any()) } returns Unit
+
+        this.mockMvc.get("/test1")
+
+        verify(exactly = 1) { RequestHandlerConfig.HANDLER.handle(any(), any(), any()) }
+    }
+
+    @Test
+    fun `POST when custom csrf token request handler then handler used`() {
+        this.spring.register(RequestHandlerConfig::class.java).autowire()
+        mockkObject(RequestHandlerConfig.HANDLER)
+        every { RequestHandlerConfig.HANDLER.handle(any(), any(), any()) } answers {
+            val request: HttpServletRequest = firstArg()
+            val response: HttpServletResponse = secondArg()
+            // Required for LazyCsrfTokenRepository
+            request.setAttribute(HttpServletResponse::class.java.name, response)
+        }
+        every { RequestHandlerConfig.HANDLER.resolveCsrfTokenValue(any(), any()) } returns "token"
+
+        this.mockMvc.post("/test2")
+
+        verify(exactly = 1) { RequestHandlerConfig.HANDLER.handle(any(), any(), any()) }
+        verify(exactly = 1) { RequestHandlerConfig.HANDLER.resolveCsrfTokenValue(any(), any()) }
+    }
+
+    @Configuration
+    @EnableWebSecurity
+    open class RequestHandlerConfig {
+
+        companion object {
+            val HANDLER: CsrfTokenRequestHandler = CsrfTokenRequestHandler { request, response, _ ->
+                request.setAttribute(HttpServletResponse::class.java.name, response)
+            }
+        }
+
+        @Bean
+        open fun filterChain(http: HttpSecurity): SecurityFilterChain {
+            http {
+                csrf {
+                    csrfTokenRequestHandler = HANDLER
+                }
+            }
+            return http.build()
+        }
+    }
 }