Browse Source

Ignore Unmappable Servlets

Closes gh-13666
Josh Cummings 2 years ago
parent
commit
ed96e2cddf

+ 13 - 2
config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java

@@ -18,6 +18,7 @@ package org.springframework.security.config.annotation.web;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 
@@ -312,8 +313,8 @@ public abstract class AbstractRequestMatcherRegistry<C> {
 		if (servletContext == null) {
 			return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
 		}
-		Map<String, ? extends ServletRegistration> registrations = servletContext.getServletRegistrations();
-		if (registrations == null) {
+		Map<String, ? extends ServletRegistration> registrations = mappableServletRegistrations(servletContext);
+		if (registrations.isEmpty()) {
 			return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
 		}
 		if (!hasDispatcherServlet(registrations)) {
@@ -324,6 +325,16 @@ public abstract class AbstractRequestMatcherRegistry<C> {
 		return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
 	}
 
+	private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) {
+		Map<String, ServletRegistration> mappable = new LinkedHashMap<>();
+		for (Map.Entry<String, ? extends ServletRegistration> entry : servletContext.getServletRegistrations().entrySet()) {
+			if (!entry.getValue().getMappings().isEmpty()) {
+				mappable.put(entry.getKey(), entry.getValue());
+			}
+		}
+		return mappable;
+	}
+
 	private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration> registrations) {
 		if (registrations == null) {
 			return false;

+ 8 - 3
config/src/test/java/org/springframework/security/config/MockServletContext.java

@@ -16,8 +16,10 @@
 
 package org.springframework.security.config;
 
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Set;
 
@@ -35,7 +37,7 @@ public class MockServletContext extends org.springframework.mock.web.MockServlet
 
 	public static MockServletContext mvc() {
 		MockServletContext servletContext = new MockServletContext();
-		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
+		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
 		return servletContext;
 	}
 
@@ -59,6 +61,8 @@ public class MockServletContext extends org.springframework.mock.web.MockServlet
 
 		private final Class<?> clazz;
 
+		private final Set<String> mappings = new LinkedHashSet<>();
+
 		MockServletRegistration(String name, Class<?> clazz) {
 			this.name = name;
 			this.clazz = clazz;
@@ -91,12 +95,13 @@ public class MockServletContext extends org.springframework.mock.web.MockServlet
 
 		@Override
 		public Set<String> addMapping(String... urlPatterns) {
-			return null;
+			this.mappings.addAll(Arrays.asList(urlPatterns));
+			return this.mappings;
 		}
 
 		@Override
 		public Collection<String> getMappings() {
-			return null;
+			return this.mappings;
 		}
 
 		@Override

+ 13 - 2
config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java

@@ -211,12 +211,23 @@ public class AbstractRequestMatcherRegistryTests {
 	public void requestMatchersWhenAmbiguousServletsThenException() {
 		MockServletContext servletContext = new MockServletContext();
 		given(this.context.getServletContext()).willReturn(servletContext);
-		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
-		servletContext.addServlet("servletTwo", Servlet.class);
+		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
+		servletContext.addServlet("servletTwo", Servlet.class).addMapping("/servlet/**");
 		assertThatExceptionOfType(IllegalArgumentException.class)
 				.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
 	}
 
+	@Test
+	public void requestMatchersWhenUnmappableServletsThenSkips() {
+		MockServletContext servletContext = new MockServletContext();
+		given(this.context.getServletContext()).willReturn(servletContext);
+		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
+		servletContext.addServlet("servletTwo", Servlet.class);
+		List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
+		assertThat(requestMatchers).hasSize(1);
+		assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class);
+	}
+
 	private void mockMvcIntrospector(boolean isPresent) {
 		ApplicationContext context = this.matcherRegistry.getApplicationContext();
 		given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent);