Sfoglia il codice sorgente

Add multipart configuration to CSRF Kotlin DSL

Fixes gh-8602
Eleftheria Stein 5 anni fa
parent
commit
61060b3a4f

+ 5 - 0
config/src/main/kotlin/org/springframework/security/config/web/server/ServerCsrfDsl.kt

@@ -17,6 +17,7 @@
 package org.springframework.security.config.web.server
 
 import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler
+import org.springframework.security.web.server.csrf.CsrfWebFilter
 import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher
 
@@ -30,12 +31,15 @@ import org.springframework.security.web.server.util.matcher.ServerWebExchangeMat
  * @property csrfTokenRepository the [ServerCsrfTokenRepository] used to persist the CSRF token.
  * @property requireCsrfProtectionMatcher the [ServerWebExchangeMatcher] used to determine when CSRF protection
  * is enabled.
+ * @property tokenFromMultipartDataEnabled if true, the [CsrfWebFilter] should try to resolve the actual CSRF
+ * token from the body of multipart data requests.
  */
 @ServerSecurityMarker
 class ServerCsrfDsl {
     var accessDeniedHandler: ServerAccessDeniedHandler? = null
     var csrfTokenRepository: ServerCsrfTokenRepository? = null
     var requireCsrfProtectionMatcher: ServerWebExchangeMatcher? = null
+    var tokenFromMultipartDataEnabled: Boolean? = null
 
     private var disabled = false
 
@@ -51,6 +55,7 @@ class ServerCsrfDsl {
             accessDeniedHandler?.also { csrf.accessDeniedHandler(accessDeniedHandler) }
             csrfTokenRepository?.also { csrf.csrfTokenRepository(csrfTokenRepository) }
             requireCsrfProtectionMatcher?.also { csrf.requireCsrfProtectionMatcher(requireCsrfProtectionMatcher) }
+            tokenFromMultipartDataEnabled?.also { csrf.tokenFromMultipartDataEnabled(tokenFromMultipartDataEnabled!!) }
             if (disabled) {
                 csrf.disable()
             }

+ 74 - 0
config/src/test/kotlin/org/springframework/security/config/web/server/ServerCsrfDslTests.kt

@@ -217,4 +217,78 @@ class ServerCsrfDslTests {
             }
         }
     }
+
+    @Test
+    fun `csrf when multipart form data and not enabled then denied`() {
+        `when`(MultipartFormDataNotEnabledConfig.TOKEN_REPOSITORY.loadToken(any()))
+                .thenReturn(Mono.just(this.token))
+        `when`(MultipartFormDataNotEnabledConfig.TOKEN_REPOSITORY.generateToken(any()))
+                .thenReturn(Mono.just(this.token))
+        this.spring.register(MultipartFormDataNotEnabledConfig::class.java).autowire()
+
+        this.client.post()
+                .uri("/")
+                .contentType(MediaType.MULTIPART_FORM_DATA)
+                .body(fromMultipartData(this.token.parameterName, this.token.token))
+                .exchange()
+                .expectStatus().isForbidden
+    }
+
+    @EnableWebFluxSecurity
+    @EnableWebFlux
+    open class MultipartFormDataNotEnabledConfig {
+        companion object {
+            var TOKEN_REPOSITORY: ServerCsrfTokenRepository = mock(ServerCsrfTokenRepository::class.java)
+        }
+
+        @Bean
+        open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
+            return http {
+                csrf {
+                    csrfTokenRepository = TOKEN_REPOSITORY
+                }
+            }
+        }
+    }
+
+    @Test
+    fun `csrf when multipart form data and enabled then granted`() {
+        `when`(MultipartFormDataEnabledConfig.TOKEN_REPOSITORY.loadToken(any()))
+                .thenReturn(Mono.just(this.token))
+        `when`(MultipartFormDataEnabledConfig.TOKEN_REPOSITORY.generateToken(any()))
+                .thenReturn(Mono.just(this.token))
+        this.spring.register(MultipartFormDataEnabledConfig::class.java).autowire()
+
+        this.client.post()
+                .uri("/")
+                .contentType(MediaType.MULTIPART_FORM_DATA)
+                .body(fromMultipartData(this.token.parameterName, this.token.token))
+                .exchange()
+                .expectStatus().isOk
+    }
+
+    @EnableWebFluxSecurity
+    @EnableWebFlux
+    open class MultipartFormDataEnabledConfig {
+        companion object {
+            var TOKEN_REPOSITORY: ServerCsrfTokenRepository = mock(ServerCsrfTokenRepository::class.java)
+        }
+
+        @Bean
+        open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
+            return http {
+                csrf {
+                    csrfTokenRepository = TOKEN_REPOSITORY
+                    tokenFromMultipartDataEnabled = true
+                }
+            }
+        }
+
+        @RestController
+        internal class TestController {
+            @PostMapping("/")
+            fun home() {
+            }
+        }
+    }
 }