Browse Source

Cache Xor CsrfToken

Closes gh-11988
Steve Riesenberg 2 years ago
parent
commit
05e4a1dd20

+ 22 - 2
web/src/main/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandler.java

@@ -59,12 +59,12 @@ public final class XorCsrfTokenRequestAttributeHandler extends CsrfTokenRequestA
 	}
 
 	private Supplier<CsrfToken> deferCsrfTokenUpdate(Supplier<CsrfToken> csrfTokenSupplier) {
-		return () -> {
+		return new CachedCsrfTokenSupplier(() -> {
 			CsrfToken csrfToken = csrfTokenSupplier.get();
 			Assert.state(csrfToken != null, "csrfToken supplier returned null");
 			String updatedToken = createXoredCsrfToken(this.secureRandom, csrfToken.getToken());
 			return new DefaultCsrfToken(csrfToken.getHeaderName(), csrfToken.getParameterName(), updatedToken);
-		};
+		});
 	}
 
 	@Override
@@ -123,4 +123,24 @@ public final class XorCsrfTokenRequestAttributeHandler extends CsrfTokenRequestA
 		return xoredCsrf;
 	}
 
+	private static final class CachedCsrfTokenSupplier implements Supplier<CsrfToken> {
+
+		private final Supplier<CsrfToken> delegate;
+
+		private CsrfToken csrfToken;
+
+		private CachedCsrfTokenSupplier(Supplier<CsrfToken> delegate) {
+			this.delegate = delegate;
+		}
+
+		@Override
+		public CsrfToken get() {
+			if (this.csrfToken == null) {
+				this.csrfToken = this.delegate.get();
+			}
+			return this.csrfToken;
+		}
+
+	}
+
 }

+ 2 - 1
web/src/main/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandler.java

@@ -53,7 +53,8 @@ public final class XorServerCsrfTokenRequestAttributeHandler extends ServerCsrfT
 		Assert.notNull(exchange, "exchange cannot be null");
 		Assert.notNull(csrfToken, "csrfToken cannot be null");
 		Mono<CsrfToken> updatedCsrfToken = csrfToken.map((token) -> new DefaultCsrfToken(token.getHeaderName(),
-				token.getParameterName(), createXoredCsrfToken(this.secureRandom, token.getToken())));
+				token.getParameterName(), createXoredCsrfToken(this.secureRandom, token.getToken())))
+				.cast(CsrfToken.class).cache();
 		super.handle(exchange, updatedCsrfToken);
 	}
 

+ 9 - 0
web/src/test/java/org/springframework/security/web/csrf/XorCsrfTokenRequestAttributeHandlerTests.java

@@ -148,6 +148,15 @@ public class XorCsrfTokenRequestAttributeHandlerTests {
 		assertThat(csrfTokenAttribute.getToken()).isEqualTo(XOR_CSRF_TOKEN_VALUE);
 	}
 
+	@Test
+	public void handleWhenCsrfTokenRequestedTwiceThenCached() {
+		this.handler.handle(this.request, this.response, () -> this.token);
+
+		CsrfToken csrfTokenAttribute = (CsrfToken) this.request.getAttribute(CsrfToken.class.getName());
+		assertThat(csrfTokenAttribute.getToken()).isNotEqualTo(this.token.getToken());
+		assertThat(csrfTokenAttribute.getToken()).isEqualTo(csrfTokenAttribute.getToken());
+	}
+
 	@Test
 	public void resolveCsrfTokenValueWhenRequestIsNullThenThrowsIllegalArgumentException() {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token))

+ 11 - 0
web/src/test/java/org/springframework/security/web/server/csrf/XorServerCsrfTokenRequestAttributeHandlerTests.java

@@ -110,6 +110,17 @@ public class XorServerCsrfTokenRequestAttributeHandlerTests {
 		verify(this.secureRandom).nextBytes(anyByteArray());
 	}
 
+	@Test
+	public void handleWhenCsrfTokenRequestedTwiceThenCached() {
+		this.handler.handle(this.exchange, Mono.just(this.token));
+		Mono<CsrfToken> csrfTokenAttribute = this.exchange.getAttribute(CsrfToken.class.getName());
+		assertThat(csrfTokenAttribute).isNotNull();
+		CsrfToken csrfToken1 = csrfTokenAttribute.block();
+		CsrfToken csrfToken2 = csrfTokenAttribute.block();
+		assertThat(csrfToken1.getToken()).isNotEqualTo(this.token.getToken());
+		assertThat(csrfToken1.getToken()).isEqualTo(csrfToken2.getToken());
+	}
+
 	@Test
 	public void resolveCsrfTokenValueWhenExchangeIsNullThenThrowsIllegalArgumentException() {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.handler.resolveCsrfTokenValue(null, this.token))