Parcourir la source

Improve RequestMatcher Validation

Closes gh-13551
Josh Cummings il y a 2 ans
Parent
commit
df239b6448

+ 53 - 12
config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java

@@ -19,8 +19,11 @@ package org.springframework.security.config.annotation.web;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
+import java.util.Map;
 
 import javax.servlet.DispatcherType;
+import javax.servlet.ServletContext;
+import javax.servlet.ServletRegistration;
 
 import org.springframework.beans.factory.NoSuchBeanDefinitionException;
 import org.springframework.context.ApplicationContext;
@@ -36,6 +39,7 @@ import org.springframework.security.web.util.matcher.RegexRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
 import org.springframework.util.ClassUtils;
+import org.springframework.web.context.WebApplicationContext;
 import org.springframework.web.servlet.handler.HandlerMappingIntrospector;
 
 /**
@@ -297,14 +301,47 @@ public abstract class AbstractRequestMatcherRegistry<C> {
 	 * @since 5.8
 	 */
 	public C requestMatchers(HttpMethod method, String... patterns) {
-		List<RequestMatcher> matchers = new ArrayList<>();
-		if (mvcPresent) {
-			matchers.addAll(createMvcMatchers(method, patterns));
+		if (!mvcPresent) {
+			return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
+		}
+		if (!(this.context instanceof WebApplicationContext)) {
+			return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
+		}
+		WebApplicationContext context = (WebApplicationContext) this.context;
+		ServletContext servletContext = context.getServletContext();
+		if (servletContext == null) {
+			return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
+		}
+		Map<String, ? extends ServletRegistration> registrations = servletContext.getServletRegistrations();
+		if (registrations == null) {
+			return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
+		}
+		if (!hasDispatcherServlet(registrations)) {
+			return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
 		}
-		else {
-			matchers.addAll(RequestMatchers.antMatchers(method, patterns));
+		Assert.isTrue(registrations.size() == 1,
+				"This method cannot decide whether these patterns are Spring MVC patterns or not. If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); otherwise, please use requestMatchers(AntPathRequestMatcher).");
+		return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
+	}
+
+	private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration> registrations) {
+		if (registrations == null) {
+			return false;
+		}
+		Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
+				null);
+		for (ServletRegistration registration : registrations.values()) {
+			try {
+				Class<?> clazz = Class.forName(registration.getClassName());
+				if (dispatcherServlet.isAssignableFrom(clazz)) {
+					return true;
+				}
+			}
+			catch (ClassNotFoundException ex) {
+				return false;
+			}
 		}
-		return requestMatchers(matchers.toArray(new RequestMatcher[0]));
+		return false;
 	}
 
 	/**
@@ -380,12 +417,7 @@ public abstract class AbstractRequestMatcherRegistry<C> {
 		 * @return a {@link List} of {@link AntPathRequestMatcher} instances
 		 */
 		static List<RequestMatcher> antMatchers(HttpMethod httpMethod, String... antPatterns) {
-			String method = (httpMethod != null) ? httpMethod.toString() : null;
-			List<RequestMatcher> matchers = new ArrayList<>();
-			for (String pattern : antPatterns) {
-				matchers.add(new AntPathRequestMatcher(pattern, method));
-			}
-			return matchers;
+			return Arrays.asList(antMatchersAsArray(httpMethod, antPatterns));
 		}
 
 		/**
@@ -399,6 +431,15 @@ public abstract class AbstractRequestMatcherRegistry<C> {
 			return antMatchers(null, antPatterns);
 		}
 
+		static RequestMatcher[] antMatchersAsArray(HttpMethod httpMethod, String... antPatterns) {
+			String method = (httpMethod != null) ? httpMethod.toString() : null;
+			RequestMatcher[] matchers = new RequestMatcher[antPatterns.length];
+			for (int index = 0; index < antPatterns.length; index++) {
+				matchers[index] = new AntPathRequestMatcher(antPatterns[index], method);
+			}
+			return matchers;
+		}
+
 		/**
 		 * Create a {@link List} of {@link RegexRequestMatcher} instances.
 		 * @param httpMethod the {@link HttpMethod} to use or {@code null} for any

+ 63 - 3
config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java

@@ -18,10 +18,16 @@ package org.springframework.security.config.annotation.web;
 
 import java.lang.reflect.Field;
 import java.lang.reflect.Modifier;
+import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Map;
 
 import javax.servlet.DispatcherType;
+import javax.servlet.Servlet;
+import javax.servlet.ServletContext;
+import javax.servlet.ServletRegistration;
 
+import org.jetbrains.annotations.NotNull;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
@@ -34,6 +40,8 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher;
 import org.springframework.security.web.util.matcher.RegexRequestMatcher;
 import org.springframework.security.web.util.matcher.RequestMatcher;
+import org.springframework.web.context.WebApplicationContext;
+import org.springframework.web.servlet.DispatcherServlet;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@@ -56,12 +64,17 @@ public class AbstractRequestMatcherRegistryTests {
 
 	private TestRequestMatcherRegistry matcherRegistry;
 
+	private WebApplicationContext context;
+
 	@BeforeEach
 	public void setUp() {
 		this.matcherRegistry = new TestRequestMatcherRegistry();
-		ApplicationContext context = mock(ApplicationContext.class);
-		given(context.getBean(ObjectPostProcessor.class)).willReturn(NO_OP_OBJECT_POST_PROCESSOR);
-		this.matcherRegistry.setApplicationContext(context);
+		this.context = mock(WebApplicationContext.class);
+		ServletContext servletContext = new MockServletContext();
+		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
+		given(this.context.getBean(ObjectPostProcessor.class)).willReturn(NO_OP_OBJECT_POST_PROCESSOR);
+		given(this.context.getServletContext()).willReturn(servletContext);
+		this.matcherRegistry.setApplicationContext(this.context);
 	}
 
 	@Test
@@ -184,6 +197,32 @@ public class AbstractRequestMatcherRegistryTests {
 						"Please ensure Spring Security & Spring MVC are configured in a shared ApplicationContext");
 	}
 
+	@Test
+	public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() {
+		MockServletContext servletContext = new MockServletContext();
+		given(this.context.getServletContext()).willReturn(servletContext);
+		List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
+		assertThat(requestMatchers).isNotEmpty();
+		assertThat(requestMatchers).hasSize(1);
+		assertThat(requestMatchers.get(0)).isExactlyInstanceOf(AntPathRequestMatcher.class);
+		servletContext.addServlet("servletOne", Servlet.class);
+		servletContext.addServlet("servletTwo", Servlet.class);
+		requestMatchers = this.matcherRegistry.requestMatchers("/**");
+		assertThat(requestMatchers).isNotEmpty();
+		assertThat(requestMatchers).hasSize(1);
+		assertThat(requestMatchers.get(0)).isExactlyInstanceOf(AntPathRequestMatcher.class);
+	}
+
+	@Test
+	public void requestMatchersWhenAmbiguousServletsThenException() {
+		MockServletContext servletContext = new MockServletContext();
+		given(this.context.getServletContext()).willReturn(servletContext);
+		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
+		servletContext.addServlet("servletTwo", Servlet.class);
+		assertThatExceptionOfType(IllegalArgumentException.class)
+				.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
+	}
+
 	private void mockMvcIntrospector(boolean isPresent) {
 		ApplicationContext context = this.matcherRegistry.getApplicationContext();
 		given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent);
@@ -217,4 +256,25 @@ public class AbstractRequestMatcherRegistryTests {
 
 	}
 
+	private static class MockServletContext extends org.springframework.mock.web.MockServletContext {
+
+		private final Map<String, ServletRegistration> registrations = new LinkedHashMap<>();
+
+		@NotNull
+		@Override
+		public ServletRegistration.Dynamic addServlet(@NotNull String servletName, Class<? extends Servlet> clazz) {
+			ServletRegistration.Dynamic dynamic = mock(ServletRegistration.Dynamic.class);
+			given(dynamic.getClassName()).willReturn(clazz.getName());
+			this.registrations.put(servletName, dynamic);
+			return dynamic;
+		}
+
+		@NotNull
+		@Override
+		public Map<String, ? extends ServletRegistration> getServletRegistrations() {
+			return this.registrations;
+		}
+
+	}
+
 }

+ 32 - 2
config/src/test/java/org/springframework/security/config/annotation/web/configurers/HttpSecuritySecurityMatchersTests.java

@@ -18,9 +18,14 @@ package org.springframework.security.config.annotation.web.configurers;
 
 import java.lang.reflect.Field;
 import java.lang.reflect.Modifier;
+import java.util.LinkedHashMap;
+import java.util.Map;
 
+import javax.servlet.Servlet;
+import javax.servlet.ServletRegistration;
 import javax.servlet.http.HttpServletResponse;
 
+import org.jetbrains.annotations.NotNull;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
@@ -34,7 +39,6 @@ import org.springframework.core.annotation.Order;
 import org.springframework.mock.web.MockFilterChain;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
-import org.springframework.mock.web.MockServletContext;
 import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity;
@@ -48,12 +52,15 @@ import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher;
 import org.springframework.web.bind.annotation.RequestMapping;
 import org.springframework.web.bind.annotation.RestController;
 import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
+import org.springframework.web.servlet.DispatcherServlet;
 import org.springframework.web.servlet.config.annotation.EnableWebMvc;
 import org.springframework.web.servlet.config.annotation.PathMatchConfigurer;
 import org.springframework.web.servlet.config.annotation.WebMvcConfigurer;
 import org.springframework.web.servlet.handler.HandlerMappingIntrospector;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
 import static org.springframework.security.config.Customizer.withDefaults;
 
 /**
@@ -233,7 +240,9 @@ public class HttpSecuritySecurityMatchersTests {
 	public void loadConfig(Class<?>... configs) {
 		this.context = new AnnotationConfigWebApplicationContext();
 		this.context.register(configs);
-		this.context.setServletContext(new MockServletContext());
+		MockServletContext servletContext = new MockServletContext();
+		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
+		this.context.setServletContext(servletContext);
 		this.context.refresh();
 		this.context.getAutowireCapableBeanFactory().autowireBean(this);
 	}
@@ -564,4 +573,25 @@ public class HttpSecuritySecurityMatchersTests {
 
 	}
 
+	private static class MockServletContext extends org.springframework.mock.web.MockServletContext {
+
+		private final Map<String, ServletRegistration> registrations = new LinkedHashMap<>();
+
+		@NotNull
+		@Override
+		public ServletRegistration.Dynamic addServlet(@NotNull String servletName, Class<? extends Servlet> clazz) {
+			ServletRegistration.Dynamic dynamic = mock(ServletRegistration.Dynamic.class);
+			given(dynamic.getClassName()).willReturn(clazz.getName());
+			this.registrations.put(servletName, dynamic);
+			return dynamic;
+		}
+
+		@NotNull
+		@Override
+		public Map<String, ? extends ServletRegistration> getServletRegistrations() {
+			return this.registrations;
+		}
+
+	}
+
 }