瀏覽代碼

SEC-3097: Change CsrfRequestPostProcessor to use TestCsrfTokenRepository

This ensures that when using a wrapped HttpServletRequest (i.e. Spring
Session) that the CSRF token test support still works.
Rob Winch 10 年之前
父節點
當前提交
81e2778106

+ 34 - 0
test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java

@@ -316,6 +316,10 @@ public final class SecurityMockMvcRequestPostProcessors {
 		public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
 
 			CsrfTokenRepository repository = WebTestUtils.getCsrfTokenRepository(request);
+			if(!(repository instanceof TestCsrfTokenRepository)) {
+				repository = new TestCsrfTokenRepository(repository);
+				WebTestUtils.setCsrfTokenRepository(request, repository);
+			}
 			CsrfToken token = repository.generateToken(request);
 			repository.saveToken(token, request, new MockHttpServletResponse());
 			String tokenValue = useInvalidToken ? "invalid" + token.getToken() : token
@@ -352,6 +356,36 @@ public final class SecurityMockMvcRequestPostProcessors {
 
 		private CsrfRequestPostProcessor() {
 		}
+
+
+
+		/**
+		 * Used to wrap the CsrfTokenRepository to provide support for testing
+		 * when the request is wrapped (i.e. Spring Session is in use).
+		 */
+		static class TestCsrfTokenRepository implements
+				CsrfTokenRepository {
+			final static String ATTR_NAME = TestCsrfTokenRepository.class
+					.getName().concat(".TOKEN");
+
+			private final CsrfTokenRepository delegate;
+
+			private TestCsrfTokenRepository(CsrfTokenRepository delegate) {
+				this.delegate = delegate;
+			}
+
+			public CsrfToken generateToken(HttpServletRequest request) {
+				return delegate.generateToken(request);
+			}
+
+			public void saveToken(CsrfToken token, HttpServletRequest request, HttpServletResponse response) {
+				request.setAttribute(ATTR_NAME, token);
+			}
+
+			public CsrfToken loadToken(HttpServletRequest request) {
+				return (CsrfToken) request.getAttribute(ATTR_NAME);
+			}
+		}
 	}
 
 	public static class DigestRequestPostProcessor implements RequestPostProcessor {

+ 16 - 0
test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java

@@ -97,6 +97,22 @@ public abstract class WebTestUtils {
 				"tokenRepository");
 	}
 
+	/**
+	 * Sets the {@link CsrfTokenRepository} for the specified
+	 * {@link HttpServletRequest}.
+	 *
+	 * @param request the {@link HttpServletRequest} to obtain the
+	 * {@link CsrfTokenRepository}
+	 * @param repository the {@link CsrfTokenRepository} to set
+	 */
+	public static void setCsrfTokenRepository(HttpServletRequest request,
+			CsrfTokenRepository repository) {
+		CsrfFilter filter = findFilter(request, CsrfFilter.class);
+		if (filter != null) {
+			ReflectionTestUtils.setField(filter, "tokenRepository", repository);
+		}
+	}
+
 	@SuppressWarnings("unchecked")
 	private static <T extends Filter> T findFilter(HttpServletRequest request,
 			Class<T> filterClass) {

+ 6 - 34
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLoginTests.java

@@ -15,48 +15,28 @@
  */
 package org.springframework.security.test.web.servlet.request;
 
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.eq;
-import static org.mockito.Mockito.verify;
 import static org.fest.assertions.Assertions.assertThat;
-import static org.powermock.api.mockito.PowerMockito.spy;
-import static org.powermock.api.mockito.PowerMockito.when;
-import static org.powermock.api.mockito.PowerMockito.doReturn;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.formLogin;
 
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-
 import org.junit.Before;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockServletContext;
-import org.springframework.security.test.web.support.WebTestUtils;
-import org.springframework.security.web.csrf.CsrfTokenRepository;
-import org.springframework.security.web.csrf.DefaultCsrfToken;
+import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
+import org.springframework.security.web.csrf.CsrfToken;
 
-@RunWith(PowerMockRunner.class)
-@PrepareForTest({ WebTestUtils.class, SecurityMockMvcRequestBuildersFormLoginTests.class })
 public class SecurityMockMvcRequestBuildersFormLoginTests {
-	@Mock
-	private CsrfTokenRepository repository;
-	private DefaultCsrfToken token;
 	private MockServletContext servletContext;
 
 	@Before
 	public void setup() throws Exception {
-		token = new DefaultCsrfToken("header", "param", "token");
 		servletContext = new MockServletContext();
-		mockWebTestUtils();
 	}
 
 	@Test
 	public void defaults() throws Exception {
 		MockHttpServletRequest request = formLogin().buildRequest(servletContext);
+		CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
 
 		assertThat(request.getParameter("username")).isEqualTo("user");
 		assertThat(request.getParameter("password")).isEqualTo("password");
@@ -64,8 +44,7 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(
 				token.getToken());
 		assertThat(request.getRequestURI()).isEqualTo("/login");
-		verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
-				any(HttpServletResponse.class));
+		assertThat(request.getParameter("_csrf")).isNotNull();
 	}
 
 	@Test
@@ -73,20 +52,13 @@ public class SecurityMockMvcRequestBuildersFormLoginTests {
 		MockHttpServletRequest request = formLogin("/login").user("username", "admin")
 				.password("password", "secret").buildRequest(servletContext);
 
+		CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
+
 		assertThat(request.getParameter("username")).isEqualTo("admin");
 		assertThat(request.getParameter("password")).isEqualTo("secret");
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(
 				token.getToken());
 		assertThat(request.getRequestURI()).isEqualTo("/login");
-		verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
-				any(HttpServletResponse.class));
-	}
-
-	private void mockWebTestUtils() throws Exception {
-		spy(WebTestUtils.class);
-		doReturn(repository).when(WebTestUtils.class, "getCsrfTokenRepository",
-				any(HttpServletRequest.class));
-		when(repository.generateToken(any(HttpServletRequest.class))).thenReturn(token);
 	}
 }

