Procházet zdrojové kódy

Add Reactive HttpSecurity.addWebFilterAt

Fixes gh-4542
Rob Winch před 8 roky
rodič
revize
b3bd5ba946

+ 66 - 23
config/src/main/java/org/springframework/security/config/web/server/HttpSecurity.java

@@ -21,12 +21,15 @@ import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
 
+import org.springframework.core.Ordered;
+import org.springframework.core.annotation.AnnotationAwareOrderComparator;
 import org.springframework.http.MediaType;
 import org.springframework.security.web.server.DelegatingAuthenticationEntryPoint;
 import org.springframework.security.web.server.authentication.AuthenticationFailureHandler;
 import org.springframework.security.web.server.authentication.logout.LogoutWebFiter;
 import org.springframework.security.web.server.util.matcher.MediaTypeServerWebExchangeMatcher;
-import org.springframework.security.web.util.matcher.MediaTypeRequestMatcher;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebFilterChain;
 import reactor.core.publisher.Mono;
 
 import org.springframework.http.HttpMethod;
@@ -94,6 +97,8 @@ public class HttpSecurity {
 
 	private List<DelegateEntry> defaultEntryPoints = new ArrayList<>();
 
+	private List<WebFilter> webFilters = new ArrayList<>();
+
 	/**
 	 * The ServerExchangeMatcher that determines which requests apply to this HttpSecurity instance.
 	 *
@@ -106,6 +111,11 @@ public class HttpSecurity {
 		return this;
 	}
 
+	public HttpSecurity addFilterAt(WebFilter webFilter, SecurityWebFiltersOrder order) {
+		this.webFilters.add(new OrderedWebFilter(webFilter, order.getOrder()));
+		return this;
+	}
+
 	/**
 	 * Gets the ServerExchangeMatcher that determines which requests apply to this HttpSecurity instance.
 	 * @return the ServerExchangeMatcher that determines which requests apply to this HttpSecurity instance.
@@ -154,20 +164,19 @@ public class HttpSecurity {
 	}
 
 	public SecurityWebFilterChain build() {
-		List<WebFilter> filters = new ArrayList<>();
 		if(this.headers != null) {
-			filters.add(this.headers.build());
+			this.webFilters.add(this.headers.build());
 		}
-		SecurityContextRepositoryWebFilter securityContextRepositoryWebFilter = securityContextRepositoryWebFilter();
+		WebFilter securityContextRepositoryWebFilter = securityContextRepositoryWebFilter();
 		if(securityContextRepositoryWebFilter != null) {
-			filters.add(securityContextRepositoryWebFilter);
+			this.webFilters.add(securityContextRepositoryWebFilter);
 		}
 		if(this.httpBasic != null) {
 			this.httpBasic.authenticationManager(this.authenticationManager);
 			if(this.securityContextRepository != null) {
 				this.httpBasic.securityContextRepository(this.securityContextRepository);
 			}
-			filters.add(this.httpBasic.build());
+			this.webFilters.add(this.httpBasic.build());
 		}
 		if(this.formLogin != null) {
 			this.formLogin.authenticationManager(this.authenticationManager);
@@ -175,22 +184,24 @@ public class HttpSecurity {
 				this.formLogin.securityContextRepository(this.securityContextRepository);
 			}
 			if(this.formLogin.authenticationEntryPoint == null) {
-				filters.add(new LoginPageGeneratingWebFilter());
+				this.webFilters.add(new OrderedWebFilter(new LoginPageGeneratingWebFilter(), SecurityWebFiltersOrder.LOGIN_PAGE_GENERATING.getOrder()));
 			}
-			filters.add(this.formLogin.build());
-			filters.add(new LogoutWebFiter());
+			this.webFilters.add(this.formLogin.build());
+			this.webFilters
+				.add(new OrderedWebFilter(new LogoutWebFiter(), SecurityWebFiltersOrder.LOGOUT.getOrder()));
 		}
-		filters.add(new AuthenticationReactorContextFilter());
+		this.webFilters.add(new OrderedWebFilter(new AuthenticationReactorContextFilter(), SecurityWebFiltersOrder.AUTHENTICATION_CONTEXT.getOrder()));
 		if(this.authorizeExchangeBuilder != null) {
 			AuthenticationEntryPoint authenticationEntryPoint = getAuthenticationEntryPoint();
 			ExceptionTranslationWebFilter exceptionTranslationWebFilter = new ExceptionTranslationWebFilter();
 			if(authenticationEntryPoint != null) {
 				exceptionTranslationWebFilter.setAuthenticationEntryPoint(authenticationEntryPoint);
 			}
-			filters.add(exceptionTranslationWebFilter);
-			filters.add(this.authorizeExchangeBuilder.build());
+			this.webFilters.add(new OrderedWebFilter(exceptionTranslationWebFilter, SecurityWebFiltersOrder.EXCEPTION_TRANSLATION.getOrder()));
+			this.webFilters.add(this.authorizeExchangeBuilder.build());
 		}
-		return new MatcherSecurityWebFilterChain(getSecurityMatcher(), filters);
+		AnnotationAwareOrderComparator.sort(this.webFilters);
+		return new MatcherSecurityWebFilterChain(getSecurityMatcher(), this.webFilters);
 	}
 
 	private AuthenticationEntryPoint getAuthenticationEntryPoint() {
@@ -209,10 +220,13 @@ public class HttpSecurity {
 		return new HttpSecurity();
 	}
 
-	private SecurityContextRepositoryWebFilter securityContextRepositoryWebFilter() {
+	private WebFilter securityContextRepositoryWebFilter() {
 		SecurityContextRepository repository = this.securityContextRepository;
-		return repository == null ? null :
-			new SecurityContextRepositoryWebFilter(repository);
+		if(repository == null) {
+			return null;
+		}
+		WebFilter result = new SecurityContextRepositoryWebFilter(repository);
+		return new OrderedWebFilter(result, SecurityWebFiltersOrder.SECURITY_CONTEXT_REPOSITORY.getOrder());
 	}
 
 	private HttpSecurity() {}
@@ -253,7 +267,8 @@ public class HttpSecurity {
 			if(this.matcher != null) {
 				throw new IllegalStateException("The matcher " + this.matcher + " does not have an access rule defined");
 			}
-			return new AuthorizationWebFilter(this.managerBldr.build());
+			AuthorizationWebFilter result = new AuthorizationWebFilter(this.managerBldr.build());
+			return new OrderedWebFilter(result, SecurityWebFiltersOrder.AUTHORIZATION.getOrder());
 		}
 
 		public final class Access {
@@ -318,7 +333,7 @@ public class HttpSecurity {
 			return HttpSecurity.this;
 		}
 
-		protected AuthenticationWebFilter build() {
+		protected WebFilter build() {
 			MediaTypeServerWebExchangeMatcher restMatcher = new MediaTypeServerWebExchangeMatcher(
 				MediaType.APPLICATION_ATOM_XML,
 				MediaType.APPLICATION_FORM_URLENCODED, MediaType.APPLICATION_JSON,
@@ -333,7 +348,7 @@ public class HttpSecurity {
 			if(this.securityContextRepository != null) {
 				authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
 			}
-			return authenticationFilter;
+			return new OrderedWebFilter(authenticationFilter, SecurityWebFiltersOrder.HTTP_BASIC.getOrder());
 		}
 
 		private HttpBasicBuilder() {}
@@ -395,7 +410,7 @@ public class HttpSecurity {
 			return HttpSecurity.this;
 		}
 
-		protected AuthenticationWebFilter build() {
+		protected WebFilter build() {
 			if(this.authenticationEntryPoint == null) {
 				loginPage("/login");
 			}
@@ -410,7 +425,7 @@ public class HttpSecurity {
 			authenticationFilter.setAuthenticationConverter(new FormLoginAuthenticationConverter());
 			authenticationFilter.setAuthenticationSuccessHandler(new RedirectAuthenticationSuccessHandler("/"));
 			authenticationFilter.setSecurityContextRepository(this.securityContextRepository);
-			return authenticationFilter;
+			return new OrderedWebFilter(authenticationFilter, SecurityWebFiltersOrder.FORM_LOGIN.getOrder());
 		}
 
 		private FormLoginBuilder() {
@@ -454,9 +469,10 @@ public class HttpSecurity {
 			return new HstsSpec();
 		}
 
-		protected HttpHeaderWriterWebFilter build() {
+		protected WebFilter build() {
 			HttpHeadersWriter writer = new CompositeHttpHeadersWriter(this.writers);
-			return new HttpHeaderWriterWebFilter(writer);
+			HttpHeaderWriterWebFilter result = new HttpHeaderWriterWebFilter(writer);
+			return new OrderedWebFilter(result, SecurityWebFiltersOrder.HTTP_HEADERS_WRITER.getOrder());
 		}
 
 		public XssProtectionSpec xssProtection() {
@@ -520,4 +536,31 @@ public class HttpSecurity {
 					this.frameOptions, this.xss));
 		}
 	}
+
+	private static class OrderedWebFilter implements WebFilter, Ordered {
+		private final WebFilter webFilter;
+		private final int order;
+
+		public OrderedWebFilter(WebFilter webFilter, int order) {
+			this.webFilter = webFilter;
+			this.order = order;
+		}
+
+		@Override
+		public Mono<Void> filter(ServerWebExchange exchange,
+			WebFilterChain chain) {
+			return this.webFilter.filter(exchange, chain);
+		}
+
+		@Override
+		public int getOrder() {
+			return this.order;
+		}
+
+		@Override
+		public String toString() {
+			return "OrderedWebFilter{" + "webFilter=" + this.webFilter + ", order=" + this.order
+				+ '}';
+		}
+	}
 }

+ 58 - 0
config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java

@@ -0,0 +1,58 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.config.web.server;
+
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+public enum SecurityWebFiltersOrder {
+	FIRST(Integer.MIN_VALUE),
+	HTTP_HEADERS_WRITER,
+	SECURITY_CONTEXT_REPOSITORY,
+	LOGIN_PAGE_GENERATING,
+	/**
+	 * Instance of AuthenticationWebFilter
+	 */
+	HTTP_BASIC,
+	/**
+	 * Instance of AuthenticationWebFilter
+	 */
+	FORM_LOGIN,
+	AUTHENTICATION,
+	LOGOUT,
+	AUTHENTICATION_CONTEXT,
+	EXCEPTION_TRANSLATION,
+	AUTHORIZATION,
+	LAST(Integer.MAX_VALUE);
+
+	private static final int INTERVAL = 100;
+
+	private final int order;
+
+	private SecurityWebFiltersOrder() {
+		this.order = ordinal() * INTERVAL;
+	}
+
+	private SecurityWebFiltersOrder(int order) {
+		this.order = order;
+	}
+
+	public int getOrder() {
+		return this.order;
+	}
+}