ソースを参照

Pick MvcRequestMatcher for MockMvc requests

Closes gh-13849
Josh Cummings 1 年間 前
コミット
6aabd768a8

+ 54 - 24
config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java

@@ -44,11 +44,13 @@ import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.AnyRequestMatcher;
 import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher;
+import org.springframework.security.web.util.matcher.OrRequestMatcher;
 import org.springframework.security.web.util.matcher.RegexRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
 import org.springframework.util.ClassUtils;
 import org.springframework.web.context.WebApplicationContext;
+import org.springframework.web.servlet.DispatcherServlet;
 import org.springframework.web.servlet.handler.HandlerMappingIntrospector;
 
 /**
@@ -335,10 +337,10 @@ public abstract class AbstractRequestMatcherRegistry<C> {
 	private RequestMatcher resolve(AntPathRequestMatcher ant, MvcRequestMatcher mvc, ServletContext servletContext) {
 		Map<String, ? extends ServletRegistration> registrations = mappableServletRegistrations(servletContext);
 		if (registrations.isEmpty()) {
-			return ant;
+			return new DispatcherServletDelegatingRequestMatcher(ant, mvc, new MockMvcRequestMatcher());
 		}
 		if (!hasDispatcherServlet(registrations)) {
-			return ant;
+			return new DispatcherServletDelegatingRequestMatcher(ant, mvc, new MockMvcRequestMatcher());
 		}
 		ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations);
 		if (dispatcherServlet != null) {
@@ -605,27 +607,70 @@ public abstract class AbstractRequestMatcherRegistry<C> {
 
 	}
 
+	static class MockMvcRequestMatcher implements RequestMatcher {
+
+		@Override
+		public boolean matches(HttpServletRequest request) {
+			return request.getAttribute("org.springframework.test.web.servlet.MockMvc.MVC_RESULT_ATTRIBUTE") != null;
+		}
+
+	}
+
+	static class DispatcherServletRequestMatcher implements RequestMatcher {
+
+		private final ServletContext servletContext;
+
+		DispatcherServletRequestMatcher(ServletContext servletContext) {
+			this.servletContext = servletContext;
+		}
+
+		@Override
+		public boolean matches(HttpServletRequest request) {
+			String name = request.getHttpServletMapping().getServletName();
+			ServletRegistration registration = this.servletContext.getServletRegistration(name);
+			Assert.notNull(name, "Failed to find servlet [" + name + "] in the servlet context");
+			try {
+				Class<?> clazz = Class.forName(registration.getClassName());
+				return DispatcherServlet.class.isAssignableFrom(clazz);
+			}
+			catch (ClassNotFoundException ex) {
+				return false;
+			}
+		}
+
+	}
+
 	static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {
 
 		private final AntPathRequestMatcher ant;
 
 		private final MvcRequestMatcher mvc;
 
-		private final ServletContext servletContext;
+		private final RequestMatcher dispatcherServlet;
 
 		DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc,
 				ServletContext servletContext) {
+			this(ant, mvc, new OrRequestMatcher(new MockMvcRequestMatcher(),
+					new DispatcherServletRequestMatcher(servletContext)));
+		}
+
+		DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc,
+				RequestMatcher dispatcherServlet) {
 			this.ant = ant;
 			this.mvc = mvc;
-			this.servletContext = servletContext;
+			this.dispatcherServlet = dispatcherServlet;
+		}
+
+		RequestMatcher requestMatcher(HttpServletRequest request) {
+			if (this.dispatcherServlet.matches(request)) {
+				return this.mvc;
+			}
+			return this.ant;
 		}
 
 		@Override
 		public boolean matches(HttpServletRequest request) {
-			String name = request.getHttpServletMapping().getServletName();
-			ServletRegistration registration = this.servletContext.getServletRegistration(name);
-			Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
-			if (isDispatcherServlet(registration)) {
+			if (this.dispatcherServlet.matches(request)) {
 				return this.mvc.matches(request);
 			}
 			return this.ant.matches(request);
@@ -633,27 +678,12 @@ public abstract class AbstractRequestMatcherRegistry<C> {
 
 		@Override
 		public MatchResult matcher(HttpServletRequest request) {
-			String name = request.getHttpServletMapping().getServletName();
-			ServletRegistration registration = this.servletContext.getServletRegistration(name);
-			Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
-			if (isDispatcherServlet(registration)) {
+			if (this.dispatcherServlet.matches(request)) {
 				return this.mvc.matcher(request);
 			}
 			return this.ant.matcher(request);
 		}
 
-		private boolean isDispatcherServlet(ServletRegistration registration) {
-			Class<?> dispatcherServlet = ClassUtils
-				.resolveClassName("org.springframework.web.servlet.DispatcherServlet", null);
-			try {
-				Class<?> clazz = Class.forName(registration.getClassName());
-				return dispatcherServlet.isAssignableFrom(clazz);
-			}
-			catch (ClassNotFoundException ex) {
-				return false;
-			}
-		}
-
 		@Override
 		public String toString() {
 			return "DispatcherServletDelegating [" + "ant = " + this.ant + ", mvc = " + this.mvc + "]";

+ 65 - 3
config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java

@@ -30,27 +30,35 @@ import org.junit.jupiter.api.Test;
 
 import org.springframework.beans.factory.NoSuchBeanDefinitionException;
 import org.springframework.context.ApplicationContext;
+import org.springframework.context.annotation.Configuration;
 import org.springframework.http.HttpMethod;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.security.config.MockServletContext;
 import org.springframework.security.config.TestMockHttpServletMappings;
 import org.springframework.security.config.annotation.ObjectPostProcessor;
 import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry.DispatcherServletDelegatingRequestMatcher;
+import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
+import org.springframework.security.config.test.SpringTestContext;
 import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher;
 import org.springframework.security.web.util.matcher.RegexRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
+import org.springframework.test.web.servlet.MockMvc;
+import org.springframework.test.web.servlet.setup.MockMvcBuilders;
 import org.springframework.web.context.WebApplicationContext;
 import org.springframework.web.servlet.DispatcherServlet;
+import org.springframework.web.servlet.config.annotation.EnableWebMvc;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.assertj.core.api.InstanceOfAssertFactories.type;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoInteractions;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
+import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get;
 
 /**
  * Tests for {@link AbstractRequestMatcherRegistry}.
@@ -206,18 +214,65 @@ public class AbstractRequestMatcherRegistryTests {
 		mockMvcIntrospector(true);
 		MockServletContext servletContext = new MockServletContext();
 		given(this.context.getServletContext()).willReturn(servletContext);
+		MockHttpServletRequest request = new MockHttpServletRequest();
+		List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
+		assertThat(requestMatchers).isNotEmpty();
+		assertThat(requestMatchers).hasSize(1);
+		assertThat(requestMatchers.get(0)).asInstanceOf(type(DispatcherServletDelegatingRequestMatcher.class))
+			.extracting((matcher) -> matcher.requestMatcher(request))
+			.isInstanceOf(AntPathRequestMatcher.class);
 		servletContext.addServlet("servletOne", Servlet.class).addMapping("/one");
 		servletContext.addServlet("servletTwo", Servlet.class).addMapping("/two");
-		List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
+		requestMatchers = this.matcherRegistry.requestMatchers("/**");
 		assertThat(requestMatchers).isNotEmpty();
 		assertThat(requestMatchers).hasSize(1);
-		assertThat(requestMatchers.get(0)).isExactlyInstanceOf(AntPathRequestMatcher.class);
+		assertThat(requestMatchers.get(0)).asInstanceOf(type(DispatcherServletDelegatingRequestMatcher.class))
+			.extracting((matcher) -> matcher.requestMatcher(request))
+			.isInstanceOf(AntPathRequestMatcher.class);
 		servletContext.addServlet("servletOne", Servlet.class);
 		servletContext.addServlet("servletTwo", Servlet.class);
 		requestMatchers = this.matcherRegistry.requestMatchers("/**");
 		assertThat(requestMatchers).isNotEmpty();
 		assertThat(requestMatchers).hasSize(1);
-		assertThat(requestMatchers.get(0)).isExactlyInstanceOf(AntPathRequestMatcher.class);
+		assertThat(requestMatchers.get(0)).asInstanceOf(type(DispatcherServletDelegatingRequestMatcher.class))
+			.extracting((matcher) -> matcher.requestMatcher(request))
+			.isInstanceOf(AntPathRequestMatcher.class);
+	}
+
+	// gh-14418
+	@Test
+	public void requestMatchersWhenNoDispatcherServletMockMvcThenMvcRequestMatcherType() throws Exception {
+		MockServletContext servletContext = new MockServletContext();
+		try (SpringTestContext spring = new SpringTestContext(this)) {
+			spring.register(MockMvcConfiguration.class)
+				.postProcessor((context) -> context.setServletContext(servletContext))
+				.autowire();
+			this.matcherRegistry.setApplicationContext(spring.getContext());
+			MockMvc mvc = MockMvcBuilders.webAppContextSetup(spring.getContext()).build();
+			MockHttpServletRequest request = mvc.perform(get("/")).andReturn().getRequest();
+			List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
+			assertThat(requestMatchers).isNotEmpty();
+			assertThat(requestMatchers).hasSize(1);
+			assertThat(requestMatchers.get(0)).asInstanceOf(type(DispatcherServletDelegatingRequestMatcher.class))
+				.extracting((matcher) -> matcher.requestMatcher(request))
+				.isInstanceOf(MvcRequestMatcher.class);
+			servletContext.addServlet("servletOne", Servlet.class).addMapping("/one");
+			servletContext.addServlet("servletTwo", Servlet.class).addMapping("/two");
+			requestMatchers = this.matcherRegistry.requestMatchers("/**");
+			assertThat(requestMatchers).isNotEmpty();
+			assertThat(requestMatchers).hasSize(1);
+			assertThat(requestMatchers.get(0)).asInstanceOf(type(DispatcherServletDelegatingRequestMatcher.class))
+				.extracting((matcher) -> matcher.requestMatcher(request))
+				.isInstanceOf(MvcRequestMatcher.class);
+			servletContext.addServlet("servletOne", Servlet.class);
+			servletContext.addServlet("servletTwo", Servlet.class);
+			requestMatchers = this.matcherRegistry.requestMatchers("/**");
+			assertThat(requestMatchers).isNotEmpty();
+			assertThat(requestMatchers).hasSize(1);
+			assertThat(requestMatchers.get(0)).asInstanceOf(type(DispatcherServletDelegatingRequestMatcher.class))
+				.extracting((matcher) -> matcher.requestMatcher(request))
+				.isInstanceOf(MvcRequestMatcher.class);
+		}
 	}
 
 	@Test
@@ -398,4 +453,11 @@ public class AbstractRequestMatcherRegistryTests {
 
 	}
 
+	@Configuration
+	@EnableWebSecurity
+	@EnableWebMvc
+	static class MockMvcConfiguration {
+
+	}
+
 }

+ 1 - 0
etc/checkstyle/checkstyle.xml

@@ -20,6 +20,7 @@
 		<property name="avoidStaticImportExcludes" value="org.springframework.security.web.csrf.CsrfTokenAssert.*" />
 		<property name="avoidStaticImportExcludes" value="org.springframework.security.web.util.matcher.AntPathRequestMatcher.*" />
 		<property name="avoidStaticImportExcludes" value="org.springframework.security.web.util.matcher.RegexRequestMatcher.*" />
+		<property name="avoidStaticImportExcludes" value="org.assertj.core.api.InstanceOfAssertFactories.*"/>
 	</module>
 	<module name="com.puppycrawl.tools.checkstyle.TreeWalker">
  		<module name="com.puppycrawl.tools.checkstyle.checks.regexp.RegexpSinglelineJavaCheck">