|
@@ -16,6 +16,7 @@
|
|
|
|
|
|
package org.springframework.security.web.server.csrf;
|
|
|
|
|
|
+import java.security.MessageDigest;
|
|
|
import java.util.Arrays;
|
|
|
import java.util.HashSet;
|
|
|
import java.util.Set;
|
|
@@ -28,6 +29,7 @@ import org.springframework.http.HttpStatus;
|
|
|
import org.springframework.http.MediaType;
|
|
|
import org.springframework.http.codec.multipart.FormFieldPart;
|
|
|
import org.springframework.http.server.reactive.ServerHttpRequest;
|
|
|
+import org.springframework.security.crypto.codec.Utf8;
|
|
|
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
|
|
|
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
|
|
|
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
|
|
@@ -139,7 +141,7 @@ public class CsrfWebFilter implements WebFilter {
|
|
|
return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
|
|
|
.switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
|
|
|
.switchIfEmpty(tokenFromMultipartData(exchange, expected))
|
|
|
- .map((actual) -> actual.equals(expected.getToken()));
|
|
|
+ .map((actual) -> equalsConstantTime(actual, expected.getToken()));
|
|
|
}
|
|
|
|
|
|
private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
|
|
@@ -168,6 +170,24 @@ public class CsrfWebFilter implements WebFilter {
|
|
|
return this.csrfTokenRepository.loadToken(exchange).switchIfEmpty(generateToken(exchange));
|
|
|
}
|
|
|
|
|
|
+ /**
|
|
|
+ * Constant time comparison to prevent against timing attacks.
|
|
|
+ * @param expected
|
|
|
+ * @param actual
|
|
|
+ * @return
|
|
|
+ */
|
|
|
+ private static boolean equalsConstantTime(String expected, String actual) {
|
|
|
+ byte[] expectedBytes = bytesUtf8(expected);
|
|
|
+ byte[] actualBytes = bytesUtf8(actual);
|
|
|
+ return MessageDigest.isEqual(expectedBytes, actualBytes);
|
|
|
+ }
|
|
|
+
|
|
|
+ private static byte[] bytesUtf8(String s) {
|
|
|
+ // need to check if Utf8.encode() runs in constant time (probably not).
|
|
|
+ // This may leak length of string.
|
|
|
+ return (s != null) ? Utf8.encode(s) : null;
|
|
|
+ }
|
|
|
+
|
|
|
private Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
|
|
|
return this.csrfTokenRepository.generateToken(exchange)
|
|
|
.delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token));
|