فهرست منبع

Fix Adding Filter Relative to Custom Filter

Closes gh-9787
Marcus Hert da Coregio 4 سال پیش
والد
کامیت
ac371d5de6

+ 12 - 2
config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterOrderRegistration.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -112,8 +112,18 @@ final class FilterOrderRegistration {
 		put(SwitchUserFilter.class, order.next());
 	}
 
-	private void put(Class<? extends Filter> filter, int position) {
+	/**
+	 * Register a {@link Filter} with its specific position. If the {@link Filter} was
+	 * already registered before, the position previously defined is not going to be
+	 * overriden
+	 * @param filter the {@link Filter} to register
+	 * @param position the position to associate with the {@link Filter}
+	 */
+	void put(Class<? extends Filter> filter, int position) {
 		String className = filter.getName();
+		if (this.filterToOrder.containsKey(className)) {
+			return;
+		}
 		this.filterToOrder.put(className, position);
 	}
 

+ 1 - 0
config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java

@@ -2566,6 +2566,7 @@ public final class HttpSecurity extends AbstractConfiguredSecurityBuilder<Defaul
 	private HttpSecurity addFilterAtOffsetOf(Filter filter, int offset, Class<? extends Filter> registeredFilter) {
 		int order = this.filterOrders.getOrder(registeredFilter) + offset;
 		this.filters.add(new OrderedFilter(filter, order));
+		this.filterOrders.put(filter.getClass(), order);
 		return this;
 	}
 

+ 75 - 0
config/src/test/java/org/springframework/security/config/annotation/web/builders/FilterOrderRegistrationTests.java

@@ -0,0 +1,75 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.annotation.web.builders;
+
+import java.io.IOException;
+
+import javax.servlet.Filter;
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+
+import org.junit.Test;
+
+import org.springframework.security.web.access.channel.ChannelProcessingFilter;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class FilterOrderRegistrationTests {
+
+	private final FilterOrderRegistration filterOrderRegistration = new FilterOrderRegistration();
+
+	@Test
+	public void putWhenNewFilterThenInsertCorrect() {
+		int position = 153;
+		this.filterOrderRegistration.put(MyFilter.class, position);
+		Integer order = this.filterOrderRegistration.getOrder(MyFilter.class);
+		assertThat(order).isEqualTo(position);
+	}
+
+	@Test
+	public void putWhenCustomFilterAlreadyExistsThenDoesNotOverride() {
+		int position = 160;
+		this.filterOrderRegistration.put(MyFilter.class, position);
+		this.filterOrderRegistration.put(MyFilter.class, 173);
+		Integer order = this.filterOrderRegistration.getOrder(MyFilter.class);
+		assertThat(order).isEqualTo(position);
+	}
+
+	@Test
+	public void putWhenPredefinedFilterThenDoesNotOverride() {
+		int position = 100;
+		Integer predefinedFilterOrderBefore = this.filterOrderRegistration.getOrder(ChannelProcessingFilter.class);
+		this.filterOrderRegistration.put(MyFilter.class, position);
+		Integer myFilterOrder = this.filterOrderRegistration.getOrder(MyFilter.class);
+		Integer predefinedFilterOrderAfter = this.filterOrderRegistration.getOrder(ChannelProcessingFilter.class);
+		assertThat(myFilterOrder).isEqualTo(position);
+		assertThat(predefinedFilterOrderAfter).isEqualTo(predefinedFilterOrderBefore).isEqualTo(position);
+	}
+
+	static class MyFilter implements Filter {
+
+		@Override
+		public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
+				throws IOException, ServletException {
+			filterChain.doFilter(servletRequest, servletResponse);
+		}
+
+	}
+
+}

+ 144 - 1
config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpSecurityAddFilterTest.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2020 the original author or authors.
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -30,14 +30,18 @@ import org.assertj.core.api.ListAssert;
 import org.junit.Rule;
 import org.junit.Test;
 
+import org.springframework.context.annotation.Bean;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
 import org.springframework.security.config.test.SpringTestRule;
 import org.springframework.security.web.FilterChainProxy;
+import org.springframework.security.web.SecurityFilterChain;
 import org.springframework.security.web.access.ExceptionTranslationFilter;
 import org.springframework.security.web.access.channel.ChannelProcessingFilter;
 import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
+import org.springframework.security.web.context.SecurityContextPersistenceFilter;
 import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter;
+import org.springframework.security.web.header.HeaderWriterFilter;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -70,6 +74,46 @@ public class HttpSecurityAddFilterTest {
 				ExceptionTranslationFilter.class);
 	}
 
