Browse Source

Fix test .standaloneSetup

Previously, Spring Security's test support did not work well with the
standalone setup. This was because the springSecurityFilterChain was not
found by the WebTestUtils.

This commit ensures that the springSecurityFilterChain is added as a
servlet attribute if it is explicitly defined. WebTestUtils can then
find the springSecurityFilterChain in the ServletContext.

Fixes gh-3881
Rob Winch 9 years ago
parent
commit
7b61a44929

+ 2 - 0
test/src/main/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurer.java

@@ -68,6 +68,8 @@ final class SecurityMockMvcConfigurer extends MockMvcConfigurerAdapter {
 		}
 		}
 
 
 		builder.addFilters(this.springSecurityFilterChain);
 		builder.addFilters(this.springSecurityFilterChain);
+		context.getServletContext().setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN,
+				this.springSecurityFilterChain);
 
 
 		return testSecurityContext();
 		return testSecurityContext();
 	}
 	}

+ 25 - 12
test/src/main/java/org/springframework/security/test/web/support/WebTestUtils.java

@@ -18,9 +18,11 @@ package org.springframework.security.test.web.support;
 import java.util.List;
 import java.util.List;
 
 
 import javax.servlet.Filter;
 import javax.servlet.Filter;
+import javax.servlet.ServletContext;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletRequest;
 
 
 import org.springframework.beans.factory.NoSuchBeanDefinitionException;
 import org.springframework.beans.factory.NoSuchBeanDefinitionException;
+import org.springframework.security.config.BeanIds;
 import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer;
 import org.springframework.security.web.context.AbstractSecurityWebApplicationInitializer;
 import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
 import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
 import org.springframework.security.web.context.SecurityContextPersistenceFilter;
 import org.springframework.security.web.context.SecurityContextPersistenceFilter;
@@ -115,18 +117,9 @@ public abstract class WebTestUtils {
 	@SuppressWarnings("unchecked")
 	@SuppressWarnings("unchecked")
 	static <T extends Filter> T findFilter(HttpServletRequest request,
 	static <T extends Filter> T findFilter(HttpServletRequest request,
 			Class<T> filterClass) {
 			Class<T> filterClass) {
-		WebApplicationContext webApplicationContext = WebApplicationContextUtils
-				.getWebApplicationContext(request.getServletContext());
-		if (webApplicationContext == null) {
-			return null;
-		}
-		Filter springSecurityFilterChain = null;
-		try {
-			springSecurityFilterChain = webApplicationContext.getBean(
-					AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME,
-					Filter.class);
-		}
-		catch (NoSuchBeanDefinitionException notFound) {
+		ServletContext servletContext = request.getServletContext();
+		Filter springSecurityFilterChain = getSpringSecurityFilterChain(servletContext);
+		if (springSecurityFilterChain == null) {
 			return null;
 			return null;
 		}
 		}
 		List<Filter> filters = (List<Filter>) ReflectionTestUtils
 		List<Filter> filters = (List<Filter>) ReflectionTestUtils
@@ -142,6 +135,26 @@ public abstract class WebTestUtils {
 		return null;
 		return null;
 	}
 	}
 
 
+	private static Filter getSpringSecurityFilterChain(ServletContext servletContext) {
+		Filter result = (Filter) servletContext
+				.getAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN);
+		if (result != null) {
+			return result;
+		}
+		WebApplicationContext webApplicationContext = WebApplicationContextUtils
+				.getWebApplicationContext(servletContext);
+		if (webApplicationContext != null) {
+			try {
+				return webApplicationContext.getBean(
+						AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME,
+						Filter.class);
+			}
+			catch (NoSuchBeanDefinitionException notFound) {
+			}
+		}
+		return null;
+	}
+
 	private WebTestUtils() {
 	private WebTestUtils() {
 	}
 	}
 }
 }

+ 20 - 0
test/src/test/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessorsCsrfTests.java

