浏览代码

Add WebFlux CSRF Protection

Fixes gh-4734
Rob Winch 7 年之前
父节点
当前提交
8da2c7f657
共有 16 个文件被更改,包括 943 次插入12 次删除
  1. 4 0
      config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java
  2. 62 0
      config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java
  3. 2 0
      config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java
  4. 2 0
      config/src/test/java/org/springframework/security/config/web/server/AuthorizeExchangeBuilderTests.java
  5. 33 0
      test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java
  6. 37 1
      test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java
  7. 1 0
      web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java
  8. 33 0
      web/src/main/java/org/springframework/security/web/server/csrf/CsrfException.java
  9. 48 0
      web/src/main/java/org/springframework/security/web/server/csrf/CsrfToken.java
  10. 140 0
      web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java
  11. 77 0
      web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java
  12. 58 0
      web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRepository.java
  13. 122 0
      web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java
  14. 27 11
      web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java
  15. 185 0
      web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java
  16. 112 0
      web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java

+ 4 - 0
config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java

@@ -23,6 +23,10 @@ package org.springframework.security.config.web.server;
 public enum SecurityWebFiltersOrder {
 public enum SecurityWebFiltersOrder {
 	FIRST(Integer.MIN_VALUE),
 	FIRST(Integer.MIN_VALUE),
 	HTTP_HEADERS_WRITER,
 	HTTP_HEADERS_WRITER,
+	/**
+	 * {@link org.springframework.security.web.server.csrf.CsrfWebFilter}
+	 */
+	CSRF,
 	/**
 	/**
 	 * Instance of AuthenticationWebFilter
 	 * Instance of AuthenticationWebFilter
 	 */
 	 */

+ 62 - 0
config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java

@@ -44,11 +44,14 @@ import org.springframework.security.web.server.authorization.AuthorizationContex
 import org.springframework.security.web.server.authorization.AuthorizationWebFilter;
 import org.springframework.security.web.server.authorization.AuthorizationWebFilter;
 import org.springframework.security.web.server.authorization.DelegatingReactiveAuthorizationManager;
 import org.springframework.security.web.server.authorization.DelegatingReactiveAuthorizationManager;
 import org.springframework.security.web.server.authorization.ExceptionTranslationWebFilter;
 import org.springframework.security.web.server.authorization.ExceptionTranslationWebFilter;
+import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
 import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
 import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
 import org.springframework.security.web.server.context.ReactorContextWebFilter;
 import org.springframework.security.web.server.context.ReactorContextWebFilter;
 import org.springframework.security.web.server.context.ServerSecurityContextRepository;
 import org.springframework.security.web.server.context.ServerSecurityContextRepository;
 import org.springframework.security.web.server.context.NoOpServerSecurityContextRepository;
 import org.springframework.security.web.server.context.NoOpServerSecurityContextRepository;
 import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
 import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository;
+import org.springframework.security.web.server.csrf.CsrfWebFilter;
+import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository;
 import org.springframework.security.web.server.header.CacheControlServerHttpHeadersWriter;
 import org.springframework.security.web.server.header.CacheControlServerHttpHeadersWriter;
 import org.springframework.security.web.server.header.CompositeServerHttpHeadersWriter;
 import org.springframework.security.web.server.header.CompositeServerHttpHeadersWriter;
 import org.springframework.security.web.server.header.ContentTypeOptionsServerHttpHeadersWriter;
 import org.springframework.security.web.server.header.ContentTypeOptionsServerHttpHeadersWriter;
@@ -90,6 +93,8 @@ public class ServerHttpSecurity {
 
 
 	private HeaderBuilder headers;
 	private HeaderBuilder headers;
 
 
+	private CsrfBuilder csrf = new CsrfBuilder();
+
 	private HttpBasicBuilder httpBasic;
 	private HttpBasicBuilder httpBasic;
 
 
 	private FormLoginBuilder formLogin;
 	private FormLoginBuilder formLogin;
@@ -139,6 +144,13 @@ public class ServerHttpSecurity {
 		return this;
 		return this;
 	}
 	}
 
 
+	public CsrfBuilder csrf() {
+		if(this.csrf == null) {
+			this.csrf = new CsrfBuilder();
+		}
+		return this.csrf;
+	}
+
 	public HttpBasicBuilder httpBasic() {
 	public HttpBasicBuilder httpBasic() {
 		if(this.httpBasic == null) {
 		if(this.httpBasic == null) {
 			this.httpBasic = new HttpBasicBuilder();
 			this.httpBasic = new HttpBasicBuilder();
@@ -191,6 +203,9 @@ public class ServerHttpSecurity {
 		if(securityContextRepositoryWebFilter != null) {
 		if(securityContextRepositoryWebFilter != null) {
 			this.webFilters.add(securityContextRepositoryWebFilter);
 			this.webFilters.add(securityContextRepositoryWebFilter);
 		}
 		}
+		if(this.csrf != null) {
+			this.csrf.configure(this);
+		}
 		if(this.httpBasic != null) {
 		if(this.httpBasic != null) {
 			this.httpBasic.authenticationManager(this.authenticationManager);
 			this.httpBasic.authenticationManager(this.authenticationManager);
 			if(this.serverSecurityContextRepository != null) {
 			if(this.serverSecurityContextRepository != null) {
@@ -340,6 +355,53 @@ public class ServerHttpSecurity {
 		}
 		}
 	}
 	}
 
 
+	/**
+	 * @author Rob Winch
+	 * @since 5.0
+	 */
+	public class CsrfBuilder {
+		private CsrfWebFilter filter = new CsrfWebFilter();
+
+		public CsrfBuilder serverAccessDeniedHandler(
+			ServerAccessDeniedHandler serverAccessDeniedHandler) {
+			this.filter.setServerAccessDeniedHandler(serverAccessDeniedHandler);
+			return this;
+		}
+
+		public CsrfBuilder csrfTokenAttributeName(String csrfTokenAttributeName) {
+			Assert.notNull(csrfTokenAttributeName, "csrfTokenAttributeName cannot be null");
+			this.filter.setCsrfTokenAttributeName(csrfTokenAttributeName);
+			return this;
+		}
+
+		public CsrfBuilder serverCsrfTokenRepository(
+			ServerCsrfTokenRepository serverCsrfTokenRepository) {
+			this.filter.setServerCsrfTokenRepository(serverCsrfTokenRepository);
+			return this;
+		}
+
+		public CsrfBuilder requireCsrfProtectionMatcher(
+			ServerWebExchangeMatcher requireCsrfProtectionMatcher) {
+			this.filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);
+			return this;
+		}
+
+		public ServerHttpSecurity and() {
+			return ServerHttpSecurity.this;
+		}
+
+		public ServerHttpSecurity disable() {
+			ServerHttpSecurity.this.csrf = null;
+			return ServerHttpSecurity.this;
+		}
+
+		protected void configure(ServerHttpSecurity http) {
+			http.addFilterAt(this.filter, SecurityWebFiltersOrder.CSRF);
+		}
+
+		private CsrfBuilder() {}
+	}
+
 	/**
 	/**
 	 * @author Rob Winch
 	 * @author Rob Winch
 	 * @since 5.0
 	 * @since 5.0

+ 2 - 0
config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java

@@ -55,6 +55,7 @@ import java.nio.charset.StandardCharsets;
 import java.security.Principal;
 import java.security.Principal;
 
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.csrf;
 import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.Credentials.basicAuthenticationCredentials;
 import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.Credentials.basicAuthenticationCredentials;
 import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.basicAuthentication;
 import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.basicAuthentication;
 
 
@@ -213,6 +214,7 @@ public class EnableWebFluxSecurityTests {
 		data.add("username", "user");
 		data.add("username", "user");
 		data.add("password", "password");
 		data.add("password", "password");
 		client
 		client
+			.mutateWith(csrf())
 			.post()
 			.post()
 			.uri("/login")
 			.uri("/login")
 			.body(BodyInserters.fromFormData(data))
 			.body(BodyInserters.fromFormData(data))

+ 2 - 0
config/src/test/java/org/springframework/security/config/web/server/AuthorizeExchangeBuilderTests.java

@@ -32,6 +32,7 @@ public class AuthorizeExchangeBuilderTests {
 	@Test
 	@Test
 	public void antMatchersWhenMethodAndPatternsThenDiscriminatesByMethod() {
 	public void antMatchersWhenMethodAndPatternsThenDiscriminatesByMethod() {
 		this.http
 		this.http
+			.csrf().disable()
 			.authorizeExchange()
 			.authorizeExchange()
 				.pathMatchers(HttpMethod.POST, "/a", "/b").denyAll()
 				.pathMatchers(HttpMethod.POST, "/a", "/b").denyAll()
 				.anyExchange().permitAll();
 				.anyExchange().permitAll();
@@ -63,6 +64,7 @@ public class AuthorizeExchangeBuilderTests {
 	@Test
 	@Test
 	public void antMatchersWhenPatternsThenAnyMethod() {
 	public void antMatchersWhenPatternsThenAnyMethod() {
 		this.http
 		this.http
+			.csrf().disable()
 			.authorizeExchange()
 			.authorizeExchange()
 				.pathMatchers("/a", "/b").denyAll()
 				.pathMatchers("/a", "/b").denyAll()
 				.anyExchange().permitAll();
 				.anyExchange().permitAll();

+ 33 - 0
test/src/main/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurers.java

@@ -26,6 +26,10 @@ import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.core.userdetails.User;
 import org.springframework.security.core.userdetails.User;
 import org.springframework.security.core.userdetails.UserDetails;
 import org.springframework.security.core.userdetails.UserDetails;
+import org.springframework.security.web.server.csrf.CsrfToken;
+import org.springframework.security.web.server.csrf.CsrfWebFilter;
+import org.springframework.security.web.server.csrf.WebSessionServerCsrfTokenRepository;
+import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.test.web.reactive.server.MockServerConfigurer;
 import org.springframework.test.web.reactive.server.MockServerConfigurer;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.test.web.reactive.server.WebTestClientConfigurer;
 import org.springframework.test.web.reactive.server.WebTestClientConfigurer;
@@ -107,6 +111,35 @@ public class SecurityMockServerConfigurers {
 		return new UserExchangeMutator(username);
 		return new UserExchangeMutator(username);
 	}
 	}
 
 
+	public static CsrfMutator csrf() {
+		return new CsrfMutator();
+	}
+
+	public static class CsrfMutator implements WebTestClientConfigurer, MockServerConfigurer {
+
+		@Override
+		public void afterConfigurerAdded(WebTestClient.Builder builder,
+			@Nullable WebHttpHandlerBuilder httpHandlerBuilder,
+			@Nullable ClientHttpConnector connector) {
+			CsrfWebFilter filter = new CsrfWebFilter();
+			filter.setRequireCsrfProtectionMatcher( e -> ServerWebExchangeMatcher.MatchResult.notMatch());
+			httpHandlerBuilder.filters( filters -> filters.add(0, filter));
+		}
+
+		@Override
+		public void afterConfigureAdded(
+			WebTestClient.MockServerSpec<?> serverSpec) {
+
+		}
+
+		@Override
+		public void beforeServerCreated(WebHttpHandlerBuilder builder) {
+
+		}
+
+		private CsrfMutator() {}
+	}
+
 	/**
 	/**
 	 * Updates the WebServerExchange using {@code {@link SecurityMockServerConfigurers#mockUser(UserDetails)}. Defaults to use a
 	 * Updates the WebServerExchange using {@code {@link SecurityMockServerConfigurers#mockUser(UserDetails)}. Defaults to use a
 	 * password of "password" and granted authorities of "ROLE_USER".
 	 * password of "password" and granted authorities of "ROLE_USER".

+ 37 - 1
test/src/test/java/org/springframework/security/test/web/reactive/server/SecurityMockServerConfigurersTests.java

@@ -18,15 +18,18 @@ package org.springframework.security.test.web.reactive.server;
 
 
 import org.junit.Test;
 import org.junit.Test;
 import org.springframework.http.HttpHeaders;
 import org.springframework.http.HttpHeaders;
+import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
 import org.springframework.http.MediaType;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.userdetails.User;
 import org.springframework.security.core.userdetails.User;
 import org.springframework.security.core.userdetails.UserDetails;
 import org.springframework.security.core.userdetails.UserDetails;
 import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
 import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter;
+import org.springframework.security.web.server.csrf.CsrfWebFilter;
 import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.test.web.reactive.server.WebTestClient;
 
 
 import java.security.Principal;
 import java.security.Principal;
 
 
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.*;
 import static org.springframework.security.test.web.reactive.server.SecurityMockServerConfigurers.*;
 
 
 /**
 /**
@@ -36,7 +39,7 @@ import static org.springframework.security.test.web.reactive.server.SecurityMock
 public class SecurityMockServerConfigurersTests extends AbstractMockServerConfigurersTests {
 public class SecurityMockServerConfigurersTests extends AbstractMockServerConfigurersTests {
 	WebTestClient client = WebTestClient
 	WebTestClient client = WebTestClient
 		.bindToController(controller)
 		.bindToController(controller)
-		.webFilter(new SecurityContextServerWebExchangeWebFilter())
+		.webFilter( new CsrfWebFilter(), new SecurityContextServerWebExchangeWebFilter())
 		.apply(springSecurity())
 		.apply(springSecurity())
 		.configureClient()
 		.configureClient()
 		.defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
 		.defaultHeader(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
@@ -144,4 +147,37 @@ public class SecurityMockServerConfigurersTests extends AbstractMockServerConfig
 
 
 		assertPrincipalCreatedFromUserDetails(actual, userBuilder.build());
 		assertPrincipalCreatedFromUserDetails(actual, userBuilder.build());
 	}
 	}
+
+	@Test
+	public void csrfWhenMutateWithThenDisablesCsrf() {
+		this.client
+			.post()
+			.exchange()
+			.expectStatus().isEqualTo(HttpStatus.FORBIDDEN)
+			.expectBody().consumeWith( b -> assertThat(new String(b.getResponseBody())).contains("CSRF"));
+
+		this.client
+			.mutateWith(csrf())
+			.post()
+			.exchange()
+			.expectStatus().isOk();
+
+	}
+
+	@Test
+	public void csrfWhenGlobalThenDisablesCsrf() {
+		this.client = WebTestClient
+			.bindToController(this.controller)
+			.webFilter(new CsrfWebFilter())
+			.apply(springSecurity())
+			.apply(csrf())
+			.configureClient()
+			.build();
+
+		this.client
+			.get()
+			.exchange()
+			.expectStatus().isOk();
+
+	}
 }
 }

+ 1 - 0
web/src/main/java/org/springframework/security/web/server/authentication/AuthenticationWebFilter.java

@@ -52,6 +52,7 @@ public class AuthenticationWebFilter implements WebFilter {
 	private ServerSecurityContextRepository serverSecurityContextRepository = NoOpServerSecurityContextRepository.getInstance();
 	private ServerSecurityContextRepository serverSecurityContextRepository = NoOpServerSecurityContextRepository.getInstance();
 
 
 	private ServerWebExchangeMatcher requiresAuthenticationMatcher = ServerWebExchangeMatchers.anyExchange();
 	private ServerWebExchangeMatcher requiresAuthenticationMatcher = ServerWebExchangeMatchers.anyExchange();
+
 	public AuthenticationWebFilter(ReactiveAuthenticationManager authenticationManager) {
 	public AuthenticationWebFilter(ReactiveAuthenticationManager authenticationManager) {
 		Assert.notNull(authenticationManager, "authenticationManager cannot be null");
 		Assert.notNull(authenticationManager, "authenticationManager cannot be null");
 		this.authenticationManager = authenticationManager;
 		this.authenticationManager = authenticationManager;

+ 33 - 0
web/src/main/java/org/springframework/security/web/server/csrf/CsrfException.java

@@ -0,0 +1,33 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.web.server.csrf;
+
+import org.springframework.security.access.AccessDeniedException;
+import org.springframework.security.web.csrf.CsrfToken;
+
+/**
+ * Thrown when an invalid or missing {@link CsrfToken} is found in the HttpServletRequest
+ *
+ * @author Rob Winch
+ * @since 3.2
+ */
+@SuppressWarnings("serial")
+public class CsrfException extends AccessDeniedException {
+
+	public CsrfException(String message) {
+		super(message);
+	}
+}

+ 48 - 0
web/src/main/java/org/springframework/security/web/server/csrf/CsrfToken.java

@@ -0,0 +1,48 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.csrf;
+
+import java.io.Serializable;
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+public interface CsrfToken extends Serializable {
+
+	/**
+	 * Gets the HTTP header that the CSRF is populated on the response and can be placed
+	 * on requests instead of the parameter. Cannot be null.
+	 *
+	 * @return the HTTP header that the CSRF is populated on the response and can be
+	 * placed on requests instead of the parameter
+	 */
+	String getHeaderName();
+
+	/**
+	 * Gets the HTTP parameter name that should contain the token. Cannot be null.
+	 * @return the HTTP parameter name that should contain the token.
+	 */
+	String getParameterName();
+
+	/**
+	 * Gets the token value. Cannot be null.
+	 * @return the token value
+	 */
+	String getToken();
+
+}

+ 140 - 0
web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java

@@ -0,0 +1,140 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.csrf;
+
+import org.springframework.http.HttpMethod;
+import org.springframework.http.HttpStatus;
+import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
+import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
+import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
+import org.springframework.util.Assert;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebFilter;
+import org.springframework.web.server.WebFilterChain;
+import reactor.core.publisher.Mono;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+
+/**
+ * <p>
+ * Applies
+ * <a href="https://www.owasp.org/index.php/Cross-Site_Request_Forgery_(CSRF)" >CSRF</a>
+ * protection using a synchronizer token pattern. Developers are required to ensure that
+ * {@link CsrfWebFilter} is invoked for any request that allows state to change. Typically
+ * this just means that they should ensure their web application follows proper REST
+ * semantics (i.e. do not change state with the HTTP methods GET, HEAD, TRACE, OPTIONS).
+ * </p>
+ *
+ * <p>
+ * Typically the {@link ServerCsrfTokenRepository} implementation chooses to store the
+ * {@link CsrfToken} in {@link org.springframework.web.server.WebSession} with
+ * {@link WebSessionServerCsrfTokenRepository}. This is preferred to storing the token in
+ * a cookie which can be modified by a client application.
+ * </p>
+ *
+ * @author Rob Winch
+ * @since 5.0
+ */
+public class CsrfWebFilter implements WebFilter {
+
+	private ServerWebExchangeMatcher requireCsrfProtectionMatcher = new DefaultRequireCsrfProtectionMatcher();
+
+	private ServerCsrfTokenRepository serverCsrfTokenRepository = new WebSessionServerCsrfTokenRepository();
+
+	private ServerAccessDeniedHandler serverAccessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN);
+
+	private String csrfTokenAttributeName = "csrf";
+
+	public void setServerAccessDeniedHandler(
+		ServerAccessDeniedHandler serverAccessDeniedHandler) {
+		Assert.notNull(serverAccessDeniedHandler, "serverAccessDeniedHandler");
+		this.serverAccessDeniedHandler = serverAccessDeniedHandler;
+	}
+
+	public void setCsrfTokenAttributeName(String csrfTokenAttributeName) {
+		Assert.notNull(csrfTokenAttributeName, "csrfTokenAttributeName cannot be null");
+		this.csrfTokenAttributeName = csrfTokenAttributeName;
+	}
+
+	public void setServerCsrfTokenRepository(
+		ServerCsrfTokenRepository serverCsrfTokenRepository) {
+		Assert.notNull(serverCsrfTokenRepository, "serverCsrfTokenRepository cannot be null");
+		this.serverCsrfTokenRepository = serverCsrfTokenRepository;
+	}
+
+	public void setRequireCsrfProtectionMatcher(
+		ServerWebExchangeMatcher requireCsrfProtectionMatcher) {
+		Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null");
+		this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
+	}
+
+	@Override
+	public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
+		return this.requireCsrfProtectionMatcher.matches(exchange)
+			.filter( matchResult -> matchResult.isMatch())
+			.filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName()))
+			.flatMap(m -> validateToken(exchange))
+			.flatMap(m -> continueFilterChain(exchange, chain))
+			.switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty()))
+			.onErrorResume(CsrfException.class, e -> this.serverAccessDeniedHandler.handle(exchange, e));
+	}
+
+	private Mono<Void> validateToken(ServerWebExchange exchange) {
+		return this.serverCsrfTokenRepository.loadToken(exchange)
+			.switchIfEmpty(Mono.error(new CsrfException("CSRF Token has been associated to this client")))
+			.filterWhen(expected -> containsValidCsrfToken(exchange, expected))
+			.switchIfEmpty(Mono.error(new CsrfException("Invalid CSRF Token")))
+			.then();
+	}
+
+	private Mono<Boolean> containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) {
+		return exchange.getFormData()
+			.flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
+			.switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
+			.map(actual -> actual.equals(expected.getToken()));
+	}
+
+	private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
+		return csrfToken(exchange)
+			.doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken))
+			.doOnSuccess(csrfToken -> exchange.getAttributes().put(this.csrfTokenAttributeName, csrfToken))
+			.flatMap( t -> chain.filter(exchange))
+			.then();
+	}
+
+	private Mono<Mono<CsrfToken>> csrfToken(ServerWebExchange exchange) {
+		return this.serverCsrfTokenRepository.loadToken(exchange)
+			.switchIfEmpty(this.serverCsrfTokenRepository.generateToken(exchange))
+			.as(Mono::just); // FIXME eager saving of CsrfToken with .as
+	}
+
+	private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher {
+		private static final Set<HttpMethod> ALLOWED_METHODS = new HashSet<>(
+			Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS));
+
+		@Override
+		public Mono<MatchResult> matches(ServerWebExchange exchange) {
+			return Mono.just(exchange.getRequest())
+				.map(r -> r.getMethod())
+				.filter(m -> ALLOWED_METHODS.contains(m))
+				.flatMap(m -> MatchResult.notMatch())
+				.switchIfEmpty(MatchResult.match());
+		}
+	}
+}

