浏览代码

Fix csrf() when used then not used

Previously if csrf() was used and subsequently not used, the
TestCsrfTokenRepository was still used. This makes it difficult to test
the actual CsrfTokenRepository implementation.

Now the TestCsrfTokenRepository is only used if explicitly enabled.

Fixes gh-4016
Rob Winch 9 年之前
父节点
当前提交
a93fb1e0e7

+ 25 - 4
test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java

@@ -318,13 +318,13 @@ public final class SecurityMockMvcRequestPostProcessors {
 		 */
 		 */
 		@Override
 		@Override
 		public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
 		public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
-
 			CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
 			CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
 			if (!(repository instanceof TestCsrfTokenRepository)) {
 			if (!(repository instanceof TestCsrfTokenRepository)) {
 				repository = new TestCsrfTokenRepository(
 				repository = new TestCsrfTokenRepository(
 						new HttpSessionCsrfTokenRepository());
 						new HttpSessionCsrfTokenRepository());
 				WebTestUtils.setCsrfTokenRepository(request, repository);
 				WebTestUtils.setCsrfTokenRepository(request, repository);
 			}
 			}
+			TestCsrfTokenRepository.enable(request);
 			CsrfToken token = repository.generateToken(request);
 			CsrfToken token = repository.generateToken(request);
 			repository.saveToken(token, request, new MockHttpServletResponse());
 			repository.saveToken(token, request, new MockHttpServletResponse());
 			String tokenValue = this.useInvalidToken ? "invalid" + token.getToken()
 			String tokenValue = this.useInvalidToken ? "invalid" + token.getToken()
@@ -367,9 +367,12 @@ public final class SecurityMockMvcRequestPostProcessors {
 		 * request is wrapped (i.e. Spring Session is in use).
 		 * request is wrapped (i.e. Spring Session is in use).
 		 */
 		 */
 		static class TestCsrfTokenRepository implements CsrfTokenRepository {
 		static class TestCsrfTokenRepository implements CsrfTokenRepository {
-			final static String ATTR_NAME = TestCsrfTokenRepository.class.getName()
+			final static String TOKEN_ATTR_NAME = TestCsrfTokenRepository.class.getName()
 					.concat(".TOKEN");
 					.concat(".TOKEN");
 
 
+			final static String ENABLED_ATTR_NAME = TestCsrfTokenRepository.class
+					.getName().concat(".ENABLED");
+
 			private final CsrfTokenRepository delegate;
 			private final CsrfTokenRepository delegate;
 
 
 			private TestCsrfTokenRepository(CsrfTokenRepository delegate) {
 			private TestCsrfTokenRepository(CsrfTokenRepository delegate) {
@@ -384,12 +387,30 @@ public final class SecurityMockMvcRequestPostProcessors {
 			@Override
 			@Override
 			public void saveToken(CsrfToken token, HttpServletRequest request,
 			public void saveToken(CsrfToken token, HttpServletRequest request,
 					HttpServletResponse response) {
 					HttpServletResponse response) {
-				request.setAttribute(ATTR_NAME, token);
+				if (isEnabled(request)) {
+					request.setAttribute(TOKEN_ATTR_NAME, token);
+				}
+				else {
+					this.delegate.saveToken(token, request, response);
+				}
 			}
 			}
 
 
 			@Override
 			@Override
 			public CsrfToken loadToken(HttpServletRequest request) {
 			public CsrfToken loadToken(HttpServletRequest request) {
-				return (CsrfToken) request.getAttribute(ATTR_NAME);
+				if (isEnabled(request)) {
+					return (CsrfToken) request.getAttribute(TOKEN_ATTR_NAME);
+				}
+				else {
+					return this.delegate.loadToken(request);
+				}
+			}
+
+			public static void enable(HttpServletRequest request) {
+				request.setAttribute(ENABLED_ATTR_NAME, Boolean.TRUE);
+			}
+
+			public boolean isEnabled(HttpServletRequest request) {
+				return Boolean.TRUE.equals(request.getAttribute(ENABLED_ATTR_NAME));
 			}
 			}
 		}
 		}
 	}
 	}

+ 2 - 2
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java

@@ -39,7 +39,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
 	public void defaults() throws Exception {
 	public void defaults() throws Exception {
 		MockHttpServletRequest request = formLogin().buildRequest(this.servletContext);
 		MockHttpServletRequest request = formLogin().buildRequest(this.servletContext);
 		CsrfToken token = (CsrfToken) request
 		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
+				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
 
 
 		assertThat(request.getParameter("username")).isEqualTo("user");
 		assertThat(request.getParameter("username")).isEqualTo("user");
 		assertThat(request.getParameter("password")).isEqualTo("password");
 		assertThat(request.getParameter("password")).isEqualTo("password");
@@ -56,7 +56,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
 				.password("password", "secret").buildRequest(this.servletContext);
 				.password("password", "secret").buildRequest(this.servletContext);
 
 
 		CsrfToken token = (CsrfToken) request
 		CsrfToken token = (CsrfToken) request
-				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
+				.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
 
 
 		assertThat(request.getParameter("username")).isEqualTo("admin");
 		assertThat(request.getParameter("username")).isEqualTo("admin");
 		assertThat(request.getParameter("password")).isEqualTo("secret");
 		assertThat(request.getParameter("password")).isEqualTo("secret");

+ 2 - 2
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java

@@ -37,7 +37,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
 	public void defaults() throws Exception {
 	public void defaults() throws Exception {
 		MockHttpServletRequest request = logout().buildRequest(servletContext);
 		MockHttpServletRequest request = logout().buildRequest(servletContext);
 
 
-		CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
+		CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
 
 
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(
@@ -50,7 +50,7 @@ public class SecurityMockMvcRequestBuildersFormLogoutTests {
 		MockHttpServletRequest request = logout("/admin/logout").buildRequest(
 		MockHttpServletRequest request = logout("/admin/logout").buildRequest(
 				servletContext);
 				servletContext);
 
 
-		CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
+		CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.TOKEN_ATTR_NAME);
 
 
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(

+ 23 - 0
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java

@@ -31,18 +31,22 @@ import org.junit.runner.RunWith;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Bean;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpSession;
 import org.springframework.mock.web.MockHttpSession;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessorsCsrfTests.Config.TheController;
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessorsCsrfTests.Config.TheController;
 import org.springframework.security.web.FilterChainProxy;
 import org.springframework.security.web.FilterChainProxy;
+import org.springframework.security.web.csrf.CsrfToken;
+import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
 import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
 import org.springframework.test.context.web.WebAppConfiguration;
 import org.springframework.test.context.web.WebAppConfiguration;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.MockMvc;
 import org.springframework.test.web.servlet.MvcResult;
 import org.springframework.test.web.servlet.MvcResult;
 import org.springframework.test.web.servlet.ResultMatcher;
 import org.springframework.test.web.servlet.ResultMatcher;
+import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder;
 import org.springframework.test.web.servlet.setup.MockMvcBuilders;
 import org.springframework.test.web.servlet.setup.MockMvcBuilders;
 import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.bind.annotation.RestController;
@@ -143,6 +147,25 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
 		// @formatter:on
 		// @formatter:on
 	}
 	}
 
 
+	// gh-4016
+	@Test
+	public void csrfWhenUsedThenDoesNotImpactOriginalRepository() throws Exception {
+		// @formatter:off
+		this.mockMvc.perform(post("/").with(csrf()));
+
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		HttpSessionCsrfTokenRepository repo = new HttpSessionCsrfTokenRepository();
+		CsrfToken token = repo.generateToken(request);
+		repo.saveToken(token, request, new MockHttpServletResponse());
+
+		MockHttpServletRequestBuilder requestWithCsrf = post("/")
+			.param(token.getParameterName(), token.getToken())
+			.session((MockHttpSession)request.getSession());
+		this.mockMvc.perform(requestWithCsrf)
+			.andExpect(status().isOk());
+		// @formatter:on
+	}
+
 	public static ResultMatcher csrfAsParam() {
 	public static ResultMatcher csrfAsParam() {
 		return new CsrfParamResultMatcher();
 		return new CsrfParamResultMatcher();
 	}
 	}