瀏覽代碼

WebSessionServerCsrfTokenRepository saves on getToken

Fixes gh-4801
Rob Winch 7 年之前
父節點
當前提交
7622826b69

+ 5 - 5
config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java

@@ -316,9 +316,9 @@ public class FormLoginTests {
 	public static class CustomLoginPageController {
 		@ResponseBody
 		@GetMapping("/login")
-		public Mono<String> login(ServerWebExchange exchange) {
-			Mono<CsrfToken> token = exchange.getAttribute(CsrfToken.class.getName());
-			return token.map(t ->
+		public String login(ServerWebExchange exchange) {
+			CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
+			return
 				"<!DOCTYPE html>\n"
 				+ "<html lang=\"en\">\n"
 				+ "  <head>\n"
@@ -340,12 +340,12 @@ public class FormLoginTests {
 				+ "          <label for=\"password\" class=\"sr-only\">Password</label>\n"
 				+ "          <input type=\"password\" id=\"password\" name=\"password\" placeholder=\"Password\" required>\n"
 				+ "        </p>\n"
-				+ "        <input type=\"hidden\" name=\"" + t.getParameterName() + "\" value=\"" + t.getToken() + "\">\n"
+				+ "        <input type=\"hidden\" name=\"" + token.getParameterName() + "\" value=\"" + token.getToken() + "\">\n"
 				+ "        <button type=\"submit\">Sign in</button>\n"
 				+ "      </form>\n"
 				+ "    </div>\n"
 				+ "  </body>\n"
-				+ "</html>");
+				+ "</html>";
 		}
 
 	}

+ 8 - 3
web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java

@@ -106,14 +106,19 @@ public class CsrfWebFilter implements WebFilter {
 	private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
 		return csrfToken(exchange)
 			.doOnSuccess(csrfToken -> exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken))
+			.doOnSuccess(csrfToken -> exchange.getAttributes().put(csrfToken.getParameterName(), csrfToken))
 			.flatMap( t -> chain.filter(exchange))
 			.then();
 	}
 
-	private Mono<Mono<CsrfToken>> csrfToken(ServerWebExchange exchange) {
+	private Mono<CsrfToken> csrfToken(ServerWebExchange exchange) {
 		return this.serverCsrfTokenRepository.loadToken(exchange)
-			.switchIfEmpty(this.serverCsrfTokenRepository.generateToken(exchange))
-			.as(Mono::just); // FIXME eager saving of CsrfToken with .as
+			.switchIfEmpty(generateToken(exchange));
+	}
+
+	private Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
+		return this.serverCsrfTokenRepository.generateToken(exchange)
+			.flatMap(token -> this.serverCsrfTokenRepository.saveToken(exchange, token));
 	}
 
 	private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher {

+ 24 - 0
web/src/main/java/org/springframework/security/web/server/csrf/DefaultCsrfToken.java

@@ -74,4 +74,28 @@ public final class DefaultCsrfToken implements CsrfToken {
 	public String getToken() {
 		return this.token;
 	}
+
+	@Override
+	public boolean equals(Object o) {
+		if (this == o)
+			return true;
+		if (o == null || !(o instanceof CsrfToken))
+			return false;
+
+		CsrfToken that = (CsrfToken) o;
+
+		if (!getToken().equals(that.getToken()))
+			return false;
+		if (!getParameterName().equals(that.getParameterName()))
+			return false;
+		return getHeaderName().equals(that.getHeaderName());
+	}
+
+	@Override
+	public int hashCode() {
+		int result = getToken().hashCode();
+		result = 31 * result + getParameterName().hashCode();
+		result = 31 * result + getHeaderName().hashCode();
+		return result;
+	}
 }

+ 66 - 2
web/src/main/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepository.java

@@ -49,12 +49,16 @@ public class WebSessionServerCsrfTokenRepository
 
 	@Override
 	public Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
-		return Mono.defer(() -> Mono.just(createCsrfToken()))
-			.flatMap(token -> saveToken(exchange, token));
+		return exchange.getSession()
+			.map(WebSession::getAttributes)
+			.map(this::createCsrfToken);
 	}
 
 	@Override
 	public Mono<CsrfToken> saveToken(ServerWebExchange exchange, CsrfToken token) {
+		if(token != null) {
+			return Mono.just(token);
+		}
 		return exchange.getSession()
 			.map(WebSession::getAttributes)
 			.flatMap( attrs -> save(attrs, token));
@@ -113,6 +117,11 @@ public class WebSessionServerCsrfTokenRepository
 		this.sessionAttributeName = sessionAttributeName;
 	}
 
+
+	private CsrfToken createCsrfToken(Map<String,Object> attributes) {
+		return new LazyCsrfToken(attributes, createCsrfToken());
+	}
+
 	private CsrfToken createCsrfToken() {
 		return new DefaultCsrfToken(this.headerName, this.parameterName, createNewToken());
 	}
@@ -120,4 +129,59 @@ public class WebSessionServerCsrfTokenRepository
 	private String createNewToken() {
 		return UUID.randomUUID().toString();
 	}
+
+	private class LazyCsrfToken implements CsrfToken {
+		private final Map<String,Object> attributes;
+		private final CsrfToken delegate;
+
+		private LazyCsrfToken(Map<String, Object> attributes, CsrfToken delegate) {
+			this.attributes = attributes;
+			this.delegate = delegate;
+		}
+
+		@Override
+		public String getHeaderName() {
+			return this.delegate.getHeaderName();
+		}
+
+		@Override
+		public String getParameterName() {
+			return this.delegate.getParameterName();
+		}
+
+		@Override
+		public String getToken() {
+			putToken(this.attributes, this.delegate);
+			return this.delegate.getToken();
+		}
+
+		@Override
+		public boolean equals(Object o) {
+			if (this == o)
+				return true;
+			if (o == null || !(o instanceof CsrfToken))
+				return false;
+
+			CsrfToken that = (CsrfToken) o;
+
+			if (!getToken().equals(that.getToken()))
+				return false;
+			if (!getParameterName().equals(that.getParameterName()))
+				return false;
+			return getHeaderName().equals(that.getHeaderName());
+		}
+
+		@Override
+		public int hashCode() {
+			int result = getToken().hashCode();
+			result = 31 * result + getParameterName().hashCode();
+			result = 31 * result + getHeaderName().hashCode();
+			return result;
+		}
+
+		@Override
+		public String toString() {
+			return "LazyCsrfToken{" + "delegate=" + this.delegate + '}';
+		}
+	}
 }

+ 2 - 3
web/src/main/java/org/springframework/security/web/server/ui/LoginPageGeneratingWebFilter.java

@@ -61,9 +61,8 @@ public class LoginPageGeneratingWebFilter implements WebFilter {
 	private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
 		MultiValueMap<String, String> queryParams = exchange.getRequest()
 			.getQueryParams();
-		Mono<CsrfToken> token = (Mono<CsrfToken>) exchange.getAttributes()
-			.getOrDefault(CsrfToken.class.getName(), Mono.<CsrfToken>empty());
-		return token
+		CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
+		return Mono.justOrEmpty(token)
 			.map(LoginPageGeneratingWebFilter::csrfToken)
 			.defaultIfEmpty("")
 			.map(csrfTokenHtmlInput -> {

+ 2 - 3
web/src/main/java/org/springframework/security/web/server/ui/LogoutPageGeneratingWebFilter.java

@@ -58,9 +58,8 @@ public class LogoutPageGeneratingWebFilter implements WebFilter {
 	}
 
 	private Mono<DataBuffer> createBuffer(ServerWebExchange exchange) {
-		Mono<CsrfToken> token = (Mono<CsrfToken>) exchange.getAttributes()
-			.getOrDefault(CsrfToken.class.getName(), Mono.<CsrfToken>empty());
-		return token
+		CsrfToken token = exchange.getAttribute(CsrfToken.class.getName());
+		return Mono.justOrEmpty(token)
 			.map(LogoutPageGeneratingWebFilter::csrfToken)
 			.defaultIfEmpty("")
 			.map(csrfTokenHtmlInput -> {

+ 16 - 24
web/src/test/java/org/springframework/security/web/server/csrf/WebSessionServerCsrfTokenRepositoryTests.java

@@ -37,7 +37,7 @@ public class WebSessionServerCsrfTokenRepositoryTests {
 	private MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
 
 	@Test
-	public void generateTokenWhenNoSubscriptionThenNoSession() {
+	public void generateTokenThenNoSession() {
 		Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
 
 		Mono<Boolean> isSessionStarted = this.exchange.getSession()
@@ -49,43 +49,34 @@ public class WebSessionServerCsrfTokenRepositoryTests {
 	}
 
 	@Test
-	public void generateTokenWhenSubscriptionThenAddsToSession() {
+	public void generateTokenWhenSubscriptionThenNoSession() {
 		Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
 
-		StepVerifier.create(result)
-			.consumeNextWith( t -> assertThat(t).isNotNull())
-			.verifyComplete();
-
-		WebSession session = this.exchange.getSession().block();
-		Map<String, Object> attributes = session.getAttributes();
-
-		assertThat(session.isStarted()).isTrue();
-		assertThat(attributes).hasSize(1);
-		assertThat(attributes.values().iterator().next()).isInstanceOf(CsrfToken.class);
+		Mono<Boolean> isSessionStarted = this.exchange.getSession()
+			.map(WebSession::isStarted);
 
+		StepVerifier.create(isSessionStarted)
+			.expectNext(false)
+			.verifyComplete();
 	}
 
 	@Test
-	public void saveTokenWhenSetSessionAttributeNameAndSubscriptionThenAddsToSession() {
-		CsrfToken token = new DefaultCsrfToken("h","p", "t");
-		String attrName = "ATTR";
-		this.repository.setSessionAttributeName(attrName);
-		Mono<CsrfToken> result = this.repository.saveToken(this.exchange, token);
-
-		StepVerifier.create(result)
-			.consumeNextWith(n -> assertThat(n).isEqualTo(token))
-			.verifyComplete();
+	public void generateTokenWhenGetTokenThenAddsToSession() {
+		Mono<CsrfToken> result = this.repository.generateToken(this.exchange);
+		result.block().getToken();
 
 		WebSession session = this.exchange.getSession().block();
+		Map<String, Object> attributes = session.getAttributes();
 
 		assertThat(session.isStarted()).isTrue();
-		assertThat(session.<WebSession>getAttribute(attrName)).isEqualTo(token);
+		assertThat(attributes).hasSize(1);
+		assertThat(attributes.values().iterator().next()).isInstanceOf(CsrfToken.class);
 	}
 
 	@Test
 	public void saveTokenWhenNullThenDeletes() {
-		CsrfToken token = new DefaultCsrfToken("h","p", "t");
-		this.repository.saveToken(this.exchange, token).block();
+		CsrfToken token = this.repository.generateToken(this.exchange).block();
+		token.getToken();
 
 		Mono<CsrfToken> result = this.repository.saveToken(this.exchange, null);
 		StepVerifier.create(result)
@@ -99,6 +90,7 @@ public class WebSessionServerCsrfTokenRepositoryTests {
 	@Test
 	public void generateTokenAndLoadTokenDeleteTokenWhenNullThenDeletes() {
 		CsrfToken generate = this.repository.generateToken(this.exchange).block();
+		generate.getToken();
 
 		CsrfToken load = this.repository.loadToken(this.exchange).block();
 		assertThat(load).isEqualTo(generate);