Browse Source

CsrfWebFilter places Mono<CsrfToken>

Fixes: gh-4855
Rob Winch 7 years ago
parent
commit
d55db837e1

+ 6 - 5
config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java

@@ -33,6 +33,7 @@ import org.springframework.test.web.reactive.server.WebTestClient;
 import org.springframework.web.bind.annotation.GetMapping;
 import org.springframework.web.bind.annotation.ResponseBody;
 import org.springframework.web.server.ServerWebExchange;
+import reactor.core.publisher.Mono;
 
 import static org.assertj.core.api.Assertions.assertThat;
 
@@ -314,9 +315,9 @@ public class FormLoginTests {
 	public static class CustomLoginPageController {
 		@ResponseBody
 		@GetMapping("/login")
-		public String login(ServerWebExchange exchange) {
-			CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
-			return
+		public Mono<String> login(ServerWebExchange exchange) {
+			Mono<CsrfToken> token = exchange.getAttributeOrDefault(CsrfToken.class.getName(), Mono.empty());
+			return token.map(t ->
 				"<!DOCTYPE html>\n"
 				+ "<html lang=\"en\">\n"
 				+ "  <head>\n"
@@ -338,12 +339,12 @@ public class FormLoginTests {
 				+ "          <label for=\"password\" class=\"sr-only\">Password</label>\n"
 				+ "          <input type=\"password\" id=\"password\" name=\"password\" placeholder=\"Password\" required>\n"
 				+ "        </p>\n"
-				+ "        <input type=\"hidden\" name=\"" + token.getParameterName() + "\" value=\"" + token.getToken() + "\">\n"
+				+ "        <input type=\"hidden\" name=\"" + t.getParameterName() + "\" value=\"" + t.getToken() + "\">\n"
 				+ "        <button type=\"submit\">Sign in</button>\n"
 				+ "      </form>\n"
 				+ "    </div>\n"
 				+ "  </body>\n"
-				+ "</html>";
+				+ "</html>");
 		}
 	}
 }

+ 0 - 2
config/src/test/java/org/springframework/security/config/web/server/RequestCacheTests.java

@@ -26,7 +26,6 @@ import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverB
 import org.springframework.security.test.web.reactive.server.WebTestClientBuilder;
 import org.springframework.security.web.server.SecurityWebFilterChain;
 import org.springframework.security.web.server.WebFilterChainProxy;
-import org.springframework.security.web.server.csrf.CsrfToken;
 import org.springframework.security.web.server.savedrequest.NoOpServerRequestCache;
 import org.springframework.stereotype.Controller;
 import org.springframework.test.web.reactive.server.WebTestClient;
@@ -126,7 +125,6 @@ public class RequestCacheTests {
 		@ResponseBody
 		@GetMapping("/secured")
 		public String login(ServerWebExchange exchange) {
-			CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
 			return
 				"<!DOCTYPE html>\n"
 					+ "<html lang=\"en\">\n"

+ 38 - 0
samples/javaconfig/webflux-form/src/main/java/sample/CsrfControllerAdvice.java

@@ -0,0 +1,38 @@
+/*
+ * Copyright 2002-2017 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package sample;
+
+import org.springframework.security.web.server.csrf.CsrfToken;
+import org.springframework.web.bind.annotation.ControllerAdvice;
+import org.springframework.web.bind.annotation.ModelAttribute;
+import org.springframework.web.server.ServerWebExchange;
+import reactor.core.publisher.Mono;
+
+import static org.springframework.security.web.reactive.result.view.CsrfRequestDataValueProcessor.DEFAULT_CSRF_ATTR_NAME;
+
+/**
+ * @author Rob Winch
+ * @since 5.0
+ */
+@ControllerAdvice
+public class CsrfControllerAdvice {
+	@ModelAttribute
+	public Mono<CsrfToken> csrfToken(ServerWebExchange exchange) {
+		Mono<CsrfToken> csrfToken = exchange.getAttribute(CsrfToken.class.getName());
+		return csrfToken.doOnSuccess(token -> exchange.getAttributes().put(DEFAULT_CSRF_ATTR_NAME, token));
+	}
+}

+ 5 - 1
web/src/main/java/org/springframework/security/web/reactive/result/view/CsrfRequestDataValueProcessor.java

@@ -30,6 +30,10 @@ import java.util.regex.Pattern;
  * @since 5.0
  */
 public class CsrfRequestDataValueProcessor implements RequestDataValueProcessor {
+	/**
+	 * The default request attribute to look for a {@link CsrfToken}.
+	 */
+	public static final String DEFAULT_CSRF_ATTR_NAME = "_csrf";
 
 	private static final Pattern DISABLE_CSRF_TOKEN_PATTERN = Pattern
 		.compile("(?i)^(GET|HEAD|TRACE|OPTIONS)$");
@@ -62,7 +66,7 @@ public class CsrfRequestDataValueProcessor implements RequestDataValueProcessor
 			exchange.getAttributes().remove(DISABLE_CSRF_TOKEN_ATTR);
 			return Collections.emptyMap();
 		}
-		CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
+		CsrfToken token = exchange.getAttribute(DEFAULT_CSRF_ATTR_NAME);
 		if(token == null) {
 			return Collections.emptyMap();
 		}

+ 10 - 6
web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java

@@ -47,12 +47,16 @@ import java.util.Set;
  * {@link WebSessionServerCsrfTokenRepository}. This is preferred to storing the token in
  * a cookie which can be modified by a client application.
  * </p>
+ * <p>
+ * The {@code Mono&lt;CsrfToken&gt;} is exposes as a request attribute with the name of
+ * {@code CsrfToken.class.getName()}. If the token is new it will automatically be saved
+ * at the time it is subscribed.
+ * </p>
  *
  * @author Rob Winch
  * @since 5.0
  */
 public class CsrfWebFilter implements WebFilter {
-
 	private ServerWebExchangeMatcher requireCsrfProtectionMatcher = new DefaultRequireCsrfProtectionMatcher();
 
 	private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository();
@@ -105,11 +109,11 @@ public class CsrfWebFilter implements WebFilter {
 	}
 
 	private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
-		return csrfToken(exchange)
-			.doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken))
-			.doOnSuccess(csrfToken -> exchange.getAttributes().put(csrfToken.getParameterName(), csrfToken))
-			.flatMap( t -> chain.filter(exchange))
-			.then();
+		return Mono.defer(() ->{
+			Mono<CsrfToken> csrfToken = csrfToken(exchange);
+			exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken);
+			return chain.filter(exchange);
+		});
 	}
 
 	private Mono<CsrfToken> csrfToken(ServerWebExchange exchange) {

+ 3 - 68
web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java

@@ -17,7 +17,6 @@ package org.springframework.security.web.server.csrf;
 
 import org.springframework.util.Assert;
 import org.springframework.web.server.ServerWebExchange;
-import org.springframework.web.server.WebSession;
 import reactor.core.publisher.Mono;
 
 import javax.servlet.http.HttpServletRequest;
@@ -49,20 +48,15 @@ public class WebSessionServerCsrfTokenRepository
 
 	@Override
 	public Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
-		return exchange.getSession()
-			.map(WebSession::getAttributes)
-			.map(this::createCsrfToken);
+		return Mono.fromCallable(() -> createCsrfToken());
 	}
 
 	@Override
 	public Mono<CsrfToken> saveToken(ServerWebExchange exchange, CsrfToken token) {
-		if(token != null) {
-			return Mono.just(token);
-		}
 		return exchange.getSession()
-			.doOnSuccess(session -> putToken(session.getAttributes(), token))
+			.doOnNext(session -> putToken(session.getAttributes(), token))
 			.flatMap(session -> session.changeSessionId())
-			.flatMap(r -> Mono.justOrEmpty(token));
+			.then(Mono.justOrEmpty(token));
 	}
 
 	private void putToken(Map<String, Object> attributes, CsrfToken token) {
@@ -111,11 +105,6 @@ public class WebSessionServerCsrfTokenRepository
 		this.sessionAttributeName = sessionAttributeName;
 	}
 
-
-	private CsrfToken createCsrfToken(Map<String, Object> attributes) {
-		return new LazyCsrfToken(attributes, createCsrfToken());
-	}
-
 	private CsrfToken createCsrfToken() {
 		return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken());
 	}