+ 77 - 0
web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.web.server.csrf;
+
+import org.springframework.util.Assert;
+
+/**
+ * A CSRF token that is used to protect against CSRF attacks.
+ *
+ * @author Rob Winch
+ * @since 5.0
+ */
+@SuppressWarnings("serial")
+public final class DefaultCsrfToken implements CsrfToken {
+
+	private final String token;
+
+	private final String parameterName;
+
+	private final String headerName;
+
+	/**
+	 * Creates a new instance
+	 * @param headerName the HTTP header name to use
+	 * @param parameterName the HTTP parameter name to use
+	 * @param token the value of the token (i.e. expected value of the HTTP parameter of
+	 * parametername).
+	 */
+	public DefaultCsrfToken(String headerName, String parameterName, String token) {
+		Assert.hasLength(headerName, "headerName cannot be null or empty");
+		Assert.hasLength(parameterName, "parameterName cannot be null or empty");
+		Assert.hasLength(token, "token cannot be null or empty");
+		this.headerName = headerName;
+		this.parameterName = parameterName;
+		this.token = token;
+	}
+
+	/*
+	 * (non-Javadoc)
+	 *
+	 * @see org.springframework.security.web.csrf.CsrfToken#getHeaderName()
+	 */
+	public String getHeaderName() {
+		return this.headerName;
+	}
+
+	/*
+	 * (non-Javadoc)
+	 *
+	 * @see org.springframework.security.web.csrf.CsrfToken#getParameterName()
+	 */
+	public String getParameterName() {
+		return this.parameterName;
+	}
+
+	/*
+	 * (non-Javadoc)
+	 *
+	 * @see org.springframework.security.web.csrf.CsrfToken#getToken()
+	 */
+	public String getToken() {
+		return this.token;
+	}
+}

+ 58 - 0
web/src/main/java/org/springframework/security/web/server/csrf/ServerCsrfTokenRepository.java

@@ -0,0 +1,58 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.web.server.csrf;
+
+import org.springframework.web.server.ServerWebExchange;
+import reactor.core.publisher.Mono;
+
+/**
+ * An API to allow changing the method in which the expected {@link CsrfToken} is
+ * associated to the {@link ServerWebExchange}. For example, it may be stored in
+ * {@link org.springframework.web.server.WebSession}.
+ *
+ * @see WebSessionServerCsrfTokenRepository
+ *
+ * @author Rob Winch
+ * @since 5.0
+ *
+ */
+public interface ServerCsrfTokenRepository {
+
+	/**
+	 * Generates a {@link CsrfToken}
+	 *
+	 * @param exchange the {@link ServerWebExchange} to use
+	 * @return the {@link CsrfToken} that was generated. Cannot be null.
+	 */
+	Mono<CsrfToken> generateToken(ServerWebExchange exchange);
+
+	/**
+	 * Saves the {@link CsrfToken} using the {@link ServerWebExchange}. If the
+	 * {@link CsrfToken} is null, it is the same as deleting it.
+	 *
+	 * @param exchange the {@link ServerWebExchange} to use
+	 * @param token the {@link CsrfToken} to save or null to delete
+	 */
+	Mono<Void> saveToken(ServerWebExchange exchange, CsrfToken token);
+
+	/**
+	 * Loads the expected {@link CsrfToken} from the {@link ServerWebExchange}
+	 *
+	 * @param exchange the {@link ServerWebExchange} to use
+	 * @return the {@link CsrfToken} or null if none exists
+	 */
+	Mono<CsrfToken> loadToken(ServerWebExchange exchange);
+}