@@ -27,6 +27,7 @@ import javax.servlet.http.HttpSession;
 import org.junit.Before;
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runner.RunWith;
+
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Bean;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
@@ -34,6 +35,8 @@ import org.springframework.mock.web.MockHttpSession;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
+import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessorsCsrfTests.Config.TheController;
+import org.springframework.security.web.FilterChainProxy;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
 import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
 import org.springframework.test.context.web.WebAppConfiguration;
 import org.springframework.test.context.web.WebAppConfiguration;
@@ -58,6 +61,10 @@ import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.
 public class SecurityMockMvcRequestPostProcessorsCsrfTests {
 public class SecurityMockMvcRequestPostProcessorsCsrfTests {
 	@Autowired
 	@Autowired
 	WebApplicationContext wac;
 	WebApplicationContext wac;
+	@Autowired
+	TheController controller;
+	@Autowired
+	FilterChainProxy springSecurityFilterChain;
 
 
 	MockMvc mockMvc;
 	MockMvc mockMvc;
 
 
@@ -69,7 +76,20 @@ public class SecurityMockMvcRequestPostProcessorsCsrfTests {
 			.apply(springSecurity())
 			.apply(springSecurity())
 			.build();
 			.build();
 		// @formatter:on
 		// @formatter:on
+	}
+
+	// gh-3881
+	@Test
+	public void csrfWithStandalone() throws Exception {
+		// @formatter:off
+		this.mockMvc = MockMvcBuilders
+				.standaloneSetup(this.controller)
+				.apply(springSecurity(this.springSecurityFilterChain))
 				.build();
 				.build();
+		this.mockMvc.perform(post("/").with(csrf()))
+			.andExpect(status().is2xxSuccessful())
+			.andExpect(csrfAsParam());
+		// @formatter:on
 	}
 	}
 
 
 	@Test
 	@Test

+ 12 - 0
test/src/test/java/org/springframework/security/test/web/servlet/setup/SecurityMockMvcConfigurerTests.java

@@ -16,12 +16,15 @@
 package org.springframework.security.test.web.servlet.setup;
 package org.springframework.security.test.web.servlet.setup;
 
 
 import javax.servlet.Filter;
 import javax.servlet.Filter;
+import javax.servlet.ServletContext;
 
 
+import org.junit.Before;
 import org.junit.Test;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.Mock;
 import org.mockito.runners.MockitoJUnitRunner;
 import org.mockito.runners.MockitoJUnitRunner;
 
 
+import org.springframework.security.config.BeanIds;
 import org.springframework.test.web.servlet.setup.ConfigurableMockMvcBuilder;
 import org.springframework.test.web.servlet.setup.ConfigurableMockMvcBuilder;
 import org.springframework.web.context.WebApplicationContext;
 import org.springframework.web.context.WebApplicationContext;
 
 
@@ -40,6 +43,13 @@ public class SecurityMockMvcConfigurerTests {
 	private ConfigurableMockMvcBuilder<?> builder;
 	private ConfigurableMockMvcBuilder<?> builder;
 	@Mock
 	@Mock
 	private WebApplicationContext context;
 	private WebApplicationContext context;
+	@Mock
+	private ServletContext servletContext;
+
+	@Before
+	public void setup() {
+		when(this.context.getServletContext()).thenReturn(this.servletContext);
+	}
 
 
 	@Test
 	@Test
 	public void beforeMockMvcCreatedOverrideBean() throws Exception {
 	public void beforeMockMvcCreatedOverrideBean() throws Exception {
@@ -49,6 +59,8 @@ public class SecurityMockMvcConfigurerTests {
 		configurer.beforeMockMvcCreated(this.builder, this.context);
 		configurer.beforeMockMvcCreated(this.builder, this.context);
 
 
 		verify(this.builder).addFilters(this.filter);
 		verify(this.builder).addFilters(this.filter);
+		verify(this.servletContext).setAttribute(BeanIds.SPRING_SECURITY_FILTER_CHAIN,
+				this.filter);
 	}
 	}
 
 
 	@Test
 	@Test

+ 42 - 0
test/src/test/java/org/springframework/security/test/web/support/WebTestUtilsTests.java

@@ -25,14 +25,19 @@ import org.mockito.runners.MockitoJUnitRunner;
 import org.springframework.context.ConfigurableApplicationContext;
 import org.springframework.context.ConfigurableApplicationContext;
 import org.springframework.context.annotation.Configuration;
 import org.springframework.context.annotation.Configuration;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.security.config.BeanIds;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
 import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter;
+import org.springframework.security.web.DefaultSecurityFilterChain;
+import org.springframework.security.web.FilterChainProxy;
 import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
 import org.springframework.security.web.context.HttpSessionSecurityContextRepository;
 import org.springframework.security.web.context.SecurityContextPersistenceFilter;
 import org.springframework.security.web.context.SecurityContextPersistenceFilter;
 import org.springframework.security.web.context.SecurityContextRepository;
 import org.springframework.security.web.context.SecurityContextRepository;
+import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
 import org.springframework.security.web.csrf.CsrfTokenRepository;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
 import org.springframework.security.web.csrf.HttpSessionCsrfTokenRepository;
+import org.springframework.security.web.util.matcher.AnyRequestMatcher;
 import org.springframework.web.context.WebApplicationContext;
 import org.springframework.web.context.WebApplicationContext;
 import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
 import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
 
 
@@ -129,6 +134,34 @@ public class WebTestUtilsTests {
 				SecurityContextPersistenceFilter.class)).isNull();
 				SecurityContextPersistenceFilter.class)).isNull();
 	}
 	}
 
 
