|
@@ -34,6 +34,8 @@ import org.junit.runner.RunWith;
|
|
import org.mockito.Mock;
|
|
import org.mockito.Mock;
|
|
import org.mockito.junit.MockitoJUnitRunner;
|
|
import org.mockito.junit.MockitoJUnitRunner;
|
|
|
|
|
|
|
|
+import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
|
|
|
|
+import org.springframework.web.server.WebFilterChain;
|
|
import reactor.core.publisher.Mono;
|
|
import reactor.core.publisher.Mono;
|
|
import reactor.test.publisher.TestPublisher;
|
|
import reactor.test.publisher.TestPublisher;
|
|
|
|
|
|
@@ -190,6 +192,30 @@ public class ServerHttpSecurityTests {
|
|
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
|
|
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ @Test
|
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
|
+ public void addFilterAfterIsApplied(){
|
|
|
|
+ SecurityWebFilterChain securityWebFilterChain = this.http.addFilterAfter(new TestWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE).build();
|
|
|
|
+ List filters = securityWebFilterChain.getWebFilters().map(WebFilter::getClass).collectList().block();
|
|
|
|
+
|
|
|
|
+ assertThat(filters).isNotNull()
|
|
|
|
+ .isNotEmpty()
|
|
|
|
+ .containsSequence(SecurityContextServerWebExchangeWebFilter.class, TestWebFilter.class);
|
|
|
|
+
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
|
+ public void addFilterBeforeIsApplied(){
|
|
|
|
+ SecurityWebFilterChain securityWebFilterChain = this.http.addFilterBefore(new TestWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE).build();
|
|
|
|
+ List filters = securityWebFilterChain.getWebFilters().map(WebFilter::getClass).collectList().block();
|
|
|
|
+
|
|
|
|
+ assertThat(filters).isNotNull()
|
|
|
|
+ .isNotEmpty()
|
|
|
|
+ .containsSequence(TestWebFilter.class, SecurityContextServerWebExchangeWebFilter.class);
|
|
|
|
+
|
|
|
|
+ }
|
|
|
|
+
|
|
private <T extends WebFilter> Optional<T> getWebFilter(SecurityWebFilterChain filterChain, Class<T> filterClass) {
|
|
private <T extends WebFilter> Optional<T> getWebFilter(SecurityWebFilterChain filterChain, Class<T> filterClass) {
|
|
return (Optional<T>) filterChain.getWebFilters()
|
|
return (Optional<T>) filterChain.getWebFilters()
|
|
.filter(Objects::nonNull)
|
|
.filter(Objects::nonNull)
|
|
@@ -214,4 +240,12 @@ public class ServerHttpSecurityTests {
|
|
.map(e -> e.getRequest().getPath().pathWithinApplication().value());
|
|
.map(e -> e.getRequest().getPath().pathWithinApplication().value());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ private static class TestWebFilter implements WebFilter {
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
|
|
|
|
+ return chain.filter(exchange);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
}
|
|
}
|