浏览代码

Support custom filter in Server Kotlin DSL

Closes gh-8783
Evgeniy Cheban 5 年之前
父节点
当前提交
0a2006ebec

+ 76 - 0
config/src/main/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDsl.kt

@@ -20,6 +20,7 @@ import org.springframework.security.oauth2.client.registration.ReactiveClientReg
 import org.springframework.security.web.server.SecurityWebFilterChain
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher
 import org.springframework.web.server.ServerWebExchange
+import org.springframework.web.server.WebFilter
 
 /**
  * Configures [ServerHttpSecurity] using a [ServerHttpSecurity Kotlin DSL][ServerHttpSecurityDsl].
@@ -89,6 +90,81 @@ class ServerHttpSecurityDsl(private val http: ServerHttpSecurity, private val in
         this.http.securityMatcher(securityMatcher)
     }
 
+    /**
+     * Adds a [WebFilter] at a specific position.
+     *
+     * Example:
+     *
+     * ```
+     * @EnableWebFluxSecurity
+     * class SecurityConfig {
+     *
+     *  @Bean
+     *  fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
+     *      return http {
+     *          addFilterAt(CustomWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE)
+     *       }
+     *   }
+     * }
+     * ```
+     *
+     * @param webFilter the [WebFilter] to add
+     * @param order the place to insert the [WebFilter]
+     */
+    fun addFilterAt(webFilter: WebFilter, order: SecurityWebFiltersOrder) {
+        this.http.addFilterAt(webFilter, order)
+    }
+
+    /**
+     * Adds a [WebFilter] before specific position.
+     *
+     * Example:
+     *
+     * ```
+     * @EnableWebFluxSecurity
+     * class SecurityConfig {
+     *
+     *  @Bean
+     *  fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
+     *      return http {
+     *          addFilterBefore(CustomWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE)
+     *       }
+     *   }
+     * }
+     * ```
+     *
+     * @param webFilter the [WebFilter] to add
+     * @param order the place before which to insert the [WebFilter]
+     */
+    fun addFilterBefore(webFilter: WebFilter, order: SecurityWebFiltersOrder) {
+        this.http.addFilterBefore(webFilter, order)
+    }
+
+    /**
+     * Adds a [WebFilter] after specific position.
+     *
+     * Example:
+     *
+     * ```
+     * @EnableWebFluxSecurity
+     * class SecurityConfig {
+     *
+     *  @Bean
+     *  fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
+     *      return http {
+     *          addFilterAfter(CustomWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE)
+     *       }
+     *   }
+     * }
+     * ```
+     *
+     * @param webFilter the [WebFilter] to add
+     * @param order the place after which to insert the [WebFilter]
+     */
+    fun addFilterAfter(webFilter: WebFilter, order: SecurityWebFiltersOrder) {
+        this.http.addFilterAfter(webFilter, order)
+    }
+
     /**
      * Enables form based authentication.
      *

+ 76 - 0
config/src/test/kotlin/org/springframework/security/config/web/server/ServerHttpSecurityDslTests.kt

@@ -16,6 +16,7 @@
 
 package org.springframework.security.config.web.server
 
+import org.assertj.core.api.Assertions.assertThat
 import org.junit.Rule
 import org.junit.Test
 import org.springframework.beans.factory.annotation.Autowired
@@ -26,6 +27,7 @@ import org.springframework.security.config.annotation.web.reactive.EnableWebFlux
 import org.springframework.security.config.test.SpringTestRule
 import org.springframework.security.web.header.writers.frameoptions.XFrameOptionsHeaderWriter
 import org.springframework.security.web.server.SecurityWebFilterChain
+import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter
 import org.springframework.security.web.server.header.ContentTypeOptionsServerHttpHeadersWriter
 import org.springframework.security.web.server.header.StrictTransportSecurityServerHttpHeadersWriter
 import org.springframework.security.web.server.header.XFrameOptionsServerHttpHeadersWriter
@@ -33,6 +35,10 @@ import org.springframework.security.web.server.header.XXssProtectionServerHttpHe
 import org.springframework.security.web.server.util.matcher.PathPatternParserServerWebExchangeMatcher
 import org.springframework.test.web.reactive.server.WebTestClient
 import org.springframework.web.reactive.config.EnableWebFlux
+import org.springframework.web.server.ServerWebExchange
+import org.springframework.web.server.WebFilter
+import org.springframework.web.server.WebFilterChain
+import reactor.core.publisher.Mono
 
 /**
  * Tests for [ServerHttpSecurityDsl]
@@ -123,4 +129,74 @@ class ServerHttpSecurityDslTests {
             }
         }
     }
+
+    @Test
+    fun `add filter at applies custom at specified filter position`() {
+        this.spring.register(CustomWebFilterAtConfig::class.java).autowire()
+        val filterChain = this.spring.context.getBean(SecurityWebFilterChain::class.java)
+        val filters = filterChain.webFilters.collectList().block()
+
+        assertThat(filters).last().isExactlyInstanceOf(CustomWebFilter::class.java)
+    }
+
+    @EnableWebFluxSecurity
+    @EnableWebFlux
+    open class CustomWebFilterAtConfig {
+        @Bean
+        open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
+            return http {
+                addFilterAt(CustomWebFilter(), SecurityWebFiltersOrder.LAST)
+            }
+        }
+    }
+
+    @Test
+    fun `add filter before applies custom before specified filter position`() {
+        this.spring.register(CustomWebFilterBeforeConfig::class.java).autowire()
+        val filterChain = this.spring.context.getBean(SecurityWebFilterChain::class.java)
+        val filters: List<Class<out WebFilter>>? = filterChain.webFilters.map { it.javaClass }.collectList().block()
+
+        assertThat(filters).containsSubsequence(
+                CustomWebFilter::class.java,
+                SecurityContextServerWebExchangeWebFilter::class.java
+        )
+    }
+
+    @EnableWebFluxSecurity
+    @EnableWebFlux
+    open class CustomWebFilterBeforeConfig {
+        @Bean
+        open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
+            return http {
+                addFilterBefore(CustomWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE)
+            }
+        }
+    }
+
+    @Test
+    fun `add filter after applies custom after specified filter position`() {
+        this.spring.register(CustomWebFilterAfterConfig::class.java).autowire()
+        val filterChain = this.spring.context.getBean(SecurityWebFilterChain::class.java)
+        val filters: List<Class<out WebFilter>>? = filterChain.webFilters.map { it.javaClass }.collectList().block()
+
+        assertThat(filters).containsSubsequence(
+                SecurityContextServerWebExchangeWebFilter::class.java,
+                CustomWebFilter::class.java
+        )
+    }
+
+    @EnableWebFluxSecurity
+    @EnableWebFlux
+    open class CustomWebFilterAfterConfig {
+        @Bean
+        open fun springWebFilterChain(http: ServerHttpSecurity): SecurityWebFilterChain {
+            return http {
+                addFilterAfter(CustomWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE)
+            }
+        }
+    }
+
+    class CustomWebFilter : WebFilter {
+        override fun filter(exchange: ServerWebExchange, chain: WebFilterChain): Mono<Void> = Mono.empty()
+    }
 }