|
@@ -1,5 +1,5 @@
|
|
/*
|
|
/*
|
|
- * Copyright 2002-2021 the original author or authors.
|
|
|
|
|
|
+ * Copyright 2002-2022 the original author or authors.
|
|
*
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* you may not use this file except in compliance with the License.
|
|
@@ -23,12 +23,8 @@ import java.util.Set;
|
|
|
|
|
|
import reactor.core.publisher.Mono;
|
|
import reactor.core.publisher.Mono;
|
|
|
|
|
|
-import org.springframework.http.HttpHeaders;
|
|
|
|
import org.springframework.http.HttpMethod;
|
|
import org.springframework.http.HttpMethod;
|
|
import org.springframework.http.HttpStatus;
|
|
import org.springframework.http.HttpStatus;
|
|
-import org.springframework.http.MediaType;
|
|
|
|
-import org.springframework.http.codec.multipart.FormFieldPart;
|
|
|
|
-import org.springframework.http.server.reactive.ServerHttpRequest;
|
|
|
|
import org.springframework.security.crypto.codec.Utf8;
|
|
import org.springframework.security.crypto.codec.Utf8;
|
|
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
|
|
import org.springframework.security.web.server.authorization.HttpStatusServerAccessDeniedHandler;
|
|
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
|
|
import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler;
|
|
@@ -63,6 +59,7 @@ import org.springframework.web.server.WebFilterChain;
|
|
*
|
|
*
|
|
* @author Rob Winch
|
|
* @author Rob Winch
|
|
* @author Parikshit Dutta
|
|
* @author Parikshit Dutta
|
|
|
|
+ * @author Steve Riesenberg
|
|
* @since 5.0
|
|
* @since 5.0
|
|
*/
|
|
*/
|
|
public class CsrfWebFilter implements WebFilter {
|
|
public class CsrfWebFilter implements WebFilter {
|
|
@@ -86,7 +83,7 @@ public class CsrfWebFilter implements WebFilter {
|
|
private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(
|
|
private ServerAccessDeniedHandler accessDeniedHandler = new HttpStatusServerAccessDeniedHandler(
|
|
HttpStatus.FORBIDDEN);
|
|
HttpStatus.FORBIDDEN);
|
|
|
|
|
|
- private boolean isTokenFromMultipartDataEnabled;
|
|
|
|
|
|
+ private ServerCsrfTokenRequestHandler requestHandler = new ServerCsrfTokenRequestAttributeHandler();
|
|
|
|
|
|
public void setAccessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) {
|
|
public void setAccessDeniedHandler(ServerAccessDeniedHandler accessDeniedHandler) {
|
|
Assert.notNull(accessDeniedHandler, "accessDeniedHandler");
|
|
Assert.notNull(accessDeniedHandler, "accessDeniedHandler");
|
|
@@ -103,14 +100,34 @@ public class CsrfWebFilter implements WebFilter {
|
|
this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
|
|
this.requireCsrfProtectionMatcher = requireCsrfProtectionMatcher;
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ /**
|
|
|
|
+ * Specifies a {@link ServerCsrfTokenRequestHandler} that is used to make the
|
|
|
|
+ * {@code CsrfToken} available as an exchange attribute.
|
|
|
|
+ * <p>
|
|
|
|
+ * The default is {@link ServerCsrfTokenRequestAttributeHandler}.
|
|
|
|
+ * @param requestHandler the {@link ServerCsrfTokenRequestHandler} to use
|
|
|
|
+ * @since 5.8
|
|
|
|
+ */
|
|
|
|
+ public void setRequestHandler(ServerCsrfTokenRequestHandler requestHandler) {
|
|
|
|
+ Assert.notNull(requestHandler, "requestHandler cannot be null");
|
|
|
|
+ this.requestHandler = requestHandler;
|
|
|
|
+ }
|
|
|
|
+
|
|
/**
|
|
/**
|
|
* Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token
|
|
* Specifies if the {@code CsrfWebFilter} should try to resolve the actual CSRF token
|
|
* from the body of multipart data requests.
|
|
* from the body of multipart data requests.
|
|
* @param tokenFromMultipartDataEnabled true if should read from multipart form body,
|
|
* @param tokenFromMultipartDataEnabled true if should read from multipart form body,
|
|
* else false. Default is false
|
|
* else false. Default is false
|
|
|
|
+ * @deprecated Use
|
|
|
|
+ * {@link ServerCsrfTokenRequestAttributeHandler#setTokenFromMultipartDataEnabled(boolean)}
|
|
|
|
+ * instead
|
|
*/
|
|
*/
|
|
|
|
+ @Deprecated
|
|
public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) {
|
|
public void setTokenFromMultipartDataEnabled(boolean tokenFromMultipartDataEnabled) {
|
|
- this.isTokenFromMultipartDataEnabled = tokenFromMultipartDataEnabled;
|
|
|
|
|
|
+ if (this.requestHandler instanceof ServerCsrfTokenRequestAttributeHandler) {
|
|
|
|
+ ((ServerCsrfTokenRequestAttributeHandler) this.requestHandler)
|
|
|
|
+ .setTokenFromMultipartDataEnabled(tokenFromMultipartDataEnabled);
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
@Override
|
|
@Override
|
|
@@ -138,30 +155,14 @@ public class CsrfWebFilter implements WebFilter {
|
|
}
|
|
}
|
|
|
|
|
|
private Mono<Boolean> containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) {
|
|
private Mono<Boolean> containsValidCsrfToken(ServerWebExchange exchange, CsrfToken expected) {
|
|
- return exchange.getFormData().flatMap((data) -> Mono.justOrEmpty(data.getFirst(expected.getParameterName())))
|
|
|
|
- .switchIfEmpty(Mono.justOrEmpty(exchange.getRequest().getHeaders().getFirst(expected.getHeaderName())))
|
|
|
|
- .switchIfEmpty(tokenFromMultipartData(exchange, expected))
|
|
|
|
|
|
+ return this.requestHandler.resolveCsrfTokenValue(exchange, expected)
|
|
.map((actual) -> equalsConstantTime(actual, expected.getToken()));
|
|
.map((actual) -> equalsConstantTime(actual, expected.getToken()));
|
|
}
|
|
}
|
|
|
|
|
|
- private Mono<String> tokenFromMultipartData(ServerWebExchange exchange, CsrfToken expected) {
|
|
|
|
- if (!this.isTokenFromMultipartDataEnabled) {
|
|
|
|
- return Mono.empty();
|
|
|
|
- }
|
|
|
|
- ServerHttpRequest request = exchange.getRequest();
|
|
|
|
- HttpHeaders headers = request.getHeaders();
|
|
|
|
- MediaType contentType = headers.getContentType();
|
|
|
|
- if (!MediaType.MULTIPART_FORM_DATA.isCompatibleWith(contentType)) {
|
|
|
|
- return Mono.empty();
|
|
|
|
- }
|
|
|
|
- return exchange.getMultipartData().map((d) -> d.getFirst(expected.getParameterName())).cast(FormFieldPart.class)
|
|
|
|
- .map(FormFieldPart::value);
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
|
|
private Mono<Void> continueFilterChain(ServerWebExchange exchange, WebFilterChain chain) {
|
|
return Mono.defer(() -> {
|
|
return Mono.defer(() -> {
|
|
Mono<CsrfToken> csrfToken = csrfToken(exchange);
|
|
Mono<CsrfToken> csrfToken = csrfToken(exchange);
|
|
- exchange.getAttributes().put(CsrfToken.class.getName(), csrfToken);
|
|
|
|
|
|
+ this.requestHandler.handle(exchange, csrfToken);
|
|
return chain.filter(exchange);
|
|
return chain.filter(exchange);
|
|
});
|
|
});
|
|
}
|
|
}
|