Browse Source

Fix test .standaloneSetup

Previously, Spring Security's test support did not work well with the
standalone setup. This was because the springSecurityFilterChain was not
found by the WebTestUtils.

This commit ensures that the springSecurityFilterChain is added as a
servlet attribute if it is explicitly defined. WebTestUtils can then
find the springSecurityFilterChain in the ServletContext.

Fixes gh-3881
Rob Winch 9 years ago
parent
commit
7b61a44929

+ 2 - 0
test/src/main/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurer.java

@@ -68,6 +68,8 @@ final class SecurityMockMvcConfigurer extends MockMvcConfigurerAdapter {
 		}
 
 		builder.addFilters(this.springSecurityFilterChain);
+		context.getServletContext().setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN,
+				this.springSecurityFilterChain);
 
 		return testSecurityContext();
 	}

+ 25 - 12
test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java

@@ -18,9 +18,11 @@ package org.springframework.security.test.web.support;
 import java.util.List;
 
 import javax.servlet.Filter;
+import javax.servlet.ServletContext;
 import javax.servlet.http.HttpServletRequest;
 
 import org.springframework.beans.factory.NoSuchBeanDefinitionException;
+import org.springframework.security.config.BeanIds;
 import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer;
 import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
 import org.springframework.security.web.context.SecurityContextPersistenceFilter;
@@ -115,18 +117,9 @@ public abstract class WebTestUtils {
 	@SuppressWarnings("unchecked")
 	static <T extends Filter> T findFilter(HttpServletRequest request,
 			Class<T> filterClass) {
-		WebApplicationContext webApplicationContext = WebApplicationContextUtils
-				.getWebApplicationContext(request.getServletContext());
-		if (webApplicationContext == null) {
-			return null;
-		}
-		Filter springSecurityFilterChain = null;
-		try {
-			springSecurityFilterChain = webApplicationContext.getBean(
-					AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME,
-					Filter.class);
-		}
-		catch (NoSuchBeanDefinitionException notFound) {
+		ServletContext servletContext = request.getServletContext();
+		Filter springSecurityFilterChain = getSpringSecurityFilterChain(servletContext);
+		if (springSecurityFilterChain == null) {
 			return null;
 		}
 		List<Filter> filters = (List<Filter>) ReflectionTestUtils
@@ -142,6 +135,26 @@ public abstract class WebTestUtils {
 		return null;
 	}
 
+	private static Filter getSpringSecurityFilterChain(ServletContext servletContext) {
+		Filter result = (Filter) servletContext
+				.getAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN);
+		if (result != null) {
+			return result;
+		}
+		WebApplicationContext webApplicationContext = WebApplicationContextUtils
+				.getWebApplicationContext(servletContext);
+		if (webApplicationContext != null) {
+			try {
+				return webApplicationContext.getBean(
+						AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME,
+						Filter.class);
+			}
+			catch (NoSuchBeanDefinitionException notFound) {
+			}
+		}
+		return null;
+	}
+
 	private WebTestUtils() {
 	}
 }

+ 20 - 0
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java

@@ -27,6 +27,7 @@ import javax.servlet.http.HttpSession;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.mock.web.MockHttpServletRequest;
@@ -34,6 +35,8 @@ import org.springframework.mock.web.MockHttpSession;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
+import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessorsCsrfTests.Config.TheController;
+import org.springframework.security.web.FilterChainProxy;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
 import org.springframework.test.context.web.WebAppConfiguration;
@@ -58,6 +61,10 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
 public class SecurityMockMvcRequestPostProcessorsCsrfTests {
 	@Autowired
 	WebApplicationContext wac;
+	@Autowired
+	TheController controller;
+	@Autowired
+	FilterChainProxy springSecurityFilterChain;
 
 	MockMvc mockMvc;
 
@@ -69,7 +76,20 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
 			.apply(springSecurity())
 			.build();
 		// @formatter:on
+	}
+
+	// gh-3881
+	@Test
+	public void csrfWithStandalone() throws Exception {
+		// @formatter:off
+		this.mockMvc = MockMvcBuilders
+				.standaloneSetup(this.controller)
+				.apply(springSecurity(this.springSecurityFilterChain))
 				.build();
+		this.mockMvc.perform(post("/").with(csrf()))
+			.andExpect(status().is2xxSuccessful())
+			.andExpect(csrfAsParam());
+		// @formatter:on
 	}
 
 	@Test

+ 12 - 0
test/src/test/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurerTests.java

@@ -16,12 +16,15 @@
 package org.springframework.security.test.web.servlet.setup;
 
 import javax.servlet.Filter;
+import javax.servlet.ServletContext;
 
+import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.runners.MockitoJUnitRunner;
 
+import org.springframework.security.config.BeanIds;
 import org.springframework.test.web.servlet.setup.ConfigurableMockMvcBuilder;
 import org.springframework.web.context.WebApplicationContext;
 
@@ -40,6 +43,13 @@ public class SecurityMockMvcConfigurerTests {
 	private ConfigurableMockMvcBuilder<?> builder;
 	@Mock
 	private WebApplicationContext context;
+	@Mock
+	private ServletContext servletContext;
+
+	@Before
+	public void setup() {
+		when(this.context.getServletContext()).thenReturn(this.servletContext);
+	}
 
 	@Test
 	public void beforeMockMvcCreatedOverrideBean() throws Exception {
@@ -49,6 +59,8 @@ public class SecurityMockMvcConfigurerTests {
 		configurer.beforeMockMvcCreated(this.builder, this.context);
 
 		verify(this.builder).addFilters(this.filter);
+		verify(this.servletContext).setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN,
+				this.filter);
 	}
 
 	@Test

+ 42 - 0
test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java

@@ -25,14 +25,19 @@ import org.mockito.runners.MockitoJUnitRunner;
 import org.springframework.context.ConfigurableApplicationContext;
 import org.springframework.context.annotation.Configuration;
 import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.security.config.BeanIds;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
+import org.springframework.security.web.DefaultSecurityFilterChain;
+import org.springframework.security.web.FilterChainProxy;
 import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
 import org.springframework.security.web.context.SecurityContextPersistenceFilter;
 import org.springframework.security.web.context.SecurityContextRepository;
+import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
+import org.springframework.security.web.util.matcher.AnyRequestMatcher;
 import org.springframework.web.context.WebApplicationContext;
 import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
 
@@ -129,6 +134,34 @@ public class WebTestUtilsTests {
 				SecurityContextPersistenceFilter.class)).isNull();
 	}
 
