|
@@ -16,26 +16,28 @@
|
|
|
|
|
|
package org.springframework.security.web.server.csrf;
|
|
package org.springframework.security.web.server.csrf;
|
|
|
|
|
|
|
|
+import java.security.MessageDigest;
|
|
|
|
+import java.util.Arrays;
|
|
|
|
+import java.util.HashSet;
|
|
|
|
+import java.util.Set;
|
|
|
|
+
|
|
|
|
+import reactor.core.publisher.Mono;
|
|
|
|
+
|
|
import org.springframework.http.HttpHeaders;
|
|
import org.springframework.http.HttpHeaders;
|
|
import org.springframework.http.HttpMethod;
|
|
import org.springframework.http.HttpMethod;
|
|
import org.springframework.http.HttpStatus;
|
|
import org.springframework.http.HttpStatus;
|
|
import org.springframework.http.MediaType;
|
|
import org.springframework.http.MediaType;
|
|
import org.springframework.http.codec.multipart.FormFieldPart;
|
|
import org.springframework.http.codec.multipart.FormFieldPart;
|
|
import org.springframework.http.server.reactive.ServerHttpRequest;
|
|
import org.springframework.http.server.reactive.ServerHttpRequest;
|
|
|
|
+import org.springframework.security.crypto.codec.Utf8;
|
|
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
|
|
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
|
|
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
|
|
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
|
|
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
|
|
import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
|
|
|
|
+import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher.MatchResult;
|
|
import org.springframework.util.Assert;
|
|
import org.springframework.util.Assert;
|
|
import org.springframework.web.server.ServerWebExchange;
|
|
import org.springframework.web.server.ServerWebExchange;
|
|
import org.springframework.web.server.WebFilter;
|
|
import org.springframework.web.server.WebFilter;
|
|
import org.springframework.web.server.WebFilterChain;
|
|
import org.springframework.web.server.WebFilterChain;
|
|
-import reactor.core.publisher.Mono;
|
|
|
|
-
|
|
|
|
-import java.util.Arrays;
|
|
|
|
-import java.util.HashSet;
|
|
|
|
-import java.util.Set;
|
|
|
|
-
|
|
|
|
-import static java.lang.Boolean.TRUE;
|
|
|
|
|
|
|
|
/**
|
|
/**
|
|
* <p>
|
|
* <p>
|
|
@@ -64,13 +66,14 @@ import static java.lang.Boolean.TRUE;
|
|
* @since 5.0
|
|
* @since 5.0
|
|
*/
|
|
*/
|
|
public class CsrfWebFilter implements WebFilter {
|
|
public class CsrfWebFilter implements WebFilter {
|
|
|
|
+
|
|
public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher();
|
|
public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher();
|
|
|
|
|
|
/**
|
|
/**
|
|
- * The attribute name to use when marking a given request as one that should not be filtered.
|
|
|
|
|
|
+ * The attribute name to use when marking a given request as one that should not be
|
|
|
|
+ * filtered.
|
|
*
|
|
*
|
|
- * To use, set the attribute on your {@link ServerWebExchange}:
|
|
|
|
- * <pre>
|
|
|
|
|
|
+ * To use, set the attribute on your {@link ServerWebExchange}: <pre>
|
|
* CsrfWebFilter.skipExchange(exchange);
|
|
* CsrfWebFilter.skipExchange(exchange);
|
|
* </pre>
|
|
* </pre>
|
|
*/
|
|
*/
|
|
@@ -80,32 +83,31 @@ public class CsrfWebFilter implements WebFilter {
|
|
|
|
|
|
private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository();
|
|
private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository();
|
|
|
|
|
|
- private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(HttpStatus.FORBIDDEN);
|
|
|
|
|
|
+ private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(
|
|
|
|
+ HttpStatus.FORBIDDEN);
|
|
|
|
|
|
private boolean isTokenFromMultipartDataEnabled;
|
|
private boolean isTokenFromMultipartDataEnabled;
|
|
|
|
|
|
- public void setAccessDeniedHandler(
|
|
|
|
- ServerAccessDeniedHandler accessDeniedHandler) {
|
|
|
|
|
|
+ public void setAccessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) {
|
|
Assert.notNull(accessDeniedHandler, "accessDeniedHandler");
|
|
Assert.notNull(accessDeniedHandler, "accessDeniedHandler");
|
|
this.accessDeniedHandler = accessDeniedHandler;
|
|
this.accessDeniedHandler = accessDeniedHandler;
|
|
}
|
|
}
|
|
|
|
|
|
- public void setCsrfTokenRepository(
|
|
|
|
- ServerCsrfTokenRepository csrfTokenRepository) {
|
|
|
|
|
|
+ public void setCsrfTokenRepository(ServerCsrfTokenRepository csrfTokenRepository) {
|
|
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
|
|
Assert.notNull(csrfTokenRepository, "csrfTokenRepository cannot be null");
|
|
this.csrfTokenRepository = csrfTokenRepository;
|
|
this.csrfTokenRepository = csrfTokenRepository;
|
|
}
|
|
}
|
|
|
|
|
|
- public void setRequireCsrfProtectionMatcher(
|
|
|
|
- ServerWebExchangeMatcher requireCsrfProtectionMatcher) {
|
|
|
|
|
|
+ public void setRequireCsrfProtectionMatcher(ServerWebExchangeMatcher requireCsrfProtectionMatcher) {
|
|
Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null");
|
|
Assert.notNull(requireCsrfProtectionMatcher, "requireCsrfProtectionMatcher cannot be null");
|
|
this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
|
|
this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
|
|
}
|
|
}
|
|
|
|
|
|
/**
|
|
/**
|
|
- * Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token from the body of multipart
|
|
|
|
- * data requests.
|
|
|
|
- * @param tokenFromMultipartDataEnabled true if should read from multipart form body, else false. Default is false
|
|
|
|
|
|
+ * Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token
|
|
|
|
+ * from the body of multipart data requests.
|
|
|
|
+ * @param tokenFromMultipartDataEnabled true if should read from multipart form body,
|
|
|
|
+ * else false. Default is false
|
|
*/
|
|
*/
|
|
public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) {
|
|
public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) {
|
|
this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled;
|
|
this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled;
|
|
@@ -113,38 +115,33 @@ public class CsrfWebFilter implements WebFilter {
|
|
|
|
|
|
@Override
|
|
@Override
|
|
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
|
|
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
|
|
- if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) {
|
|
|
|
|
|
+ if (Boolean.TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) {
|
|
return chain.filter(exchange).then(Mono.empty());
|
|
return chain.filter(exchange).then(Mono.empty());
|
|
}
|
|
}
|
|
-
|
|
|
|
- return this.requireCsrfProtectionMatcher.matches(exchange)
|
|
|
|
- .filter( matchResult -> matchResult.isMatch())
|
|
|
|
- .filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName()))
|
|
|
|
- .flatMap(m -> validateToken(exchange))
|
|
|
|
- .flatMap(m -> continueFilterChain(exchange, chain))
|
|
|
|
- .switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty()))
|
|
|
|
- .onErrorResume(CsrfException.class, e -> this.accessDeniedHandler
|
|
|
|
- .handle(exchange, e));
|
|
|
|
|
|
+ return this.requireCsrfProtectionMatcher.matches(exchange).filter(MatchResult::isMatch)
|
|
|
|
+ .filter((matchResult) -> !exchange.getAttributes().containsKey(CsrfToken.class.getName()))
|
|
|
|
+ .flatMap((m) -> validateToken(exchange)).flatMap((m) -> continueFilterChain(exchange, chain))
|
|
|
|
+ .switchIfEmpty(continueFilterChain(exchange, chain).then(Mono.empty()))
|
|
|
|
+ .onErrorResume(CsrfException.class, (ex) -> this.accessDeniedHandler.handle(exchange, ex));
|
|
}
|
|
}
|
|
|
|
|
|
public static void skipExchange(ServerWebExchange exchange) {
|
|
public static void skipExchange(ServerWebExchange exchange) {
|
|
- exchange.getAttributes().put(SHOULD_NOT_FILTER, TRUE);
|
|
|
|
|
|
+ exchange.getAttributes().put(SHOULD_NOT_FILTER, Boolean.TRUE);
|
|
}
|
|
}
|
|
|
|
|
|
private Mono<Void> validateToken(ServerWebExchange exchange) {
|
|
private Mono<Void> validateToken(ServerWebExchange exchange) {
|
|
return this.csrfTokenRepository.loadToken(exchange)
|
|
return this.csrfTokenRepository.loadToken(exchange)
|
|
- .switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("An expected CSRF token cannot be found"))))
|
|
|
|
- .filterWhen(expected -> containsValidCsrfToken(exchange, expected))
|
|
|
|
- .switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("Invalid CSRF Token"))))
|
|
|
|
- .then();
|
|
|
|
|
|
+ .switchIfEmpty(
|
|
|
|
+ Mono.defer(() -> Mono.error(new CsrfException("An expected CSRF token cannot be found"))))
|
|
|
|
+ .filterWhen((expected) -> containsValidCsrfToken(exchange, expected))
|
|
|
|
+ .switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("Invalid CSRF Token")))).then();
|
|
}
|
|
}
|
|
|
|
|
|
private Mono<Boolean> containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) {
|
|
private Mono<Boolean> containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) {
|
|
- return exchange.getFormData()
|
|
|
|
- .flatMap(data -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
|
|
|
|
- .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
|
|
|
|
- .switchIfEmpty(tokenFromMultipartData(exchange, expected))
|
|
|
|
- .map(actual -> actual.equals(expected.getToken()));
|
|
|
|
|
|
+ return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
|
|
|
|
+ .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
|
|
|
|
+ .switchIfEmpty(tokenFromMultipartData(exchange, expected))
|
|
|
|
+ .map((actual) -> equalsConstantTime(actual, expected.getToken()));
|
|
}
|
|
}
|
|
|
|
|
|
private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
|
|
private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
|
|
@@ -157,14 +154,12 @@ public class CsrfWebFilter implements WebFilter {
|
|
if (!contentType.includes(MediaType.MULTIPART_FORM_DATA)) {
|
|
if (!contentType.includes(MediaType.MULTIPART_FORM_DATA)) {
|
|
return Mono.empty();
|
|
return Mono.empty();
|
|
}
|
|
}
|
|
- return exchange.getMultipartData()
|
|
|
|
- .map(d -> d.getFirst(expected.getParameterName()))
|
|
|
|
- .cast(FormFieldPart.class)
|
|
|
|
- .map(FormFieldPart::value);
|
|
|
|
|
|
+ return exchange.getMultipartData().map((d) -> d.getFirst(expected.getParameterName())).cast(FormFieldPart.class)
|
|
|
|
+ .map(FormFieldPart::value);
|
|
}
|
|
}
|
|
|
|
|
|
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
|
|
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
|
|
- return Mono.defer(() ->{
|
|
|
|
|
|
+ return Mono.defer(() -> {
|
|
Mono<CsrfToken> csrfToken = csrfToken(exchange);
|
|
Mono<CsrfToken> csrfToken = csrfToken(exchange);
|
|
exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken);
|
|
exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken);
|
|
return chain.filter(exchange);
|
|
return chain.filter(exchange);
|
|
@@ -172,26 +167,44 @@ public class CsrfWebFilter implements WebFilter {
|
|
}
|
|
}
|
|
|
|
|
|
private Mono<CsrfToken> csrfToken(ServerWebExchange exchange) {
|
|
private Mono<CsrfToken> csrfToken(ServerWebExchange exchange) {
|
|
- return this.csrfTokenRepository.loadToken(exchange)
|
|
|
|
- .switchIfEmpty(generateToken(exchange));
|
|
|
|
|
|
+ return this.csrfTokenRepository.loadToken(exchange).switchIfEmpty(generateToken(exchange));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /**
|
|
|
|
+ * Constant time comparison to prevent against timing attacks.
|
|
|
|
+ * @param expected
|
|
|
|
+ * @param actual
|
|
|
|
+ * @return
|
|
|
|
+ */
|
|
|
|
+ private static boolean equalsConstantTime(String expected, String actual) {
|
|
|
|
+ byte[] expectedBytes = bytesUtf8(expected);
|
|
|
|
+ byte[] actualBytes = bytesUtf8(actual);
|
|
|
|
+ return MessageDigest.isEqual(expectedBytes, actualBytes);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private static byte[] bytesUtf8(String s) {
|
|
|
|
+ // need to check if Utf8.encode() runs in constant time (probably not).
|
|
|
|
+ // This may leak length of string.
|
|
|
|
+ return (s != null) ? Utf8.encode(s) : null;
|
|
}
|
|
}
|
|
|
|
|
|
private Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
|
|
private Mono<CsrfToken> generateToken(ServerWebExchange exchange) {
|
|
return this.csrfTokenRepository.generateToken(exchange)
|
|
return this.csrfTokenRepository.generateToken(exchange)
|
|
- .delayUntil(token -> this.csrfTokenRepository.saveToken(exchange, token));
|
|
|
|
|
|
+ .delayUntil((token) -> this.csrfTokenRepository.saveToken(exchange, token));
|
|
}
|
|
}
|
|
|
|
|
|
private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher {
|
|
private static class DefaultRequireCsrfProtectionMatcher implements ServerWebExchangeMatcher {
|
|
|
|
+
|
|
private static final Set<HttpMethod> ALLOWED_METHODS = new HashSet<>(
|
|
private static final Set<HttpMethod> ALLOWED_METHODS = new HashSet<>(
|
|
- Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS));
|
|
|
|
|
|
+ Arrays.asList(HttpMethod.GET, HttpMethod.HEAD, HttpMethod.TRACE, HttpMethod.OPTIONS));
|
|
|
|
|
|
@Override
|
|
@Override
|
|
public Mono<MatchResult> matches(ServerWebExchange exchange) {
|
|
public Mono<MatchResult> matches(ServerWebExchange exchange) {
|
|
- return Mono.just(exchange.getRequest())
|
|
|
|
- .flatMap(r -> Mono.justOrEmpty(r.getMethod()))
|
|
|
|
- .filter(m -> ALLOWED_METHODS.contains(m))
|
|
|
|
- .flatMap(m -> MatchResult.notMatch())
|
|
|
|
- .switchIfEmpty(MatchResult.match());
|
|
|
|
|
|
+ return Mono.just(exchange.getRequest()).flatMap((r) -> Mono.justOrEmpty(r.getMethod()))
|
|
|
|
+ .filter(ALLOWED_METHODS::contains).flatMap((m) -> MatchResult.notMatch())
|
|
|
|
+ .switchIfEmpty(MatchResult.match());
|
|
}
|
|
}
|
|
|
|
+
|
|
}
|
|
}
|
|
|
|
+
|
|
}
|
|
}
|