Sfoglia il codice sorgente

Merge branch '5.8.x'

Steve Riesenberg 2 anni fa
parent
commit
7c872cf7fd

+ 11 - 8
config/src/test/java/org/springframework/security/config/http/CsrfConfigTests.java

@@ -41,6 +41,7 @@ import org.springframework.security.web.FilterChainProxy;
 import org.springframework.security.web.access.AccessDeniedHandler;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfToken;
+import org.springframework.security.web.csrf.CsrfTokenRepository;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.stereotype.Controller;
 import org.springframework.test.context.junit.jupiter.SpringExtension;
@@ -301,7 +302,7 @@ public class CsrfConfigTests {
 	}
 
 	@Test
-	public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorThenOk() throws Exception {
+	public void postWhenUsingCsrfAndXorCsrfTokenRequestAttributeHandlerThenOk() throws Exception {
 		this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers"))
 				.autowire();
 		// @formatter:off
@@ -309,25 +310,27 @@ public class CsrfConfigTests {
 				.andExpect(status().isOk())
 				.andReturn();
 		MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession();
-		CsrfToken csrfToken = (CsrfToken) mvcResult.getRequest().getAttribute("_csrf");
 		MockHttpServletRequestBuilder ok = post("/ok")
-				.header(csrfToken.getHeaderName(), csrfToken.getToken())
+				.with(csrf())
 				.session(session);
 		this.mvc.perform(ok).andExpect(status().isOk());
 		// @formatter:on
 	}
 
 	@Test
-	public void postWhenUsingCsrfAndXorCsrfTokenRequestProcessorWithRawTokenThenForbidden() throws Exception {
+	public void postWhenUsingCsrfAndXorCsrfTokenRequestAttributeHandlerWithRawTokenThenForbidden() throws Exception {
 		this.spring.configLocations(this.xml("WithXorCsrfTokenRequestAttributeHandler"), this.xml("shared-controllers"))
 				.autowire();
 		// @formatter:off
-		MvcResult mvcResult = this.mvc.perform(get("/ok"))
+		MvcResult mvcResult = this.mvc.perform(get("/csrf"))
 				.andExpect(status().isOk())
 				.andReturn();
-		MockHttpSession session = (MockHttpSession) mvcResult.getRequest().getSession();
+		MockHttpServletRequest request = mvcResult.getRequest();
+		MockHttpSession session = (MockHttpSession) request.getSession();
+		CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
+		CsrfToken csrfToken = repository.loadToken(request);
 		MockHttpServletRequestBuilder ok = post("/ok")
-				.with(csrf())
+				.header(csrfToken.getHeaderName(), csrfToken.getToken())
 				.session(session);
 		this.mvc.perform(ok).andExpect(status().isForbidden());
 		// @formatter:on
@@ -594,7 +597,7 @@ public class CsrfConfigTests {
 		@Override
 		public void match(MvcResult result) throws Exception {
 			MockHttpServletRequest request = result.getRequest();
-			CsrfToken token = WebTestUtils.getCsrfTokenRepository(request).loadToken(request);
+			CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
 			assertThat(token).isNotNull();
 			assertThat(token.getToken()).isEqualTo(this.token.apply(result));
 		}

+ 12 - 3
test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java

@@ -95,6 +95,8 @@ import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
+import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
+import org.springframework.security.web.csrf.DeferredCsrfToken;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.test.web.servlet.MockMvc;
@@ -499,6 +501,10 @@ public final class SecurityMockMvcRequestPostProcessors {
 	 */
 	public static final class CsrfRequestPostProcessor implements RequestPostProcessor {
 
+		private static final byte[] INVALID_TOKEN_BYTES = new byte[] { 1, 1, 1, 96, 99, 98 };
+
+		private static final String INVALID_TOKEN_VALUE = Base64.getEncoder().encodeToString(INVALID_TOKEN_BYTES);
+
 		private boolean asHeader;
 
 		private boolean useInvalidToken;
@@ -509,14 +515,17 @@ public final class SecurityMockMvcRequestPostProcessors {
 		@Override
 		public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
 			CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
+			CsrfTokenRequestHandler handler = WebTestUtils.getCsrfTokenRequestHandler(request);
 			if (!(repository instanceof TestCsrfTokenRepository)) {
 				repository = new TestCsrfTokenRepository(new HttpSessionCsrfTokenRepository());
 				WebTestUtils.setCsrfTokenRepository(request, repository);
 			}
 			TestCsrfTokenRepository.enable(request);
-			CsrfToken token = repository.generateToken(request);
-			repository.saveToken(token, request, new MockHttpServletResponse());
-			String tokenValue = this.useInvalidToken ? "invalid" + token.getToken() : token.getToken();
+			MockHttpServletResponse response = new MockHttpServletResponse();
+			DeferredCsrfToken deferredCsrfToken = repository.loadDeferredToken(request, response);
+			handler.handle(request, response, deferredCsrfToken::get);
+			CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
+			String tokenValue = this.useInvalidToken ? INVALID_TOKEN_VALUE : token.getToken();
 			if (this.asHeader) {
 				request.addHeader(token.getHeaderName(), tokenValue);
 			}

+ 21 - 0
test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java

@@ -31,7 +31,9 @@ import org.springframework.security.web.context.SecurityContextPersistenceFilter
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
+import org.springframework.security.web.csrf.CsrfTokenRequestHandler;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
+import org.springframework.security.web.csrf.XorCsrfTokenRequestAttributeHandler;
 import org.springframework.test.util.ReflectionTestUtils;
 import org.springframework.web.context.WebApplicationContext;
 import org.springframework.web.context.support.WebApplicationContextUtils;
@@ -48,6 +50,8 @@ public abstract class WebTestUtils {
 
 	private static final CsrfTokenRepository DEFAULT_TOKEN_REPO = new HttpSessionCsrfTokenRepository();
 
+	private static final CsrfTokenRequestHandler DEFAULT_CSRF_HANDLER = new XorCsrfTokenRequestAttributeHandler();
+
 	private WebTestUtils() {
 	}
 
@@ -107,6 +111,23 @@ public abstract class WebTestUtils {
 		return (CsrfTokenRepository) ReflectionTestUtils.getField(filter, "tokenRepository");
 	}
 
+	/**
+	 * Gets the {@link CsrfTokenRequestHandler} for the specified
+	 * {@link HttpServletRequest}. If one is not found, the default
+	 * {@link XorCsrfTokenRequestAttributeHandler} is used.
+	 * @param request the {@link HttpServletRequest} to obtain the
+	 * {@link CsrfTokenRequestHandler}
+	 * @return the {@link CsrfTokenRequestHandler} for the specified
+	 * {@link HttpServletRequest}
+	 */
+	public static CsrfTokenRequestHandler getCsrfTokenRequestHandler(HttpServletRequest request) {
+		CsrfFilter filter = findFilter(request, CsrfFilter.class);
+		if (filter == null) {
+			return DEFAULT_CSRF_HANDLER;
+		}
+		return (CsrfTokenRequestHandler) ReflectionTestUtils.getField(filter, "requestHandler");
+	}
+
 	/**
 	 * Sets the {@link CsrfTokenRepository} for the specified {@link HttpServletRequest}.
 	 * @param request the {@link HttpServletRequest} to obtain the

+ 3 - 7
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java

@@ -25,7 +25,6 @@ import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockServletContext;
-import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.MvcResult;
@@ -52,8 +51,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
 	@Test
 	public void defaults() {
 		MockHttpServletRequest request = formLogin().buildRequest(this.servletContext);
-		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
+		CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
 		assertThat(request.getParameter("username")).isEqualTo("user");
 		assertThat(request.getParameter("password")).isEqualTo("password");
 		assertThat(request.getMethod()).isEqualTo("POST");
@@ -66,8 +64,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
 	public void custom() {
 		MockHttpServletRequest request = formLogin("/login").user("username", "admin").password("password", "secret")
 				.buildRequest(this.servletContext);
-		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
+		CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
 		assertThat(request.getParameter("username")).isEqualTo("admin");
 		assertThat(request.getParameter("password")).isEqualTo("secret");
 		assertThat(request.getMethod()).isEqualTo("POST");
@@ -79,8 +76,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
 	public void customWithUriVars() {
 		MockHttpServletRequest request = formLogin().loginProcessingUrl("/uri-login/{var1}/{var2}", "val1", "val2")
 				.user("username", "admin").password("password", "secret").buildRequest(this.servletContext);
-		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
+		CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
 		assertThat(request.getParameter("username")).isEqualTo("admin");
 		assertThat(request.getParameter("password")).isEqualTo("secret");
 		assertThat(request.getMethod()).isEqualTo("POST");

+ 3 - 7
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java

@@ -25,7 +25,6 @@ import org.springframework.http.HttpMethod;
 import org.springframework.http.MediaType;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockServletContext;
-import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
 import org.springframework.security.web.csrf.CsrfToken;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.MvcResult;
@@ -52,8 +51,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
 	@Test
 	public void defaults() {
 		MockHttpServletRequest request = logout().buildRequest(this.servletContext);
-		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
+		CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
 		assertThat(request.getRequestURI()).isEqualTo("/logout");
@@ -62,8 +60,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
 	@Test
 	public void custom() {
 		MockHttpServletRequest request = logout("/admin/logout").buildRequest(this.servletContext);
-		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
+		CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
 		assertThat(request.getRequestURI()).isEqualTo("/admin/logout");
@@ -73,8 +70,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
 	public void customWithUriVars() {
 		MockHttpServletRequest request = logout().logoutUrl("/uri-logout/{var1}/{var2}", "val1", "val2")
 				.buildRequest(this.servletContext);
-		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
+		CsrfToken token = (CsrfToken) request.getAttribute(CsrfToken.class.getName());
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(token.getToken());
 		assertThat(request.getRequestURI()).isEqualTo("/uri-logout/val1/val2");