+	@Test
+	public void findFilterNoSpringSecurityFilterChainInContext() {
+		loadConfig(NoSecurityConfig.class);
+
+		CsrfFilter toFind = new CsrfFilter(new HttpSessionCsrfTokenRepository());
+		FilterChainProxy springSecurityFilterChain = new FilterChainProxy(
+				new DefaultSecurityFilterChain(AnyRequestMatcher.INSTANCE, toFind));
+		this.request.getServletContext().setAttribute(
+				BeanIds.SPRING_SECURITY_FILTER_CHAIN, springSecurityFilterChain);
+
+		assertThat(WebTestUtils.findFilter(this.request, toFind.getClass()))
+				.isEqualTo(toFind);
+	}
+
+	@Test
+	public void findFilterExplicitWithSecurityFilterInContext() {
+		loadConfig(SecurityConfigWithDefaults.class);
+
+		CsrfFilter toFind = new CsrfFilter(new HttpSessionCsrfTokenRepository());
+		FilterChainProxy springSecurityFilterChain = new FilterChainProxy(
+				new DefaultSecurityFilterChain(AnyRequestMatcher.INSTANCE, toFind));
+		this.request.getServletContext().setAttribute(
+				BeanIds.SPRING_SECURITY_FILTER_CHAIN, springSecurityFilterChain);
+
+		assertThat(WebTestUtils.findFilter(this.request, toFind.getClass()))
+				.isSameAs(toFind);
+	}
+
 	private void loadConfig(Class<?> config) {
 	private void loadConfig(Class<?> config) {
 		AnnotationConfigWebApplicationContext context = new AnnotationConfigWebApplicationContext();
 		AnnotationConfigWebApplicationContext context = new AnnotationConfigWebApplicationContext();
 		context.register(config);
 		context.register(config);
@@ -180,4 +213,13 @@ public class WebTestUtilsTests {
 		}
 		}
 		// @formatter:on
 		// @formatter:on
 	}
 	}
+
+	@Configuration
+	static class NoSecurityConfig {
+	}
+
+	@EnableWebSecurity
+	static class SecurityConfigWithDefaults extends WebSecurityConfigurerAdapter {
+
+	}
 }
 }