فهرست منبع

Fix HttpSecurity.addFilter* Ordering

Closes gh-9633
Rob Winch 4 سال پیش
والد
کامیت
adf3e94c9f

+ 3 - 75
config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterComparator.java → config/src/main/java/org/springframework/security/config/annotation/web/builders/FilterOrderRegistration.java

@@ -15,7 +15,6 @@
  */
 package org.springframework.security.config.annotation.web.builders;
 
-import java.io.Serializable;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.Map;
@@ -53,14 +52,12 @@ import org.springframework.web.filter.CorsFilter;
  * @author Rob Winch
  * @since 3.2
  */
-
-@SuppressWarnings("serial")
-final class FilterComparator implements Comparator<Filter>, Serializable {
+final class FilterOrderRegistration {
 	private static final int INITIAL_ORDER = 100;
 	private static final int ORDER_STEP = 100;
 	private final Map<String, Integer> filterToOrder = new HashMap<>();
 
-	FilterComparator() {
+	FilterOrderRegistration() {
 		Step order = new Step(INITIAL_ORDER, ORDER_STEP);
 		put(ChannelProcessingFilter.class, order.next());
 		put(ConcurrentSessionFilter.class, order.next());
@@ -111,75 +108,6 @@ final class FilterComparator implements Comparator<Filter>, Serializable {
 		put(SwitchUserFilter.class, order.next());
 	}
 
-	public int compare(Filter lhs, Filter rhs) {
-		Integer left = getOrder(lhs.getClass());
-		Integer right = getOrder(rhs.getClass());
-		return left - right;
-	}
-
-	/**
-	 * Determines if a particular {@link Filter} is registered to be sorted
-	 *
-	 * @param filter
-	 * @return
-	 */
-	public boolean isRegistered(Class<? extends Filter> filter) {
-		return getOrder(filter) != null;
-	}
-
-	/**
-	 * Registers a {@link Filter} to exist after a particular {@link Filter} that is
-	 * already registered.
-	 * @param filter the {@link Filter} to register
-	 * @param afterFilter the {@link Filter} that is already registered and that
-	 * {@code filter} should be placed after.
-	 */
-	public void registerAfter(Class<? extends Filter> filter,
-			Class<? extends Filter> afterFilter) {
-		Integer position = getOrder(afterFilter);
-		if (position == null) {
-			throw new IllegalArgumentException(
-					"Cannot register after unregistered Filter " + afterFilter);
-		}
-
-		put(filter, position + 1);
-	}
-
-	/**
-	 * Registers a {@link Filter} to exist at a particular {@link Filter} position
-	 * @param filter the {@link Filter} to register
-	 * @param atFilter the {@link Filter} that is already registered and that
-	 * {@code filter} should be placed at.
-	 */
-	public void registerAt(Class<? extends Filter> filter,
-			Class<? extends Filter> atFilter) {
-		Integer position = getOrder(atFilter);
-		if (position == null) {
-			throw new IllegalArgumentException(
-					"Cannot register after unregistered Filter " + atFilter);
-		}
-
-		put(filter, position);
-	}
-
-	/**
-	 * Registers a {@link Filter} to exist before a particular {@link Filter} that is
-	 * already registered.
-	 * @param filter the {@link Filter} to register
-	 * @param beforeFilter the {@link Filter} that is already registered and that
-	 * {@code filter} should be placed before.
-	 */
-	public void registerBefore(Class<? extends Filter> filter,
-			Class<? extends Filter> beforeFilter) {
-		Integer position = getOrder(beforeFilter);
-		if (position == null) {
-			throw new IllegalArgumentException(
-					"Cannot register after unregistered Filter " + beforeFilter);
-		}
-
-		put(filter, position - 1);
-	}
-
 	private void put(Class<? extends Filter> filter, int position) {
 		String className = filter.getName();
 		filterToOrder.put(className, position);
@@ -192,7 +120,7 @@ final class FilterComparator implements Comparator<Filter>, Serializable {
 	 * @param clazz the {@link Filter} class to determine the sort order
 	 * @return the sort order or null if not defined
 	 */
-	private Integer getOrder(Class<?> clazz) {
+	Integer getOrder(Class<?> clazz) {
 		while (clazz != null) {
 			Integer result = filterToOrder.get(clazz.getName());
 			if (result != null) {

+ 64 - 17
config/src/main/java/org/springframework/security/config/annotation/web/builders/HttpSecurity.java

@@ -16,6 +16,8 @@
 package org.springframework.security.config.annotation.web.builders;
 
 import org.springframework.context.ApplicationContext;
+import org.springframework.core.OrderComparator;
+import org.springframework.core.Ordered;
 import org.springframework.http.HttpMethod;
 import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.AuthenticationProvider;
@@ -78,10 +80,16 @@ import org.springframework.web.cors.CorsConfiguration;
 import org.springframework.web.filter.CorsFilter;
 import org.springframework.web.servlet.handler.HandlerMappingIntrospector;
 
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
+
 import javax.servlet.Filter;
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
 import javax.servlet.http.HttpServletRequest;
 
 /**
@@ -125,9 +133,9 @@ public final class HttpSecurity extends
 		implements SecurityBuilder<DefaultSecurityFilterChain>,
 		HttpSecurityBuilder<HttpSecurity> {
 	private final RequestMatcherConfigurer requestMatcherConfigurer;
-	private List<Filter> filters = new ArrayList<>();
+	private List<OrderedFilter> filters = new ArrayList<>();
 	private RequestMatcher requestMatcher = AnyRequestMatcher.INSTANCE;
-	private FilterComparator comparator = new FilterComparator();
+	private FilterOrderRegistration filterOrders = new FilterOrderRegistration();
 
 	/**
 	 * Creates a new instance
@@ -2528,8 +2536,12 @@ public final class HttpSecurity extends
 
 	@Override
 	protected DefaultSecurityFilterChain performBuild() {
-		filters.sort(comparator);
-		return new DefaultSecurityFilterChain(requestMatcher, filters);
+		this.filters.sort(OrderComparator.INSTANCE);
+		List<Filter> sortedFilters = new ArrayList<>(this.filters.size());
+		for (Filter filter : this.filters) {
+			sortedFilters.add(((OrderedFilter) filter).filter);
+		}
+		return new DefaultSecurityFilterChain(this.requestMatcher, sortedFilters);
 	}
 
 	/*
@@ -2570,8 +2582,7 @@ public final class HttpSecurity extends
 	 * .servlet.Filter, java.lang.Class)
 	 */
 	public HttpSecurity addFilterAfter(Filter filter, Class<? extends Filter> afterFilter) {
-		comparator.registerAfter(filter.getClass(), afterFilter);
-		return addFilter(filter);
+		return addFilterAtOffsetOf(filter, 1, afterFilter);
 	}
 
 	/*
@@ -2583,8 +2594,13 @@ public final class HttpSecurity extends
 	 */
 	public HttpSecurity addFilterBefore(Filter filter,
 			Class<? extends Filter> beforeFilter) {
-		comparator.registerBefore(filter.getClass(), beforeFilter);
-		return addFilter(filter);
+		return addFilterAtOffsetOf(filter, -1, beforeFilter);
+	}
+
+	private HttpSecurity addFilterAtOffsetOf(Filter filter, int offset, Class<? extends Filter> registeredFilter) {
+		int order = this.filterOrders.getOrder(registeredFilter) + offset;
+		this.filters.add(new OrderedFilter(filter, order));
+		return this;
 	}
 
 	/*
@@ -2595,14 +2611,12 @@ public final class HttpSecurity extends
 	 * servlet.Filter)
 	 */
 	public HttpSecurity addFilter(Filter filter) {
-		Class<? extends Filter> filterClass = filter.getClass();
-		if (!comparator.isRegistered(filterClass)) {
-			throw new IllegalArgumentException(
-					"The Filter class "
-							+ filterClass.getName()
-							+ " does not have a registered order and cannot be added without a specified order. Consider using addFilterBefore or addFilterAfter instead.");
+		Integer order = this.filterOrders.getOrder(filter.getClass());
+		if (order == null) {
+			throw new IllegalArgumentException("The Filter class " + filter.getClass().getName()
+					+ " does not have a registered order and cannot be added without a specified order. Consider using addFilterBefore or addFilterAfter instead.");
 		}
-		this.filters.add(filter);
+		this.filters.add(new OrderedFilter(filter, order));
 		return this;
 	}
 
@@ -2626,8 +2640,7 @@ public final class HttpSecurity extends
 	 * @return the {@link HttpSecurity} for further customizations
 	 */
 	public HttpSecurity addFilterAt(Filter filter, Class<? extends Filter> atFilter) {
-		this.comparator.registerAt(filter.getClass(), atFilter);
-		return addFilter(filter);
+		return addFilterAtOffsetOf(filter, 0, atFilter);
 	}
 
 	/**
@@ -3023,4 +3036,38 @@ public final class HttpSecurity extends
 		}
 		return apply(configurer);
 	}
+
+	/*
+	 * A Filter that implements Ordered to be sorted. After sorting occurs, the original
+	 * filter is what is used by FilterChainProxy
+	 */
+	private static final class OrderedFilter implements Ordered, Filter {
+
+		private final Filter filter;
+
+		private final int order;
+
+		private OrderedFilter(Filter filter, int order) {
+			this.filter = filter;
+			this.order = order;
+		}
+
+		@Override
+		public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
+				throws IOException, ServletException {
+			this.filter.doFilter(servletRequest, servletResponse, filterChain);
+		}
+
+		@Override
+		public int getOrder() {
+			return this.order;
+		}
+
+		@Override
+		public String toString() {
+			return "OrderedFilter{" + "filter=" + this.filter + ", order=" + this.order + '}';
+		}
+
+	}
+
 }

+ 132 - 0
config/src/test/java/org/springframework/security/config/annotation/web/builders/HttpSecurityAddFilterTest.java

@@ -0,0 +1,132 @@
+/*
+ * Copyright 2002-2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.annotation.web.builders;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import javax.servlet.Filter;
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+
+import org.assertj.core.api.ListAssert;
+import org.junit.Rule;
+import org.junit.Test;
+
+import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
+import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
+import org.springframework.security.config.test.SpringTestRule;
+import org.springframework.security.web.FilterChainProxy;
+import org.springframework.security.web.access.ExceptionTranslationFilter;
+import org.springframework.security.web.access.channel.ChannelProcessingFilter;
+import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
+import org.springframework.security.web.context.request.async.WebAsyncManagerIntegrationFilter;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+public class HttpSecurityAddFilterTest {
+
+	@Rule
+	public final SpringTestRule spring = new SpringTestRule();
+
+	@Test
+	public void addFilterAfterWhenSameFilterDifferentPlacesThenOrderCorrect() {
+		this.spring.register(MyFilterMultipleAfterConfig.class).autowire();
+
+		assertThatFilters().containsSubsequence(WebAsyncManagerIntegrationFilter.class, MyFilter.class,
+				ExceptionTranslationFilter.class, MyFilter.class);
+	}
+
+	@Test
+	public void addFilterBeforeWhenSameFilterDifferentPlacesThenOrderCorrect() {
+		this.spring.register(MyFilterMultipleBeforeConfig.class).autowire();
+
+		assertThatFilters().containsSubsequence(MyFilter.class, WebAsyncManagerIntegrationFilter.class, MyFilter.class,
+				ExceptionTranslationFilter.class);
+	}
+
+	@Test
+	public void addFilterAtWhenSameFilterDifferentPlacesThenOrderCorrect() {
+		this.spring.register(MyFilterMultipleAtConfig.class).autowire();
+
+		assertThatFilters().containsSubsequence(MyFilter.class, WebAsyncManagerIntegrationFilter.class, MyFilter.class,
+				ExceptionTranslationFilter.class);
+	}
+
+	private ListAssert<Class<?>> assertThatFilters() {
+		FilterChainProxy filterChain = this.spring.getContext().getBean(FilterChainProxy.class);
+		List<Class<?>> filters = filterChain.getFilters("/").stream().map(Object::getClass)
+				.collect(Collectors.toList());
+		return assertThat(filters);
+	}
+
+	public static class MyFilter implements Filter {
+
+		@Override
+		public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain)
+				throws IOException, ServletException {
+			filterChain.doFilter(servletRequest, servletResponse);
+		}
+
+	}
+
+	@EnableWebSecurity
+	static class MyFilterMultipleAfterConfig extends WebSecurityConfigurerAdapter {
+
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+					.addFilterAfter(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
+					.addFilterAfter(new MyFilter(), ExceptionTranslationFilter.class);
+			// @formatter:on
+		}
+
+	}
+
+	@EnableWebSecurity
+	static class MyFilterMultipleBeforeConfig extends WebSecurityConfigurerAdapter {
+
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+					.addFilterBefore(new MyFilter(), WebAsyncManagerIntegrationFilter.class)
+					.addFilterBefore(new MyFilter(), ExceptionTranslationFilter.class);
+			// @formatter:on
+		}
+
+	}
+
+	@EnableWebSecurity
+	static class MyFilterMultipleAtConfig extends WebSecurityConfigurerAdapter {
+
+		@Override
+		protected void configure(HttpSecurity http) throws Exception {
+			// @formatter:off
+			http
+					.addFilterAt(new MyFilter(), ChannelProcessingFilter.class)
+					.addFilterAt(new MyFilter(), UsernamePasswordAuthenticationFilter.class);
+			// @formatter:on
+		}
+
+	}
+
+}