+ 122 - 0
web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java

@@ -0,0 +1,122 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.web.server.csrf;
+
+import org.springframework.util.Assert;
+import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebSession;
+import reactor.core.publisher.Mono;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpSession;
+import java.util.Map;
+import java.util.UUID;
+
+/**
+ * A {@link ServerCsrfTokenRepository} that stores the {@link CsrfToken} in the
+ * {@link HttpSession}.
+ *
+ * @author Rob Winch
+ * @since 5.0
+ */
+public class WebSessionServerCsrfTokenRepository
+	implements ServerCsrfTokenRepository {
+	private static final String DEFAULT_CSRF_PARAMETER_NAME = "_csrf";
+
+	private static final String DEFAULT_CSRF_HEADER_NAME = "X-CSRF-TOKEN";
+
+	private static final String DEFAULT_CSRF_TOKEN_ATTR_NAME = WebSessionServerCsrfTokenRepository.class
+			.getName().concat(".CSRF_TOKEN");
+
+	private String parameterName = DEFAULT_CSRF_PARAMETER_NAME;
+
+	private String headerName = DEFAULT_CSRF_HEADER_NAME;
+
+	private String sessionAttributeName = DEFAULT_CSRF_TOKEN_ATTR_NAME;
+
+	@Override
+	public Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
+		return Mono.defer(() -> Mono.just(createCsrfToken()))
+			.flatMap(token -> save(exchange, token));
+	}
+
+	@Override
+	public Mono<Void> saveToken(ServerWebExchange exchange, CsrfToken token) {
+		return save(exchange, token)
+			.then();
+	}
+
+	private Mono<CsrfToken> save(ServerWebExchange exchange, CsrfToken token) {
+		return exchange.getSession()
+			.map(WebSession::getAttributes)
+			.flatMap( attrs -> save(attrs, token));
+	}
+
+	private Mono<CsrfToken> save(Map<String,Object> attributes, CsrfToken token) {
+		if(token == null) {
+			attributes.remove(this.sessionAttributeName);
+		} else {
+			attributes.put(this.sessionAttributeName, token);
+		}
+		return Mono.justOrEmpty(token);
+	}
+
+	@Override
+	public Mono<CsrfToken> loadToken(ServerWebExchange exchange) {
+		return exchange.getSession()
+			.filter( s -> s.getAttributes().containsKey(this.sessionAttributeName))
+			.map(s -> s.getAttribute(this.sessionAttributeName));
+	}
+
+	/**
+	 * Sets the {@link HttpServletRequest} parameter name that the {@link CsrfToken} is
+	 * expected to appear on
+	 * @param parameterName the new parameter name to use
+	 */
+	public void setParameterName(String parameterName) {
+		Assert.hasLength(parameterName, "parameterName cannot be null or empty");
+		this.parameterName = parameterName;
+	}
+
+	/**
+	 * Sets the header name that the {@link CsrfToken} is expected to appear on and the
+	 * header that the response will contain the {@link CsrfToken}.
+	 *
+	 * @param headerName the new header name to use
+	 */
+	public void setHeaderName(String headerName) {
+		Assert.hasLength(headerName, "headerName cannot be null or empty");
+		this.headerName = headerName;
+	}
+
+	/**
+	 * Sets the {@link HttpSession} attribute name that the {@link CsrfToken} is stored in
+	 * @param sessionAttributeName the new attribute name to use
+	 */
+	public void setSessionAttributeName(String sessionAttributeName) {
+		Assert.hasLength(sessionAttributeName,
+				"sessionAttributename cannot be null or empty");
+		this.sessionAttributeName = sessionAttributeName;
+	}
+
+	private CsrfToken createCsrfToken() {
+		return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken());
+	}
+
+	private String createNewToken() {
+		return UUID.randomUUID().toString();
+	}
+}