+	@Test
+	public void findFilterNoSpringSecurityFilterChainInContext() {
+		loadConfig(NoSecurityConfig.class);
+
+		CsrfFilter toFind = new CsrfFilter(new HttpSessionCsrfTokenRepository());
+		FilterChainProxy springSecurityFilterChain = new FilterChainProxy(
+				new DefaultSecurityFilterChain(AnyRequestMatcher.INSTANCE, toFind));
+		this.request.getServletContext().setAttribute(
+				BeanIds.SPRING_SECURITY_FILTER_CHAIN, springSecurityFilterChain);
+
+		assertThat(WebTestUtils.findFilter(this.request, toFind.getClass()))
+				.isEqualTo(toFind);
+	}
+
+	@Test
+	public void findFilterExplicitWithSecurityFilterInContext() {
+		loadConfig(SecurityConfigWithDefaults.class);
+
+		CsrfFilter toFind = new CsrfFilter(new HttpSessionCsrfTokenRepository());
+		FilterChainProxy springSecurityFilterChain = new FilterChainProxy(
+				new DefaultSecurityFilterChain(AnyRequestMatcher.INSTANCE, toFind));
+		this.request.getServletContext().setAttribute(
+				BeanIds.SPRING_SECURITY_FILTER_CHAIN, springSecurityFilterChain);
+
+		assertThat(WebTestUtils.findFilter(this.request, toFind.getClass()))
+				.isSameAs(toFind);
+	}
+
 	private void loadConfig(Class<?> config) {
 		AnnotationConfigWebApplicationContext context = new AnnotationConfigWebApplicationContext();
 		context.register(config);
@@ -180,4 +213,13 @@ public class WebTestUtilsTests {
 		}
 		// @formatter:on
 	}
+
+	@Configuration
+	static class NoSecurityConfig {
+	}
+
+	@EnableWebSecurity
+	static class SecurityConfigWithDefaults extends WebSecurityConfigurerAdapter {
+
+	}
 }