|
@@ -37,6 +37,7 @@ import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequ
|
|
|
import org.springframework.security.web.SecurityFilterChain
|
|
|
import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy
|
|
|
import org.springframework.security.web.csrf.CsrfTokenRepository
|
|
|
+import org.springframework.security.web.csrf.CsrfTokenRequestHandler
|
|
|
import org.springframework.security.web.csrf.DefaultCsrfToken
|
|
|
import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository
|
|
|
import org.springframework.security.web.util.matcher.AntPathRequestMatcher
|
|
@@ -46,6 +47,8 @@ import org.springframework.test.web.servlet.post
|
|
|
import org.springframework.web.bind.annotation.PostMapping
|
|
|
import org.springframework.web.bind.annotation.RestController
|
|
|
import org.springframework.web.servlet.config.annotation.EnableWebMvc
|
|
|
+import javax.servlet.http.HttpServletRequest
|
|
|
+import javax.servlet.http.HttpServletResponse
|
|
|
|
|
|
/**
|
|
|
* Tests for [CsrfDsl]
|
|
@@ -302,4 +305,54 @@ class CsrfDslTests {
|
|
|
fun test2() {
|
|
|
}
|
|
|
}
|
|
|
+
|
|
|
+ @Test
|
|
|
+ fun `CSRF when custom csrf token request handler then handler used`() {
|
|
|
+ this.spring.register(RequestHandlerConfig::class.java).autowire()
|
|
|
+ mockkObject(RequestHandlerConfig.HANDLER)
|
|
|
+ every { RequestHandlerConfig.HANDLER.handle(any(), any(), any()) } returns Unit
|
|
|
+
|
|
|
+ this.mockMvc.get("/test1")
|
|
|
+
|
|
|
+ verify(exactly = 1) { RequestHandlerConfig.HANDLER.handle(any(), any(), any()) }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ fun `POST when custom csrf token request handler then handler used`() {
|
|
|
+ this.spring.register(RequestHandlerConfig::class.java).autowire()
|
|
|
+ mockkObject(RequestHandlerConfig.HANDLER)
|
|
|
+ every { RequestHandlerConfig.HANDLER.handle(any(), any(), any()) } answers {
|
|
|
+ val request: HttpServletRequest = firstArg()
|
|
|
+ val response: HttpServletResponse = secondArg()
|
|
|
+ // Required for LazyCsrfTokenRepository
|
|
|
+ request.setAttribute(HttpServletResponse::class.java.name, response)
|
|
|
+ }
|
|
|
+ every { RequestHandlerConfig.HANDLER.resolveCsrfTokenValue(any(), any()) } returns "token"
|
|
|
+
|
|
|
+ this.mockMvc.post("/test2")
|
|
|
+
|
|
|
+ verify(exactly = 1) { RequestHandlerConfig.HANDLER.handle(any(), any(), any()) }
|
|
|
+ verify(exactly = 1) { RequestHandlerConfig.HANDLER.resolveCsrfTokenValue(any(), any()) }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Configuration
|
|
|
+ @EnableWebSecurity
|
|
|
+ open class RequestHandlerConfig {
|
|
|
+
|
|
|
+ companion object {
|
|
|
+ val HANDLER: CsrfTokenRequestHandler = CsrfTokenRequestHandler { request, response, _ ->
|
|
|
+ request.setAttribute(HttpServletResponse::class.java.name, response)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ @Bean
|
|
|
+ open fun filterChain(http: HttpSecurity): SecurityFilterChain {
|
|
|
+ http {
|
|
|
+ csrf {
|
|
|
+ csrfTokenRequestHandler = HANDLER
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return http.build()
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|