@@ -124,58 +113,4 @@ public class WebSessionServerCsrfTokenRepository
 		return UUID.randomUUID().toString();
 	}
 
-	private class LazyCsrfToken implements CsrfToken {
-		private final Map<String, Object> attributes;
-		private final CsrfToken delegate;
-
-		private LazyCsrfToken(Map<String, Object> attributes, CsrfToken delegate) {
-			this.attributes = attributes;
-			this.delegate = delegate;
-		}
-
-		@Override
-		public String getHeaderName() {
-			return this.delegate.getHeaderName();
-		}
-
-		@Override
-		public String getParameterName() {
-			return this.delegate.getParameterName();
-		}
-
-		@Override
-		public String getToken() {
-			putToken(this.attributes, this.delegate);
-			return this.delegate.getToken();
-		}
-
-		@Override
-		public boolean equals(Object o) {
-			if (this == o)
-				return true;
-			if (o == null || !(o instanceof CsrfToken))
-				return false;
-
-			CsrfToken that = (CsrfToken) o;
-
-			if (!getToken().equals(that.getToken()))
-				return false;
-			if (!getParameterName().equals(that.getParameterName()))
-				return false;
-			return getHeaderName().equals(that.getHeaderName());
-		}
-
-		@Override
-		public int hashCode() {
-			int result = getToken().hashCode();
-			result = 31 * result + getParameterName().hashCode();
-			result = 31 * result + getHeaderName().hashCode();
-			return result;
-		}
-
-		@Override
-		public String toString() {
-			return "LazyCsrfToken{" + "delegate=" + this.delegate + '}';
-		}
-	}
 }

