瀏覽代碼

Constant Time Comparison for CSRF tokens

Closes gh-9291
Rob Winch 4 年之前
父節點
當前提交
40e027c56d

+ 21 - 1
web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java

@@ -17,6 +17,7 @@
 package org.springframework.security.web.csrf;
 
 import java.io.IOException;
+import java.security.MessageDigest;
 import java.util.Arrays;
 import java.util.HashSet;
 
@@ -31,6 +32,7 @@ import org.apache.commons.logging.LogFactory;
 
 import org.springframework.core.log.LogMessage;
 import org.springframework.security.access.AccessDeniedException;
+import org.springframework.security.crypto.codec.Utf8;
 import org.springframework.security.web.access.AccessDeniedHandler;
 import org.springframework.security.web.access.AccessDeniedHandlerImpl;
 import org.springframework.security.web.util.UrlUtils;
@@ -119,7 +121,7 @@ public final class CsrfFilter extends OncePerRequestFilter {
 		if (actualToken == null) {
 			actualToken = request.getParameter(csrfToken.getParameterName());
 		}
-		if (!csrfToken.getToken().equals(actualToken)) {
+		if (!equalsConstantTime(csrfToken.getToken(), actualToken)) {
 			this.logger.debug(
 					LogMessage.of(() -> "Invalid CSRF token found for " + UrlUtils.buildFullRequestUrl(request)));
 			AccessDeniedException exception = (!missingToken) ? new InvalidCsrfTokenException(csrfToken, actualToken)
@@ -165,6 +167,24 @@ public final class CsrfFilter extends OncePerRequestFilter {
 		this.accessDeniedHandler = accessDeniedHandler;
 	}
 
+	/**
+	 * Constant time comparison to prevent against timing attacks.
+	 * @param expected
+	 * @param actual
+	 * @return
+	 */
+	private static boolean equalsConstantTime(String expected, String actual) {
+		byte[] expectedBytes = bytesUtf8(expected);
+		byte[] actualBytes = bytesUtf8(actual);
+		return MessageDigest.isEqual(expectedBytes, actualBytes);
+	}
+
+	private static byte[] bytesUtf8(String s) {
+		// need to check if Utf8.encode() runs in constant time (probably not).
+		// This may leak length of string.
+		return (s != null) ? Utf8.encode(s) : null;
+	}
+
 	private static final class DefaultRequiresCsrfMatcher implements RequestMatcher {
 
 		private final HashSet<String> allowedMethods = new HashSet<>(Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS"));

+ 21 - 1
web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java

@@ -16,6 +16,7 @@
 
 package org.springframework.security.web.server.csrf;
 
+import java.security.MessageDigest;
 import java.util.Arrays;
 import java.util.HashSet;
 import java.util.Set;
@@ -28,6 +29,7 @@ import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
 import org.springframework.http.codec.multipart.FormFieldPart;
 import org.springframework.http.server.reactive.ServerHttpRequest;
+import org.springframework.security.crypto.codec.Utf8;
 import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
 import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
 import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
@@ -139,7 +141,7 @@ public class CsrfWebFilter implements WebFilter {
 		return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
 				.switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
 				.switchIfEmpty(tokenFromMultipartData(exchange, expected))
-				.map((actual) -> actual.equals(expected.getToken()));
+				.map((actual) -> equalsConstantTime(actual, expected.getToken()));
 	}
 
 	private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
@@ -168,6 +170,24 @@ public class CsrfWebFilter implements WebFilter {
 		return this.csrfTokenRepository.loadToken(exchange).switchIfEmpty(generateToken(exchange));
 	}
 
+	/**
+	 * Constant time comparison to prevent against timing attacks.
+	 * @param expected
+	 * @param actual
+	 * @return
+	 */
+	private static boolean equalsConstantTime(String expected, String actual) {
+		byte[] expectedBytes = bytesUtf8(expected);
+		byte[] actualBytes = bytesUtf8(actual);
+		return MessageDigest.isEqual(expectedBytes, actualBytes);
+	}
+
+	private static byte[] bytesUtf8(String s) {
+		// need to check if Utf8.encode() runs in constant time (probably not).
+		// This may leak length of string.
+		return (s != null) ? Utf8.encode(s) : null;
+	}
+
 	private Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
 		return this.csrfTokenRepository.generateToken(exchange)
 				.delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token));