Răsfoiți Sursa

Automatically add CsrfServerLogoutHandler if csrf enabled

The configuration DSL should automatically add CsrfServerLogoutHandler if csrf is enabled

Fixes gh-5337
Eric Deandrea 7 ani în urmă
părinte
comite
b060ec050a

+ 28 - 2
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -27,6 +27,7 @@ import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.function.Function;
 
 import reactor.core.publisher.Mono;
@@ -92,7 +93,9 @@ import org.springframework.security.web.server.authentication.ServerAuthenticati
 import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler;
 import org.springframework.security.web.server.authentication.ServerFormLoginAuthenticationConverter;
 import org.springframework.security.web.server.authentication.ServerHttpBasicAuthenticationConverter;
+import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler;
 import org.springframework.security.web.server.authentication.logout.LogoutWebFilter;
+import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler;
 import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler;
 import org.springframework.security.web.server.authentication.logout.ServerLogoutSuccessHandler;
 import org.springframework.security.web.server.authorization.AuthorizationContext;
@@ -106,8 +109,10 @@ import org.springframework.security.web.server.context.ReactorContextWebFilter;
 import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
 import org.springframework.security.web.server.context.ServerSecurityContextRepository;
 import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
+import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
 import org.springframework.security.web.server.csrf.CsrfWebFilter;
 import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
+import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository;
 import org.springframework.security.web.server.header.CacheControlServerHttpHeadersWriter;
 import org.springframework.security.web.server.header.CompositeServerHttpHeadersWriter;
 import org.springframework.security.web.server.header.ContentSecurityPolicyServerHttpHeadersWriter;