+ 2 - 2
web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java

@@ -60,8 +60,8 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
 	private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
 		MultiValueMap<String, String> queryParams = exchange.getRequest()
 			.getQueryParams();
-		CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
-		return Mono.justOrEmpty(token)
+		Mono<CsrfToken> token = exchange.getAttributeOrDefault(CsrfToken.class.getName(), Mono.empty());
+		return token
 			.map(LoginPageGeneratingWebFilter::csrfToken)
 			.defaultIfEmpty("")
 			.map(csrfTokenHtmlInput -> {

+ 2 - 2
web/src/main/java/org/springframework/security/web/server/ui/LogoutPageGeneratingWebFilter.java

@@ -57,8 +57,8 @@ public class LogoutPageGeneratingWebFilter implements WebFilter {
 	}
 
 	private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
-		CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
-		return Mono.justOrEmpty(token)
+		Mono<CsrfToken> token = exchange.getAttributeOrDefault(CsrfToken.class.getName(), Mono.empty());
+		return token
 			.map(LogoutPageGeneratingWebFilter::csrfToken)
 			.defaultIfEmpty("")
 			.map(csrfTokenHtmlInput -> {

+ 3 - 2
web/src/test/java/org/springframework/security/web/reactive/result/view/CsrfRequestDataValueProcessorTests.java

@@ -30,6 +30,7 @@ import java.util.HashMap;
 import java.util.Map;
 
 import static org.assertj.core.api.Assertions.*;
+import static org.springframework.security.web.reactive.result.view.CsrfRequestDataValueProcessor.DEFAULT_CSRF_ATTR_NAME;
 
 /**
  * @author Rob Winch
@@ -46,7 +47,7 @@ public class CsrfRequestDataValueProcessorTests {
 	@Before
 	public void setup() {
 		this.expected.put(this.token.getParameterName(), this.token.getToken());
-		this.exchange.getAttributes().put(CsrfToken.class.getName(), this.token);
+		this.exchange.getAttributes().put(DEFAULT_CSRF_ATTR_NAME, this.token);
 	}
 
 	@Test
@@ -122,7 +123,7 @@ public class CsrfRequestDataValueProcessorTests {
 	@Test
 	public void createGetExtraHiddenFieldsHasCsrfToken() {
 		CsrfToken token = new DefaultCsrfToken("1", "a", "b");
-		this.exchange.getAttributes().put(CsrfToken.class.getName(), token);
+		this.exchange.getAttributes().put(DEFAULT_CSRF_ATTR_NAME, token);
 		Map<String, String> expected = new HashMap<String, String>();
 		expected.put(token.getParameterName(), token.getToken());
 

+ 0 - 6
web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java

@@ -89,8 +89,6 @@ public class CsrfWebFilterTests {
 		this.csrfFilter.setCsrfTokenRepository(this.repository);
 		when(this.repository.loadToken(any()))
 			.thenReturn(Mono.just(this.token));
-		when(this.repository.generateToken(any()))
-			.thenReturn(Mono.just(this.token));
 
 		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
 
@@ -106,8 +104,6 @@ public class CsrfWebFilterTests {
 		this.csrfFilter.setCsrfTokenRepository(this.repository);
 		when(this.repository.loadToken(any()))
 			.thenReturn(Mono.just(this.token));
-		when(this.repository.generateToken(any()))
-			.thenReturn(Mono.just(this.token));
 		this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
 			.body(this.token.getParameterName() + "="+this.token.getToken()+"INVALID"));
 
@@ -146,8 +142,6 @@ public class CsrfWebFilterTests {
 		this.csrfFilter.setCsrfTokenRepository(this.repository);
 		when(this.repository.loadToken(any()))
 			.thenReturn(Mono.just(this.token));
-		when(this.repository.generateToken(any()))
-			.thenReturn(Mono.just(this.token));
 		this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
 			.header(this.token.getHeaderName(), this.token.getToken()+"INVALID"));
 

+ 4 - 20
web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java

@@ -61,9 +61,10 @@ public class WebSessionServerCsrfTokenRepositoryTests {
 	}
 
 	@Test
-	public void generateTokenWhenGetTokenThenAddsToSession() {
-		Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
-		result.block().getToken();
+	public void saveTokenWhenDefaultThenAddsToSession() {
+		Mono<CsrfToken> result = this.repository.generateToken(this.exchange)
+			.delayUntil(t-> this.repository.saveToken(this.exchange, t));
+		result.block();
 
 		WebSession session = this.exchange.getSession().block();
 		Map<String, Object> attributes = session.getAttributes();
@@ -76,7 +77,6 @@ public class WebSessionServerCsrfTokenRepositoryTests {
 	@Test
 	public void saveTokenWhenNullThenDeletes() {
 		CsrfToken token = this.repository.generateToken(this.exchange).block();
-		token.getToken();
 
 		Mono<CsrfToken> result = this.repository.saveToken(this.exchange, null);
 		StepVerifier.create(result)
@@ -87,22 +87,6 @@ public class WebSessionServerCsrfTokenRepositoryTests {
 		assertThat(session.getAttributes()).isEmpty();
 	}
 
-	@Test
-	public void generateTokenAndLoadTokenDeleteTokenWhenNullThenDeletes() {
-		CsrfToken generate = this.repository.generateToken(this.exchange).block();
-		generate.getToken();
-
-		CsrfToken load = this.repository.loadToken(this.exchange).block();
-		assertThat(load).isEqualTo(generate);
-
-		this.repository.saveToken(this.exchange, null).block();
-		WebSession session = this.exchange.getSession().block();
-		assertThat(session.getAttributes()).isEmpty();
-
-		load = this.repository.loadToken(this.exchange).block();
-		assertThat(load).isNull();
-	}
-
 	@Test
 	public void saveTokenChangeSessionId() {
 		String originalSessionId = this.exchange.getSession().block().getId();