+ 27 - 11
web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java

@@ -23,6 +23,7 @@ import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
 import org.springframework.http.MediaType;
 import org.springframework.http.server.reactive.ServerHttpResponse;
 import org.springframework.http.server.reactive.ServerHttpResponse;
+import org.springframework.security.web.server.csrf.CsrfToken;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers;
 import org.springframework.util.MultiValueMap;
 import org.springframework.util.MultiValueMap;
@@ -50,21 +51,31 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
 	}
 	}
 
 
 	private Mono<Void> render(ServerWebExchange exchange) {
 	private Mono<Void> render(ServerWebExchange exchange) {
-		MultiValueMap<String, String> queryParams = exchange.getRequest()
-			.getQueryParams();
-		boolean isError = queryParams.containsKey("error");
-		boolean isLogoutSuccess = queryParams.containsKey("logout");
 		ServerHttpResponse result = exchange.getResponse();
 		ServerHttpResponse result = exchange.getResponse();
-		result.setStatusCode(HttpStatus.FOUND);
+		result.setStatusCode(HttpStatus.OK);
 		result.getHeaders().setContentType(MediaType.TEXT_HTML);
 		result.getHeaders().setContentType(MediaType.TEXT_HTML);
-		byte[] bytes = createPage(isError, isLogoutSuccess);
-		DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory();
-		DataBuffer buffer = bufferFactory.wrap(bytes);
-		return result.writeWith(Mono.just(buffer))
-			.doOnError( error -> DataBufferUtils.release(buffer));
+		return result.writeWith(createBuffer(exchange));
+//			.doOnError( error -> DataBufferUtils.release(buffer));
+	}
+
+	private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
+		MultiValueMap<String, String> queryParams = exchange.getRequest()
+			.getQueryParams();
+		Mono<CsrfToken> token = (Mono<CsrfToken>) exchange.getAttributes()
+			.getOrDefault(CsrfToken.class.getName(), Mono.<CsrfToken>empty());
+		return token
+			.map(LoginPageGeneratingWebFilter::csrfToken)
+			.defaultIfEmpty("")
+			.map(csrfTokenHtmlInput -> {
+				boolean isError = queryParams.containsKey("error");
+				boolean isLogoutSuccess = queryParams.containsKey("logout");
+				byte[] bytes = createPage(isError, isLogoutSuccess, csrfTokenHtmlInput);
+				DataBufferFactory bufferFactory = exchange.getResponse().bufferFactory();
+				return bufferFactory.wrap(bytes);
+			});
 	}
 	}
 
 
-	private static byte[] createPage(boolean isError, boolean isLogoutSuccess) {
+	private static byte[] createPage(boolean isError, boolean isLogoutSuccess, String csrfTokenHtmlInput) {
 		String page =  "<!DOCTYPE html>\n"
 		String page =  "<!DOCTYPE html>\n"
 			+ "<html lang=\"en\">\n"
 			+ "<html lang=\"en\">\n"
 			+ "  <head>\n"
 			+ "  <head>\n"
@@ -90,6 +101,7 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
 			+ "          <label for=\"password\" class=\"sr-only\">Password</label>\n"
 			+ "          <label for=\"password\" class=\"sr-only\">Password</label>\n"
 			+ "          <input type=\"password\" id=\"password\" name=\"password\" class=\"form-control\" placeholder=\"Password\" required>\n"
 			+ "          <input type=\"password\" id=\"password\" name=\"password\" class=\"form-control\" placeholder=\"Password\" required>\n"
 			+ "        </p>\n"
 			+ "        </p>\n"
+			+ csrfTokenHtmlInput
 			+ "        <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Sign in</button>\n"
 			+ "        <button class=\"btn btn-lg btn-primary btn-block\" type=\"submit\">Sign in</button>\n"
 			+ "      </form>\n"
 			+ "      </form>\n"
 			+ "    </div>\n"
 			+ "    </div>\n"
@@ -99,6 +111,10 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
 		return page.getBytes(Charset.defaultCharset());
 		return page.getBytes(Charset.defaultCharset());
 	}
 	}
 
 
+	private static String csrfToken(CsrfToken token) {
+		return "          <input type=\"hidden\" name=\"" + token.getParameterName() + "\" value=\"" + token.getToken() + "\">\n";
+	}
+
 	private static String createError(boolean isError) {
 	private static String createError(boolean isError) {
 		return isError ? "<div class=\"alert alert-danger\" role=\"alert\">Invalid credentials</div>" : "";
 		return isError ? "<div class=\"alert alert-danger\" role=\"alert\">Invalid credentials</div>" : "";
 	}
 	}

+ 185 - 0
web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java

@@ -0,0 +1,185 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.csrf;
+
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.mockito.Mock;
+import org.mockito.junit.MockitoJUnitRunner;
+import org.springframework.http.HttpStatus;
+import org.springframework.http.MediaType;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.web.server.MockServerWebExchange;
+import org.springframework.util.LinkedMultiValueMap;
+import org.springframework.util.MultiValueMap;
+import org.springframework.web.reactive.function.BodyInserters;
+import org.springframework.web.server.WebFilterChain;
+import org.springframework.web.server.WebSession;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+import reactor.test.publisher.PublisherProbe;
+
+import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.when;
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+@RunWith(MockitoJUnitRunner.class)
+public class CsrfWebFilterTests {
+	@Mock
+	private WebFilterChain chain;
+	@Mock
+	private ServerCsrfTokenRepository repository;
+
+	private CsrfToken token = new DefaultCsrfToken("csrf", "CSRF", "a");
+
+	private CsrfWebFilter csrfFilter = new CsrfWebFilter();
+
+	private MockServerWebExchange get = MockServerWebExchange.from(
+		MockServerHttpRequest.get("/"));
+
+	private MockServerWebExchange post = MockServerWebExchange.from(
+		MockServerHttpRequest.post("/"));
+
+	@Test
+	public void filterWhenGetThenSessionNotCreatedAndChainContinues() {
+		PublisherProbe<Void> chainResult = PublisherProbe.empty();
+		when(this.chain.filter(this.get)).thenReturn(chainResult.mono());
+
+		Mono<Void> result = this.csrfFilter.filter(this.get, this.chain);
+
+		StepVerifier.create(result)
+			.verifyComplete();
+
+		Mono<Boolean> isSessionStarted = this.get.getSession()
+			.map(WebSession::isStarted);
+		StepVerifier.create(isSessionStarted)
+			.expectNext(false)
+			.verifyComplete();
+
+		chainResult.assertWasSubscribed();
+	}
+
+	@Test
+	public void filterWhenPostAndNoTokenThenCsrfException() {
+		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
+
+		StepVerifier.create(result)
+			.verifyComplete();
+
+		assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN);
+	}
+
+	@Test
+	public void filterWhenPostAndEstablishedCsrfTokenAndRequestMissingTokenThenCsrfException() {
+		this.csrfFilter.setServerCsrfTokenRepository(this.repository);
+		when(this.repository.loadToken(any()))
+			.thenReturn(Mono.just(this.token));
+		when(this.repository.generateToken(any()))
+			.thenReturn(Mono.just(this.token));
+
+		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
+
+
+		StepVerifier.create(result)
+			.verifyComplete();
+
+		assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN);
+	}
+
+	@Test
+	public void filterWhenPostAndEstablishedCsrfTokenAndRequestParamInvalidTokenThenCsrfException() {
+		this.csrfFilter.setServerCsrfTokenRepository(this.repository);
+		when(this.repository.loadToken(any()))
+			.thenReturn(Mono.just(this.token));
+		when(this.repository.generateToken(any()))
+			.thenReturn(Mono.just(this.token));
+		this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
+			.body(this.token.getParameterName() + "="+this.token.getToken()+"INVALID"));
+
+		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
+
+		StepVerifier.create(result)
+			.verifyComplete();
+
+		assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN);
+	}
+
+	@Test
+	public void filterWhenPostAndEstablishedCsrfTokenAndRequestParamValidTokenThenContinues() {
+		PublisherProbe<Void> chainResult = PublisherProbe.empty();
+		when(this.chain.filter(any())).thenReturn(chainResult.mono());
+
+		this.csrfFilter.setServerCsrfTokenRepository(this.repository);
+		when(this.repository.loadToken(any()))
+			.thenReturn(Mono.just(this.token));
+		when(this.repository.generateToken(any()))
+			.thenReturn(Mono.just(this.token));
+		this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
+			.contentType(MediaType.APPLICATION_FORM_URLENCODED)
+			.body(this.token.getParameterName() + "="+this.token.getToken()));
+
+		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
+
+		StepVerifier.create(result)
+			.verifyComplete();
+
+		chainResult.assertWasSubscribed();
+	}
+
+	@Test
+	public void filterWhenPostAndEstablishedCsrfTokenAndHeaderInvalidTokenThenCsrfException() {
+		this.csrfFilter.setServerCsrfTokenRepository(this.repository);
+		when(this.repository.loadToken(any()))
+			.thenReturn(Mono.just(this.token));
+		when(this.repository.generateToken(any()))
+			.thenReturn(Mono.just(this.token));
+		this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
+			.header(this.token.getHeaderName(), this.token.getToken()+"INVALID"));
+
+		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
+
+		StepVerifier.create(result)
+			.verifyComplete();
+
+		assertThat(this.post.getResponse().getStatusCode()).isEqualTo(HttpStatus.FORBIDDEN);
+	}
+
+	@Test
+	public void filterWhenPostAndEstablishedCsrfTokenAndHeaderValidTokenThenContinues() {
+		PublisherProbe<Void> chainResult = PublisherProbe.empty();
+		when(this.chain.filter(any())).thenReturn(chainResult.mono());
+
+		this.csrfFilter.setServerCsrfTokenRepository(this.repository);
+		when(this.repository.loadToken(any()))
+			.thenReturn(Mono.just(this.token));
+		when(this.repository.generateToken(any()))
+			.thenReturn(Mono.just(this.token));
+		this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
+			.header(this.token.getHeaderName(), this.token.getToken()));
+
+		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
+
+		StepVerifier.create(result)
+			.verifyComplete();
+
+		chainResult.assertWasSubscribed();
+	}
+}