+	@Test
+	public void addFilterAfterWhenAfterCustomFilterThenOrderCorrect() {
+		this.spring.register(MyOtherFilterRelativeToMyFilterAfterConfig.class).autowire();
+
+		assertThatFilters().containsSubsequence(WebAsyncManagerIntegrationFilter.class, MyFilter.class,
+				MyOtherFilter.class);
+	}
+
+	@Test
+	public void addFilterBeforeWhenBeforeCustomFilterThenOrderCorrect() {
+		this.spring.register(MyOtherFilterRelativeToMyFilterBeforeConfig.class).autowire();
+
+		assertThatFilters().containsSubsequence(MyOtherFilter.class, MyFilter.class,
+				WebAsyncManagerIntegrationFilter.class);
+	}
+
+	@Test
+	public void addFilterAtWhenAtCustomFilterThenOrderCorrect() {
+		this.spring.register(MyOtherFilterRelativeToMyFilterAtConfig.class).autowire();
+
+		assertThatFilters().containsSubsequence(WebAsyncManagerIntegrationFilter.class, MyFilter.class,
+				MyOtherFilter.class, SecurityContextPersistenceFilter.class);
+	}
+
+	@Test
+	public void addFilterBeforeWhenCustomFilterDifferentPlacesThenOrderCorrect() {
+		this.spring.register(MyOtherFilterBeforeToMyFilterMultipleAfterConfig.class).autowire();
+
+		assertThatFilters().containsSubsequence(WebAsyncManagerIntegrationFilter.class, MyOtherFilter.class,
+				MyFilter.class, ExceptionTranslationFilter.class);
+	}
+
+	@Test
+	public void addFilterBeforeAndAfterWhenCustomFiltersDifferentPlacesThenOrderCorrect() {
+		this.spring.register(MyAnotherFilterRelativeToMyCustomFiltersMultipleConfig.class).autowire();
+
+		assertThatFilters().containsSubsequence(HeaderWriterFilter.class, MyFilter.class, MyOtherFilter.class,
+				MyOtherFilter.class, MyAnotherFilter.class, MyFilter.class, ExceptionTranslationFilter.class);
+	}
+
 	private ListAssert<Class<?>> assertThatFilters() {
 		FilterChainProxy filterChain = this.spring.getContext().getBean(FilterChainProxy.class);
 		List<Class<?>> filters = filterChain.getFilters("/").stream().map(Object::getClass)
@@ -87,6 +131,26 @@ public class HttpSecurityAddFilterTest {
 
 	}
 
+	static class MyOtherFilter implements Filter {
+
+		@Override
+		public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
+				throws IOException, ServletException {
+			filterChain.doFilter(servletRequest, servletResponse);
+		}
+
+	}
+
+	static class MyAnotherFilter implements Filter {
+
+		@Override
+		public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
+				throws IOException, ServletException {
+			filterChain.doFilter(servletRequest, servletResponse);
+		}
+
+	}
+
 	@EnableWebSecurity
 	static class MyFilterMultipleAfterConfig extends WebSecurityConfigurerAdapter {
 
@@ -129,4 +193,83 @@ public class HttpSecurityAddFilterTest {
 
 	}
 
+	@EnableWebSecurity
+	static class MyOtherFilterRelativeToMyFilterAfterConfig {
+
+		@Bean
+		SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+					.addFilterAfter(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
+					.addFilterAfter(new MyOtherFilter(), MyFilter.class);
+			// @formatter:on
+			return http.build();
+		}
+
+	}
+
+	@EnableWebSecurity
+	static class MyOtherFilterRelativeToMyFilterBeforeConfig {
+
+		@Bean
+		SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+					.addFilterBefore(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
+					.addFilterBefore(new MyOtherFilter(), MyFilter.class);
+			// @formatter:on
+			return http.build();
+		}
+
+	}
+
+	@EnableWebSecurity
+	static class MyOtherFilterRelativeToMyFilterAtConfig {
+
+		@Bean
+		SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+					.addFilterAt(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
+					.addFilterAt(new MyOtherFilter(), MyFilter.class);
+			// @formatter:on
+			return http.build();
+		}
+
+	}
+
+	@EnableWebSecurity
+	static class MyOtherFilterBeforeToMyFilterMultipleAfterConfig {
+
+		@Bean
+		SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+					.addFilterAfter(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
+					.addFilterAfter(new MyFilter(), ExceptionTranslationFilter.class)
+					.addFilterBefore(new MyOtherFilter(), MyFilter.class);
+			// @formatter:on
+			return http.build();
+		}
+
+	}
+
+	@EnableWebSecurity
+	static class MyAnotherFilterRelativeToMyCustomFiltersMultipleConfig {
+
+		@Bean
+		SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+					.addFilterAfter(new MyFilter(), HeaderWriterFilter.class)
+					.addFilterBefore(new MyOtherFilter(), ExceptionTranslationFilter.class)
+					.addFilterAfter(new MyOtherFilter(), MyFilter.class)
+					.addFilterAt(new MyAnotherFilter(), MyOtherFilter.class)
+					.addFilterAfter(new MyFilter(), MyAnotherFilter.class);
+			// @formatter:on
+			return http.build();
+		}
+
+	}
+
 }