2
0
Эх сурвалжийг харах

Merge branch '6.0.x' into 6.1.x

Closes gh-13722
Josh Cummings 2 жил өмнө
parent
commit
0df1884372

+ 32 - 4
config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java

@@ -18,6 +18,8 @@ package org.springframework.security.config.annotation.web;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
+import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 
@@ -194,18 +196,31 @@ public abstract class AbstractRequestMatcherRegistry<C> {
 		if (servletContext == null) {
 			return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
 		}
-		Map<String, ? extends ServletRegistration> registrations = servletContext.getServletRegistrations();
-		if (registrations == null) {
+		Map<String, ? extends ServletRegistration> registrations = mappableServletRegistrations(servletContext);
+		if (registrations.isEmpty()) {
 			return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
 		}
 		if (!hasDispatcherServlet(registrations)) {
 			return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
 		}
-		Assert.isTrue(registrations.size() == 1,
-				"This method cannot decide whether these patterns are Spring MVC patterns or not. If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); otherwise, please use requestMatchers(AntPathRequestMatcher).");
+		if (registrations.size() > 1) {
+			String errorMessage = computeErrorMessage(registrations.values());
+			throw new IllegalArgumentException(errorMessage);
+		}
 		return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
 	}
 
+	private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) {
+		Map<String, ServletRegistration> mappable = new LinkedHashMap<>();
+		for (Map.Entry<String, ? extends ServletRegistration> entry : servletContext.getServletRegistrations()
+				.entrySet()) {
+			if (!entry.getValue().getMappings().isEmpty()) {
+				mappable.put(entry.getKey(), entry.getValue());
+			}
+		}
+		return mappable;
+	}
+
 	private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration> registrations) {
 		if (registrations == null) {
 			return false;
@@ -226,6 +241,19 @@ public abstract class AbstractRequestMatcherRegistry<C> {
 		return false;
 	}
 
+	private String computeErrorMessage(Collection<? extends ServletRegistration> registrations) {
+		String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
+				+ "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); "
+				+ "otherwise, please use requestMatchers(AntPathRequestMatcher).\n\n"
+				+ "This is because there is more than one mappable servlet in your servlet context: %s.\n\n"
+				+ "For each MvcRequestMatcher, call MvcRequestMatcher#setServletPath to indicate the servlet path.";
+		Map<String, Collection<String>> mappings = new LinkedHashMap<>();
+		for (ServletRegistration registration : registrations) {
+			mappings.put(registration.getClassName(), registration.getMappings());
+		}
+		return String.format(template, mappings);
+	}
+
 	/**
 	 * <p>
 	 * If the {@link HandlerMappingIntrospector} is available in the classpath, maps to an

+ 8 - 3
config/src/test/java/org/springframework/security/config/MockServletContext.java

@@ -16,8 +16,10 @@
 
 package org.springframework.security.config;
 
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
 import java.util.Map;
 import java.util.Set;
 
@@ -35,7 +37,7 @@ public class MockServletContext extends org.springframework.mock.web.MockServlet
 
 	public static MockServletContext mvc() {
 		MockServletContext servletContext = new MockServletContext();
-		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
+		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
 		return servletContext;
 	}
 
@@ -59,6 +61,8 @@ public class MockServletContext extends org.springframework.mock.web.MockServlet
 
 		private final Class<?> clazz;
 
+		private final Set<String> mappings = new LinkedHashSet<>();
+
 		MockServletRegistration(String name, Class<?> clazz) {
 			this.name = name;
 			this.clazz = clazz;
@@ -91,12 +95,13 @@ public class MockServletContext extends org.springframework.mock.web.MockServlet
 
 		@Override
 		public Set<String> addMapping(String... urlPatterns) {
-			return null;
+			this.mappings.addAll(Arrays.asList(urlPatterns));
+			return this.mappings;
 		}
 
 		@Override
 		public Collection<String> getMappings() {
-			return null;
+			return this.mappings;
 		}
 
 		@Override

+ 14 - 2
config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java

@@ -174,12 +174,24 @@ public class AbstractRequestMatcherRegistryTests {
 	public void requestMatchersWhenAmbiguousServletsThenException() {
 		MockServletContext servletContext = new MockServletContext();
 		given(this.context.getServletContext()).willReturn(servletContext);
-		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class);
-		servletContext.addServlet("servletTwo", Servlet.class);
+		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
+		servletContext.addServlet("servletTwo", Servlet.class).addMapping("/servlet/**");
 		assertThatExceptionOfType(IllegalArgumentException.class)
 				.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
 	}
 
+	@Test
+	public void requestMatchersWhenUnmappableServletsThenSkips() {
+		mockMvcIntrospector(true);
+		MockServletContext servletContext = new MockServletContext();
+		given(this.context.getServletContext()).willReturn(servletContext);
+		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
+		servletContext.addServlet("servletTwo", Servlet.class);
+		List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
+		assertThat(requestMatchers).hasSize(1);
+		assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class);
+	}
+
 	private void mockMvcIntrospector(boolean isPresent) {
 		ApplicationContext context = this.matcherRegistry.getApplicationContext();
 		given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent);