浏览代码

Polish Csrf Tests

Issue gh-9561
Josh Cummings 4 年之前
父节点
当前提交
4f7d529c5d

+ 14 - 13
web/src/test/java/org/springframework/security/web/csrf/CsrfFilterTests.java

@@ -17,7 +17,6 @@
 package org.springframework.security.web.csrf;
 
 import java.io.IOException;
-import java.lang.reflect.Method;
 import java.util.Arrays;
 
 import javax.servlet.FilterChain;
@@ -97,18 +96,6 @@ public class CsrfFilterTests {
 		this.response = new MockHttpServletResponse();
 	}
 
-	@Test
-	public void nullConstantTimeEquals() throws Exception {
-		Method method = CsrfFilter.class.getDeclaredMethod("equalsConstantTime", String.class, String.class);
-		method.setAccessible(true);
-		assertThat(method.invoke(CsrfFilter.class, null, null)).isEqualTo(true);
-		String expectedToken = "Hello—World";
-		String actualToken = new String("Hello—World");
-		assertThat(method.invoke(CsrfFilter.class, expectedToken, null)).isEqualTo(false);
-		assertThat(method.invoke(CsrfFilter.class, expectedToken, "hello-world")).isEqualTo(false);
-		assertThat(method.invoke(CsrfFilter.class, expectedToken, actualToken)).isEqualTo(true);
-	}
-
 	@Test
 	public void constructorNullRepository() {
 		assertThatIllegalArgumentException().isThrownBy(() -> new CsrfFilter(null));
@@ -333,6 +320,20 @@ public class CsrfFilterTests {
 		verifyZeroInteractions(repository);
 	}
 
+	// gh-9561
+	@Test
+	public void doFilterWhenTokenIsNullThenNoNullPointer() throws Exception {
+		CsrfFilter filter = createCsrfFilter(this.tokenRepository);
+		CsrfToken token = mock(CsrfToken.class);
+		given(token.getToken()).willReturn(null);
+		given(token.getHeaderName()).willReturn(this.token.getHeaderName());
+		given(token.getParameterName()).willReturn(this.token.getParameterName());
+		given(this.tokenRepository.loadToken(this.request)).willReturn(token);
+		given(this.requestMatcher.matches(this.request)).willReturn(true);
+		filter.doFilterInternal(this.request, this.response, this.filterChain);
+		assertThat(this.response.getStatus()).isEqualTo(HttpServletResponse.SC_OK);
+	}
+
 	@Test
 	public void setRequireCsrfProtectionMatcherNull() {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.filter.setRequireCsrfProtectionMatcher(null));

+ 15 - 14
web/src/test/java/org/springframework/security/web/server/csrf/CsrfWebFilterTests.java

@@ -16,8 +16,6 @@
 
 package org.springframework.security.web.server.csrf;
 
-import java.lang.reflect.Method;
-
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
@@ -67,18 +65,6 @@ public class CsrfWebFilterTests {
 
 	private MockServerWebExchange post = MockServerWebExchange.from(MockServerHttpRequest.post("/"));
 
-	@Test
-	public void nullConstantTimeEquals() throws Exception {
-		Method method = CsrfWebFilter.class.getDeclaredMethod("equalsConstantTime", String.class, String.class);
-		method.setAccessible(true);
-		assertThat(method.invoke(CsrfWebFilter.class, null, null)).isEqualTo(true);
-		String expectedToken = "Hello—World";
-		String actualToken = new String("Hello—World");
-		assertThat(method.invoke(CsrfWebFilter.class, expectedToken, null)).isEqualTo(false);
-		assertThat(method.invoke(CsrfWebFilter.class, expectedToken, "hello-world")).isEqualTo(false);
-		assertThat(method.invoke(CsrfWebFilter.class, expectedToken, actualToken)).isEqualTo(true);
-	}
-
 	@Test
 	public void filterWhenGetThenSessionNotCreatedAndChainContinues() {
 		PublisherProbe<Void> chainResult = PublisherProbe.empty();
@@ -226,6 +212,21 @@ public class CsrfWebFilterTests {
 				.isForbidden();
 	}
 
+	// gh-9561
+	@Test
+	public void doFilterWhenTokenIsNullThenNoNullPointer() {
+		this.csrfFilter.setCsrfTokenRepository(this.repository);
+		CsrfToken token = mock(CsrfToken.class);
+		given(token.getToken()).willReturn(null);
+		given(token.getHeaderName()).willReturn(this.token.getHeaderName());
+		given(token.getParameterName()).willReturn(this.token.getParameterName());
+		given(this.repository.loadToken(any())).willReturn(Mono.just(token));
+		WebTestClient client = WebTestClient.bindToController(new OkController()).webFilter(this.csrfFilter).build();
+		client.post().uri("/").contentType(MediaType.APPLICATION_FORM_URLENCODED)
+				.bodyValue(this.token.getParameterName() + "=" + this.token.getToken()).exchange().expectStatus()
+				.isForbidden();
+	}
+
 	@RestController
 	static class OkController {