+ 6 - 34
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestBuildersFormLogoutTests.java

@@ -15,75 +15,47 @@
  */
 package org.springframework.security.test.web.servlet.request;
 
-import static org.mockito.Matchers.any;
-import static org.mockito.Matchers.eq;
-import static org.mockito.Mockito.verify;
 import static org.fest.assertions.Assertions.assertThat;
-import static org.powermock.api.mockito.PowerMockito.spy;
-import static org.powermock.api.mockito.PowerMockito.when;
-import static org.powermock.api.mockito.PowerMockito.doReturn;
 import static org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestBuilders.logout;
 
-import javax.servlet.http.HttpServletRequest;
-import javax.servlet.http.HttpServletResponse;
-
 import org.junit.Before;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mock;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockServletContext;
-import org.springframework.security.test.web.support.WebTestUtils;
-import org.springframework.security.web.csrf.CsrfTokenRepository;
-import org.springframework.security.web.csrf.DefaultCsrfToken;
+import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.CsrfRequestPostProcessor;
+import org.springframework.security.web.csrf.CsrfToken;
 
-@RunWith(PowerMockRunner.class)
-@PrepareForTest({ WebTestUtils.class, SecurityMockMvcRequestBuildersFormLogoutTests.class })
 public class SecurityMockMvcRequestBuildersFormLogoutTests {
-	@Mock
-	private CsrfTokenRepository repository;
-	private DefaultCsrfToken token;
 	private MockServletContext servletContext;
 
 	@Before
 	public void setup() {
-		token = new DefaultCsrfToken("header", "param", "token");
 		servletContext = new MockServletContext();
 	}
 
 	@Test
 	public void defaults() throws Exception {
-		mockWebTestUtils();
 		MockHttpServletRequest request = logout().buildRequest(servletContext);
 
+		CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
+
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(
 				token.getToken());
 		assertThat(request.getRequestURI()).isEqualTo("/logout");
-		verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
-				any(HttpServletResponse.class));
 	}
 
 	@Test
 	public void custom() throws Exception {
-		mockWebTestUtils();
 		MockHttpServletRequest request = logout("/admin/logout").buildRequest(
 				servletContext);
 
+		CsrfToken token = (CsrfToken) request.getAttribute(CsrfRequestPostProcessor.TestCsrfTokenRepository.ATTR_NAME);
+
 		assertThat(request.getMethod()).isEqualTo("POST");
 		assertThat(request.getParameter(token.getParameterName())).isEqualTo(
 				token.getToken());
 		assertThat(request.getRequestURI()).isEqualTo("/admin/logout");
-		verify(repository).saveToken(eq(token), any(HttpServletRequest.class),
-				any(HttpServletResponse.class));
 	}
 
-	private void mockWebTestUtils() throws Exception {
-		spy(WebTestUtils.class);
-		doReturn(repository).when(WebTestUtils.class, "getCsrfTokenRepository",
-				any(HttpServletRequest.class));
-		when(repository.generateToken(any(HttpServletRequest.class))).thenReturn(token);
-	}
 }

+ 52 - 0
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java

@@ -21,12 +21,22 @@ import static org.springframework.security.test.web.servlet.setup.SecurityMockMv
 import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post;
 import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
 
+import java.io.IOException;
+
+import javax.servlet.FilterChain;
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletRequestWrapper;
+import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpSession;
+
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpSession;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
@@ -40,6 +50,7 @@ import org.springframework.test.web.servlet.setup.MockMvcBuilders;
 import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.context.WebApplicationContext;
+import org.springframework.web.filter.OncePerRequestFilter;
 
 @RunWith(SpringJUnit4ClassRunner.class)
 @ContextConfiguration
@@ -86,6 +97,20 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
 			.andExpect(csrfAsHeader());
 	}
 
+	// SEC-3097
+	@Test
+	public void csrfWithWrappedRequest() throws Exception {
+		mockMvc = MockMvcBuilders
+				.webAppContextSetup(wac)
+				.addFilter(new SessionRepositoryFilter())
+				.apply(springSecurity())
+				.build();
+
+		mockMvc.perform(post("/").with(csrf()))
+				.andExpect(status().is2xxSuccessful())
+				.andExpect(csrfAsParam());
+	}
+
 	public static ResultMatcher csrfAsParam() {
 		return new CsrfParamResultMatcher();
 	}
@@ -112,6 +137,33 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
 		}
 	}
 
+	static class SessionRepositoryFilter extends OncePerRequestFilter {
+
+		@Override
+		protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
+				throws ServletException, IOException {
+			filterChain.doFilter(new SessionRequestWrapper(request) , response);
+		}
+
+		static class SessionRequestWrapper extends HttpServletRequestWrapper {
+			HttpSession session = new MockHttpSession();
+
+			public SessionRequestWrapper(HttpServletRequest request) {
+				super(request);
+			}
+
+			@Override
+			public HttpSession getSession(boolean create) {
+				return session;
+			}
+
+			@Override
+			public HttpSession getSession() {
+				return session;
+			}
+		}
+	}
+
 	@EnableWebSecurity
 	static class Config extends WebSecurityConfigurerAdapter {
 		@Override