@@ -1538,6 +1543,7 @@ public class ServerHttpSecurity {
 	 */
 	public class CsrfSpec {
 		private CsrfWebFilter filter = new CsrfWebFilter();
+		private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository();
 
 		private boolean specifiedRequireCsrfProtectionMatcher;
 
@@ -1563,7 +1569,7 @@ public class ServerHttpSecurity {
 		 */
 		public CsrfSpec csrfTokenRepository(
 			ServerCsrfTokenRepository csrfTokenRepository) {
-			this.filter.setCsrfTokenRepository(csrfTokenRepository);
+			this.csrfTokenRepository = csrfTokenRepository;
 			return this;
 		}
 
@@ -1600,6 +1606,10 @@ public class ServerHttpSecurity {
 		}
 
 		protected void configure(ServerHttpSecurity http) {
+			Optional.ofNullable(this.csrfTokenRepository).ifPresent(serverCsrfTokenRepository -> {
+				this.filter.setCsrfTokenRepository(serverCsrfTokenRepository);
+				http.logout().logoutHandler(new CsrfServerLogoutHandler(serverCsrfTokenRepository));
+			});
 			http.addFilterAt(this.filter, SecurityWebFiltersOrder.CSRF);
 		}
 
@@ -2332,6 +2342,7 @@ public class ServerHttpSecurity {
 	 */
 	public final class LogoutSpec {
 		private LogoutWebFilter logoutWebFilter = new LogoutWebFilter();
+		private List<ServerLogoutHandler> logoutHandlers = new ArrayList<>(Arrays.asList(new SecurityContextServerLogoutHandler()));
 
 		/**
 		 * Configures the logout handler. Default is {@code SecurityContextServerLogoutHandler}
@@ -2339,7 +2350,10 @@ public class ServerHttpSecurity {
 		 * @return the {@link LogoutSpec} to configure
 		 */
 		public LogoutSpec logoutHandler(ServerLogoutHandler logoutHandler) {
-			this.logoutWebFilter.setLogoutHandler(logoutHandler);
+			if (logoutHandler != null) {
+				this.logoutHandlers.add(logoutHandler);
+			}
+
 			return this;
 		}
 
@@ -2387,7 +2401,19 @@ public class ServerHttpSecurity {
 			return and();
 		}
 
+		private Optional<ServerLogoutHandler> createLogoutHandler() {
+			if (this.logoutHandlers.isEmpty()) {
+				return Optional.empty();
+			}
+			else if (this.logoutHandlers.size() == 1) {
+				return Optional.of(this.logoutHandlers.get(0));
+			}
+
+			return Optional.of(new DelegatingServerLogoutHandler(this.logoutHandlers));
+		}
+
 		protected void configure(ServerHttpSecurity http) {
+			createLogoutHandler().ifPresent(this.logoutWebFilter::setLogoutHandler);
 			http.addFilterAt(this.logoutWebFilter, SecurityWebFiltersOrder.LOGOUT);
 		}
 

+ 71 - 7
config/src/test/java/org/springframework/security/config/web/server/ServerHttpSecurityTests.java

@@ -16,12 +16,27 @@
 
 package org.springframework.security.config.web.server;
 
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.when;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
 import org.apache.http.HttpHeaders;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
+
+import reactor.core.publisher.Mono;
+import reactor.test.publisher.TestPublisher;
+
 import org.springframework.security.authentication.ReactiveAuthenticationManager;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder;
@@ -29,21 +44,23 @@ import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
 import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.WebFilterChainProxy;
+import org.springframework.security.web.server.authentication.logout.DelegatingServerLogoutHandler;
+import org.springframework.security.web.server.authentication.logout.LogoutWebFilter;
+import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler;
+import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler;
 import org.springframework.security.web.server.context.ServerSecurityContextRepository;
 import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
+import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler;
+import org.springframework.security.web.server.csrf.CsrfWebFilter;
+import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
+import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.test.web.reactive.server.EntityExchangeResult;
 import org.springframework.test.web.reactive.server.FluxExchangeResult;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.server.ServerWebExchange;
-import reactor.core.publisher.Mono;
-import reactor.test.publisher.TestPublisher;
-
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.mockito.BDDMockito.given;
-import static org.mockito.Matchers.any;
-import static org.mockito.Mockito.when;
+import org.springframework.web.server.WebFilter;
 
 /**
  * @author Rob Winch
@@ -55,6 +72,8 @@ public class ServerHttpSecurityTests {
 	private ServerSecurityContextRepository contextRepository;
 	@Mock
 	private ReactiveAuthenticationManager authenticationManager;
+	@Mock
+	private ServerCsrfTokenRepository csrfTokenRepository;
 
 	private ServerHttpSecurity http;
 
@@ -134,6 +153,51 @@ public class ServerHttpSecurityTests {
 				.expectBody(String.class).isEqualTo("/foo/bar");
 	}
 
+	@Test
+	public void csrfServerLogoutHandlerNotAppliedIfCsrfIsntEnabled() {
+		SecurityWebFilterChain securityWebFilterChain = this.http.csrf().disable().build();
+
+		assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class))
+				.isNotPresent();
+
+		Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
+				.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
+
+		assertThat(logoutHandler)
+				.get()
+				.isExactlyInstanceOf(SecurityContextServerLogoutHandler.class);
+	}
+
+	@Test
+	public void csrfServerLogoutHandlerAppliedIfCsrfIsEnabled() {
+		SecurityWebFilterChain securityWebFilterChain = this.http.csrf().csrfTokenRepository(this.csrfTokenRepository).and().build();
+
+		assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class))
+				.get()
+				.extracting(csrfWebFilter -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository"))
+				.isEqualTo(this.csrfTokenRepository);
+
+		Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class)
+				.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler"));
+
+		assertThat(logoutHandler)
+				.get()
+				.isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
+				.extracting(delegatingLogoutHandler ->
+						((List<ServerLogoutHandler>) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream()
+								.map(ServerLogoutHandler::getClass)
+								.collect(Collectors.toList()))
+				.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class));
+	}
+
+	private <T extends WebFilter> Optional<T> getWebFilter(SecurityWebFilterChain filterChain, Class<T> filterClass) {
+		return (Optional<T>) filterChain.getWebFilters()
+				.filter(Objects::nonNull)
+				.filter(filter -> filter.getClass().isAssignableFrom(filterClass))
+				.singleOrEmpty()
+				.blockOptional();
+	}
+
 	private WebTestClient buildClient() {
 		WebFilterChainProxy springSecurityFilterChain = new WebFilterChainProxy(
 			this.http.build());

+ 13 - 9
web/src/main/java/org/springframework/security/web/server/authentication/logout/DelegatingServerLogoutHandler.java

@@ -18,16 +18,17 @@ package org.springframework.security.web.server.authentication.logout;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.List;
+import java.util.Objects;
 import java.util.stream.Collectors;
-import java.util.stream.Stream;
+
+import reactor.core.publisher.Mono;
 
 import org.springframework.security.core.Authentication;
 import org.springframework.security.web.server.WebFilterExchange;
 import org.springframework.util.Assert;
 
-import reactor.core.publisher.Mono;
-
 /**
  * Delegates to a collection of {@link ServerLogoutHandler} implementations.
  *
@@ -35,21 +36,24 @@ import reactor.core.publisher.Mono;
  * @since 5.1
  */
 public class DelegatingServerLogoutHandler implements ServerLogoutHandler {
-	private final List<ServerLogoutHandler> delegates;
+	private final List<ServerLogoutHandler> delegates = new ArrayList<>();
 
 	public DelegatingServerLogoutHandler(ServerLogoutHandler... delegates) {
 		Assert.notEmpty(delegates, "delegates cannot be null or empty");
-		this.delegates = Arrays.asList(delegates);
+		this.delegates.addAll(Arrays.asList(delegates));
 	}
 
-	public DelegatingServerLogoutHandler(List<ServerLogoutHandler> delegates) {
+	public DelegatingServerLogoutHandler(Collection<ServerLogoutHandler> delegates) {
 		Assert.notEmpty(delegates, "delegates cannot be null or empty");
-		this.delegates = new ArrayList<>(delegates);
+		this.delegates.addAll(delegates);
 	}
 
 	@Override
 	public Mono<Void> logout(WebFilterExchange exchange, Authentication authentication) {
-		Stream<Mono<Void>> results = this.delegates.stream().map(delegate -> delegate.logout(exchange, authentication));
-		return Mono.when(results.collect(Collectors.toList()));
+		return Mono.when(this.delegates.stream()
+				.filter(Objects::nonNull)
+				.map(delegate -> delegate.logout(exchange, authentication))
+				.collect(Collectors.toList())
+		);
 	}
 }

+ 7 - 3
web/src/main/java/org/springframework/security/web/server/authentication/logout/LogoutWebFilter.java

@@ -16,17 +16,17 @@
 
 package org.springframework.security.web.server.authentication.logout;
 
-import org.springframework.http.HttpMethod;
-import org.springframework.security.core.context.ReactiveSecurityContextHolder;
-import org.springframework.util.Assert;
 import reactor.core.publisher.Mono;
 
+import org.springframework.http.HttpMethod;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.security.web.server.WebFilterExchange;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
+import org.springframework.util.Assert;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
 import org.springframework.web.server.WebFilterChain;
@@ -85,6 +85,10 @@ public class LogoutWebFilter implements WebFilter {
 		this.logoutSuccessHandler = logoutSuccessHandler;
 	}
 
+	/**
+	 * Sets the {@link ServerLogoutHandler}. The default is {@link SecurityContextServerLogoutHandler}.
+	 * @param logoutHandler The handler to use
+	 */
 	public void setLogoutHandler(ServerLogoutHandler logoutHandler) {
 		Assert.notNull(logoutHandler, "logoutHandler must not be null");
 		this.logoutHandler = logoutHandler;

+ 86 - 0
web/src/test/java/org/springframework/security/web/server/authentication/logout/LogoutWebFilterTests.java

@@ -0,0 +1,86 @@
+/*
+ * Copyright 2002-2018 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.authentication.logout;
+
+import static org.assertj.core.api.Assertions.assertThat;
+
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.stream.Collectors;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+
+import org.springframework.test.util.ReflectionTestUtils;
+
+/**
+ * @author Eric Deandrea
+ * @since  5.1
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class LogoutWebFilterTests {
+	@Mock
+	private ServerLogoutHandler handler1;
+
+	@Mock
+	private ServerLogoutHandler handler2;
+
+	@Mock
+	private ServerLogoutHandler handler3;
+
+	private LogoutWebFilter logoutWebFilter = new LogoutWebFilter();
+
+	@Test
+	public void defaultLogoutHandler() {
+		assertThat(getLogoutHandler())
+				.isNotNull()
+				.isExactlyInstanceOf(SecurityContextServerLogoutHandler.class);
+	}
+
+	@Test
+	public void singleLogoutHandler() {
+		this.logoutWebFilter.setLogoutHandler(this.handler1);
+		this.logoutWebFilter.setLogoutHandler(this.handler2);
+
+		assertThat(getLogoutHandler())
+				.isNotNull()
+				.isInstanceOf(ServerLogoutHandler.class)
+				.isNotInstanceOf(SecurityContextServerLogoutHandler.class)
+				.extracting(ServerLogoutHandler::getClass)
+				.isEqualTo(this.handler2.getClass());
+	}
+
+	@Test
+	public void multipleLogoutHandlers() {
+		this.logoutWebFilter.setLogoutHandler(new DelegatingServerLogoutHandler(this.handler1, this.handler2, this.handler3));
+
+		assertThat(getLogoutHandler())
+				.isNotNull()
+				.isExactlyInstanceOf(DelegatingServerLogoutHandler.class)
+				.extracting(delegatingLogoutHandler -> ((Collection<ServerLogoutHandler>) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates"))
+						.stream()
+						.map(ServerLogoutHandler::getClass)
+						.collect(Collectors.toList()))
+				.isEqualTo(Arrays.asList(this.handler1.getClass(), this.handler2.getClass(), this.handler3.getClass()));
+	}
+
+	private ServerLogoutHandler getLogoutHandler() {
+		return (ServerLogoutHandler) ReflectionTestUtils.getField(this.logoutWebFilter, LogoutWebFilter.class, "logoutHandler");
+	}
+}