+ 112 - 0
web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java

@@ -0,0 +1,112 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server.csrf;
+
+import org.junit.Test;
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.web.server.MockServerWebExchange;
+import org.springframework.web.server.WebSession;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+
+import java.util.Map;
+
+import static org.assertj.core.api.Assertions.*;
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+public class WebSessionServerCsrfTokenRepositoryTests {
+	private WebSessionServerCsrfTokenRepository repository = new WebSessionServerCsrfTokenRepository();
+
+	private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
+
+	@Test
+	public void generateTokenWhenNoSubscriptionThenNoSession() {
+		Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
+
+		Mono<Boolean> isSessionStarted = this.exchange.getSession()
+			.map(WebSession::isStarted);
+
+		StepVerifier.create(isSessionStarted)
+			.expectNext(false)
+			.verifyComplete();
+	}
+
+	@Test
+	public void generateTokenWhenSubscriptionThenAddsToSession() {
+		Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
+
+		StepVerifier.create(result)
+			.consumeNextWith( t -> assertThat(t).isNotNull())
+			.verifyComplete();
+
+		WebSession session = this.exchange.getSession().block();
+		Map<String, Object> attributes = session.getAttributes();
+
+		assertThat(session.isStarted()).isTrue();
+		assertThat(attributes).hasSize(1);
+		assertThat(attributes.values().iterator().next()).isInstanceOf(CsrfToken.class);
+
+	}
+
+	@Test
+	public void saveTokenWhenSetSessionAttributeNameAndSubscriptionThenAddsToSession() {
+		CsrfToken token = new DefaultCsrfToken("h","p", "t");
+		String attrName = "ATTR";
+		this.repository.setSessionAttributeName(attrName);
+		Mono<Void> result = this.repository.saveToken(this.exchange, token);
+
+		StepVerifier.create(result)
+			.verifyComplete();
+
+		WebSession session = this.exchange.getSession().block();
+
+		assertThat(session.isStarted()).isTrue();
+		assertThat(session.<WebSession>getAttribute(attrName)).isEqualTo(token);
+	}
+
+	@Test
+	public void saveTokenWhenNullThenDeletes() {
+		CsrfToken token = new DefaultCsrfToken("h","p", "t");
+		this.repository.saveToken(this.exchange, token).block();
+
+		Mono<Void> result = this.repository.saveToken(this.exchange, null);
+		StepVerifier.create(result)
+			.verifyComplete();
+
+		WebSession session = this.exchange.getSession().block();
+
+		assertThat(session.getAttributes()).isEmpty();
+	}
+
+	@Test
+	public void generateTokenAndLoadTokenDeleteTokenWhenNullThenDeletes() {
+		CsrfToken generate = this.repository.generateToken(this.exchange).block();
+
+		CsrfToken load = this.repository.loadToken(this.exchange).block();
+		assertThat(load).isEqualTo(generate);
+
+		this.repository.saveToken(this.exchange, null).block();
+		WebSession session = this.exchange.getSession().block();
+		assertThat(session.getAttributes()).isEmpty();
+
+		load = this.repository.loadToken(this.exchange).block();
+		assertThat(load).isNull();
+	}
+}