Browse Source

Add Request-level CSRF Skip

Fixes gh-7367
Josh Cummings 6 năm trước cách đây
mục cha
commit
aa12748c9b

+ 21 - 0
web/src/main/java/org/springframework/security/web/csrf/CsrfFilter.java

@@ -35,6 +35,8 @@ import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
 import org.springframework.web.filter.OncePerRequestFilter;
 
+import static java.lang.Boolean.TRUE;
+
 /**
  * <p>
  * Applies
@@ -63,6 +65,16 @@ public final class CsrfFilter extends OncePerRequestFilter {
 	 */
 	public static final RequestMatcher DEFAULT_CSRF_MATCHER = new DefaultRequiresCsrfMatcher();
 
+	/**
+	 * The attribute name to use when marking a given request as one that should not be filtered.
+	 *
+	 * To use, set the attribute on your {@link HttpServletRequest}:
+	 * <pre>
+	 * 	CsrfFilter.skipRequest(request);
+	 * </pre>
+	 */
+	private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfFilter.class.getName();
+
 	private final Log logger = LogFactory.getLog(getClass());
 	private final CsrfTokenRepository tokenRepository;
 	private RequestMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
@@ -73,6 +85,11 @@ public final class CsrfFilter extends OncePerRequestFilter {
 		this.tokenRepository = csrfTokenRepository;
 	}
 
+	@Override
+	protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
+		return TRUE.equals(request.getAttribute(SHOULD_NOT_FILTER));
+	}
+
 	/*
 	 * (non-Javadoc)
 	 *
@@ -124,6 +141,10 @@ public final class CsrfFilter extends OncePerRequestFilter {
 		filterChain.doFilter(request, response);
 	}
 
+	public static void skipRequest(HttpServletRequest request) {
+		request.setAttribute(SHOULD_NOT_FILTER, TRUE);
+	}
+
 	/**
 	 * Specifies a {@link RequestMatcher} that is used to determine if CSRF protection
 	 * should be applied. If the {@link RequestMatcher} returns true for a given request,

+ 20 - 0
web/src/main/java/org/springframework/security/web/server/csrf/CsrfWebFilter.java

@@ -32,6 +32,8 @@ import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
 import org.springframework.web.server.WebFilterChain;
 
+import static java.lang.Boolean.TRUE;
+
 /**
  * <p>
  * Applies
@@ -60,6 +62,16 @@ import org.springframework.web.server.WebFilterChain;
 public class CsrfWebFilter implements WebFilter {
 	public static final ServerWebExchangeMatcher DEFAULT_CSRF_MATCHER = new DefaultRequireCsrfProtectionMatcher();
 
+	/**
+	 * The attribute name to use when marking a given request as one that should not be filtered.
+	 *
+	 * To use, set the attribute on your {@link ServerWebExchange}:
+	 * <pre>
+	 * 	CsrfWebFilter.skipExchange(exchange);
+	 * </pre>
+	 */
+	private static final String SHOULD_NOT_FILTER = "SHOULD_NOT_FILTER" + CsrfWebFilter.class.getName();
+
 	private ServerWebExchangeMatcher requireCsrfProtectionMatcher = DEFAULT_CSRF_MATCHER;
 
 	private ServerCsrfTokenRepository csrfTokenRepository = new WebSessionServerCsrfTokenRepository();
@@ -86,6 +98,10 @@ public class CsrfWebFilter implements WebFilter {
 
 	@Override
 	public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
+		if (TRUE.equals(exchange.getAttribute(SHOULD_NOT_FILTER))) {
+			return chain.filter(exchange).then(Mono.empty());
+		}
+
 		return this.requireCsrfProtectionMatcher.matches(exchange)
 			.filter( matchResult -> matchResult.isMatch())
 			.filter( matchResult -> !exchange.getAttributes().containsKey(CsrfToken.class.getName()))
@@ -96,6 +112,10 @@ public class CsrfWebFilter implements WebFilter {
 				.handle(exchange, e));
 	}
 
+	public static void skipExchange(ServerWebExchange exchange) {
+		exchange.getAttributes().put(SHOULD_NOT_FILTER, TRUE);
+	}
+
 	private Mono<Void> validateToken(ServerWebExchange exchange) {
 		return this.csrfTokenRepository.loadToken(exchange)
 			.switchIfEmpty(Mono.defer(() -> Mono.error(new CsrfException("CSRF Token has been associated to this client"))))

+ 19 - 0
web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java

@@ -31,6 +31,7 @@ import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 
+import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.security.web.access.AccessDeniedHandler;
@@ -39,6 +40,8 @@ import org.springframework.security.web.util.matcher.RequestMatcher;
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.Mockito.lenient;
+import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
@@ -390,6 +393,22 @@ public class CsrfFilterTests {
 		verifyZeroInteractions(this.filterChain);
 	}
 
+	@Test
+	public void doFilterWhenSkipRequestInvokedThenSkips()
+			throws Exception {
+
+		CsrfTokenRepository repository = mock(CsrfTokenRepository.class);
+		CsrfFilter filter = new CsrfFilter(repository);
+
+		lenient().when(repository.loadToken(any(HttpServletRequest.class))).thenReturn(this.token);
+
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		CsrfFilter.skipRequest(request);
+		filter.doFilter(request, new MockHttpServletResponse(), new MockFilterChain());
+
+		verifyZeroInteractions(repository);
+	}
+
 	@Test(expected = IllegalArgumentException.class)
 	public void setRequireCsrfProtectionMatcherNull() {
 		this.filter.setRequireCsrfProtectionMatcher(null);

+ 29 - 9
web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java

@@ -20,19 +20,24 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
+import reactor.core.publisher.Mono;
+import reactor.test.StepVerifier;
+import reactor.test.publisher.PublisherProbe;
+
 import org.springframework.http.HttpStatus;
 import org.springframework.http.MediaType;
 import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
 import org.springframework.mock.web.server.MockServerWebExchange;
+import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher;
 import org.springframework.web.server.WebFilterChain;
 import org.springframework.web.server.WebSession;
-import reactor.core.publisher.Mono;
-import reactor.test.StepVerifier;
-import reactor.test.publisher.PublisherProbe;
 
 import static org.assertj.core.api.AssertionsForInterfaceTypes.assertThat;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verifyZeroInteractions;
 import static org.mockito.Mockito.when;
+import static org.springframework.mock.web.server.MockServerWebExchange.from;
 
 /**
  * @author Rob Winch
@@ -49,10 +54,10 @@ public class CsrfWebFilterTests {
 
 	private CsrfWebFilter csrfFilter = new CsrfWebFilter();
 
-	private MockServerWebExchange get = MockServerWebExchange.from(
+	private MockServerWebExchange get = from(
 		MockServerHttpRequest.get("/"));
 
-	private MockServerWebExchange post = MockServerWebExchange.from(
+	private MockServerWebExchange post = from(
 		MockServerHttpRequest.post("/"));
 
 	@Test
@@ -104,7 +109,7 @@ public class CsrfWebFilterTests {
 		this.csrfFilter.setCsrfTokenRepository(this.repository);
 		when(this.repository.loadToken(any()))
 			.thenReturn(Mono.just(this.token));
-		this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
+		this.post = from(MockServerHttpRequest.post("/")
 			.body(this.token.getParameterName() + "="+this.token.getToken()+"INVALID"));
 
 		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
@@ -125,7 +130,7 @@ public class CsrfWebFilterTests {
 			.thenReturn(Mono.just(this.token));
 		when(this.repository.generateToken(any()))
 			.thenReturn(Mono.just(this.token));
-		this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
+		this.post = from(MockServerHttpRequest.post("/")
 			.contentType(MediaType.APPLICATION_FORM_URLENCODED)
 			.body(this.token.getParameterName() + "="+this.token.getToken()));
 
@@ -142,7 +147,7 @@ public class CsrfWebFilterTests {
 		this.csrfFilter.setCsrfTokenRepository(this.repository);
 		when(this.repository.loadToken(any()))
 			.thenReturn(Mono.just(this.token));
-		this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
+		this.post = from(MockServerHttpRequest.post("/")
 			.header(this.token.getHeaderName(), this.token.getToken()+"INVALID"));
 
 		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
@@ -163,7 +168,7 @@ public class CsrfWebFilterTests {
 			.thenReturn(Mono.just(this.token));
 		when(this.repository.generateToken(any()))
 			.thenReturn(Mono.just(this.token));
-		this.post = MockServerWebExchange.from(MockServerHttpRequest.post("/")
+		this.post = from(MockServerHttpRequest.post("/")
 			.header(this.token.getHeaderName(), this.token.getToken()));
 
 		Mono<Void> result = this.csrfFilter.filter(this.post, this.chain);
@@ -173,4 +178,19 @@ public class CsrfWebFilterTests {
 
 		chainResult.assertWasSubscribed();
 	}
+
+	@Test
+	public void doFilterWhenSkipExchangeInvokedThenSkips() {
+		PublisherProbe<Void> chainResult = PublisherProbe.empty();
+		when(this.chain.filter(any())).thenReturn(chainResult.mono());
+
+		ServerWebExchangeMatcher matcher = mock(ServerWebExchangeMatcher.class);
+		this.csrfFilter.setRequireCsrfProtectionMatcher(matcher);
+
+		MockServerWebExchange exchange = from(MockServerHttpRequest.post("/post").build());
+		CsrfWebFilter.skipExchange(exchange);
+		this.csrfFilter.filter(exchange, this.chain).block();
+
+		verifyZeroInteractions(matcher);
+	}
 }