Explorar o código

Merge branch '6.0.x' into 6.1.x

Closes gh-14085
Josh Cummings hai 1 ano
pai
achega
624dcafcf2

+ 130 - 15
config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@ import java.util.Map;
 import jakarta.servlet.DispatcherType;
 import jakarta.servlet.ServletContext;
 import jakarta.servlet.ServletRegistration;
+import jakarta.servlet.http.HttpServletRequest;
 
 import org.springframework.beans.factory.NoSuchBeanDefinitionException;
 import org.springframework.context.ApplicationContext;
@@ -203,11 +204,30 @@ public abstract class AbstractRequestMatcherRegistry<C> {
 		if (!hasDispatcherServlet(registrations)) {
 			return requestMatchers(RequestMatchers.antMatchersAsArray(method, patterns));
 		}
-		if (registrations.size() > 1) {
-			String errorMessage = computeErrorMessage(registrations.values());
-			throw new IllegalArgumentException(errorMessage);
+		ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations);
+		if (dispatcherServlet != null) {
+			if (registrations.size() == 1) {
+				return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
+			}
+			List<RequestMatcher> matchers = new ArrayList<>();
+			for (String pattern : patterns) {
+				AntPathRequestMatcher ant = new AntPathRequestMatcher(pattern, (method != null) ? method.name() : null);
+				MvcRequestMatcher mvc = createMvcMatchers(method, pattern).get(0);
+				matchers.add(new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext));
+			}
+			return requestMatchers(matchers.toArray(new RequestMatcher[0]));
 		}
-		return requestMatchers(createMvcMatchers(method, patterns).toArray(new RequestMatcher[0]));
+		dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations);
+		if (dispatcherServlet != null) {
+			String mapping = dispatcherServlet.getMappings().iterator().next();
+			List<MvcRequestMatcher> matchers = createMvcMatchers(method, patterns);
+			for (MvcRequestMatcher matcher : matchers) {
+				matcher.setServletPath(mapping.substring(0, mapping.length() - 2));
+			}
+			return requestMatchers(matchers.toArray(new RequestMatcher[0]));
+		}
+		String errorMessage = computeErrorMessage(registrations.values());
+		throw new IllegalArgumentException(errorMessage);
 	}
 
 	private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) {
@@ -225,22 +245,66 @@ public abstract class AbstractRequestMatcherRegistry<C> {
 		if (registrations == null) {
 			return false;
 		}
-		Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
-				null);
 		for (ServletRegistration registration : registrations.values()) {
-			try {
-				Class<?> clazz = Class.forName(registration.getClassName());
-				if (dispatcherServlet.isAssignableFrom(clazz)) {
-					return true;
-				}
-			}
-			catch (ClassNotFoundException ex) {
-				return false;
+			if (isDispatcherServlet(registration)) {
+				return true;
 			}
 		}
 		return false;
 	}
 
+	private ServletRegistration requireOneRootDispatcherServlet(
+			Map<String, ? extends ServletRegistration> registrations) {
+		ServletRegistration rootDispatcherServlet = null;
+		for (ServletRegistration registration : registrations.values()) {
+			if (!isDispatcherServlet(registration)) {
+				continue;
+			}
+			if (registration.getMappings().size() > 1) {
+				return null;
+			}
+			if (!"/".equals(registration.getMappings().iterator().next())) {
+				return null;
+			}
+			rootDispatcherServlet = registration;
+		}
+		return rootDispatcherServlet;
+	}
+
+	private ServletRegistration requireOnlyPathMappedDispatcherServlet(
+			Map<String, ? extends ServletRegistration> registrations) {
+		ServletRegistration pathDispatcherServlet = null;
+		for (ServletRegistration registration : registrations.values()) {
+			if (!isDispatcherServlet(registration)) {
+				return null;
+			}
+			if (registration.getMappings().size() > 1) {
+				return null;
+			}
+			String mapping = registration.getMappings().iterator().next();
+			if (!mapping.startsWith("/") || !mapping.endsWith("/*")) {
+				return null;
+			}
+			if (pathDispatcherServlet != null) {
+				return null;
+			}
+			pathDispatcherServlet = registration;
+		}
+		return pathDispatcherServlet;
+	}
+
+	private boolean isDispatcherServlet(ServletRegistration registration) {
+		Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
+				null);
+		try {
+			Class<?> clazz = Class.forName(registration.getClassName());
+			return dispatcherServlet.isAssignableFrom(clazz);
+		}
+		catch (ClassNotFoundException ex) {
+			return false;
+		}
+	}
+
 	private String computeErrorMessage(Collection<? extends ServletRegistration> registrations) {
 		String template = "This method cannot decide whether these patterns are Spring MVC patterns or not. "
 				+ "If this endpoint is a Spring MVC endpoint, please use requestMatchers(MvcRequestMatcher); "
@@ -380,4 +444,55 @@ public abstract class AbstractRequestMatcherRegistry<C> {
 
 	}
 
+	static class DispatcherServletDelegatingRequestMatcher implements RequestMatcher {
+
+		private final AntPathRequestMatcher ant;
+
+		private final MvcRequestMatcher mvc;
+
+		private final ServletContext servletContext;
+
+		DispatcherServletDelegatingRequestMatcher(AntPathRequestMatcher ant, MvcRequestMatcher mvc,
+				ServletContext servletContext) {
+			this.ant = ant;
+			this.mvc = mvc;
+			this.servletContext = servletContext;
+		}
+
+		@Override
+		public boolean matches(HttpServletRequest request) {
+			String name = request.getHttpServletMapping().getServletName();
+			ServletRegistration registration = this.servletContext.getServletRegistration(name);
+			Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
+			if (isDispatcherServlet(registration)) {
+				return this.mvc.matches(request);
+			}
+			return this.ant.matches(request);
+		}
+
+		@Override
+		public MatchResult matcher(HttpServletRequest request) {
+			String name = request.getHttpServletMapping().getServletName();
+			ServletRegistration registration = this.servletContext.getServletRegistration(name);
+			Assert.notNull(registration, "Failed to find servlet [" + name + "] in the servlet context");
+			if (isDispatcherServlet(registration)) {
+				return this.mvc.matcher(request);
+			}
+			return this.ant.matcher(request);
+		}
+
+		private boolean isDispatcherServlet(ServletRegistration registration) {
+			Class<?> dispatcherServlet = ClassUtils
+				.resolveClassName("org.springframework.web.servlet.DispatcherServlet", null);
+			try {
+				Class<?> clazz = Class.forName(registration.getClassName());
+				return dispatcherServlet.isAssignableFrom(clazz);
+			}
+			catch (ClassNotFoundException ex) {
+				return false;
+			}
+		}
+
+	}
+
 }

+ 5 - 0
config/src/test/java/org/springframework/security/config/MockServletContext.java

@@ -55,6 +55,11 @@ public class MockServletContext extends org.springframework.mock.web.MockServlet
 		return this.registrations;
 	}
 
+	@Override
+	public ServletRegistration getServletRegistration(String servletName) {
+		return this.registrations.get(servletName);
+	}
+
 	private static class MockServletRegistration implements ServletRegistration.Dynamic {
 
 		private final String name;

+ 46 - 0
config/src/test/java/org/springframework/security/config/TestMockHttpServletMappings.java

@@ -0,0 +1,46 @@
+/*
+ * Copyright 2002-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config;
+
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.MappingMatch;
+
+import org.springframework.mock.web.MockHttpServletMapping;
+
+public final class TestMockHttpServletMappings {
+
+	private TestMockHttpServletMappings() {
+
+	}
+
+	public static MockHttpServletMapping extension(HttpServletRequest request, String extension) {
+		String uri = request.getRequestURI();
+		String matchValue = uri.substring(0, uri.lastIndexOf(extension));
+		return new MockHttpServletMapping(matchValue, "*" + extension, "extension", MappingMatch.EXTENSION);
+	}
+
+	public static MockHttpServletMapping path(HttpServletRequest request, String path) {
+		String uri = request.getRequestURI();
+		String matchValue = uri.substring(path.length());
+		return new MockHttpServletMapping(matchValue, path + "/*", "path", MappingMatch.PATH);
+	}
+
+	public static MockHttpServletMapping defaultMapping() {
+		return new MockHttpServletMapping("", "/", "default", MappingMatch.DEFAULT);
+	}
+
+}

+ 111 - 2
config/src/test/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistryTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -20,14 +20,18 @@ import java.util.List;
 
 import jakarta.servlet.DispatcherType;
 import jakarta.servlet.Servlet;
+import jakarta.servlet.http.HttpServletMapping;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 
 import org.springframework.beans.factory.NoSuchBeanDefinitionException;
 import org.springframework.context.ApplicationContext;
 import org.springframework.http.HttpMethod;
+import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.security.config.MockServletContext;
+import org.springframework.security.config.TestMockHttpServletMappings;
 import org.springframework.security.config.annotation.ObjectPostProcessor;
+import org.springframework.security.config.annotation.web.AbstractRequestMatcherRegistry.DispatcherServletDelegatingRequestMatcher;
 import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
 import org.springframework.security.web.util.matcher.DispatcherTypeRequestMatcher;
@@ -40,6 +44,9 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+import static org.mockito.Mockito.verifyNoMoreInteractions;
 
 /**
  * Tests for {@link AbstractRequestMatcherRegistry}.
@@ -159,6 +166,8 @@ public class AbstractRequestMatcherRegistryTests {
 	public void requestMatchersWhenNoDispatcherServletThenAntPathRequestMatcherType() {
 		MockServletContext servletContext = new MockServletContext();
 		given(this.context.getServletContext()).willReturn(servletContext);
+		servletContext.addServlet("servletOne", Servlet.class).addMapping("/one");
+		servletContext.addServlet("servletTwo", Servlet.class).addMapping("/two");
 		List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
 		assertThat(requestMatchers).isNotEmpty();
 		assertThat(requestMatchers).hasSize(1);
@@ -176,7 +185,26 @@ public class AbstractRequestMatcherRegistryTests {
 		MockServletContext servletContext = new MockServletContext();
 		given(this.context.getServletContext()).willReturn(servletContext);
 		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
-		servletContext.addServlet("servletTwo", Servlet.class).addMapping("/servlet/**");
+		servletContext.addServlet("servletTwo", DispatcherServlet.class).addMapping("/servlet/*");
+		assertThatExceptionOfType(IllegalArgumentException.class)
+			.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
+	}
+
+	@Test
+	public void requestMatchersWhenMultipleDispatcherServletMappingsThenException() {
+		MockServletContext servletContext = new MockServletContext();
+		given(this.context.getServletContext()).willReturn(servletContext);
+		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/", "/mvc/*");
+		assertThatExceptionOfType(IllegalArgumentException.class)
+			.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
+	}
+
+	@Test
+	public void requestMatchersWhenPathDispatcherServletAndOtherServletsThenException() {
+		MockServletContext servletContext = new MockServletContext();
+		given(this.context.getServletContext()).willReturn(servletContext);
+		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*");
+		servletContext.addServlet("default", Servlet.class).addMapping("/");
 		assertThatExceptionOfType(IllegalArgumentException.class)
 			.isThrownBy(() -> this.matcherRegistry.requestMatchers("/**"));
 	}
@@ -193,6 +221,87 @@ public class AbstractRequestMatcherRegistryTests {
 		assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class);
 	}
 
+	@Test
+	public void requestMatchersWhenOnlyDispatcherServletThenAllows() {
+		mockMvcIntrospector(true);
+		MockServletContext servletContext = new MockServletContext();
+		given(this.context.getServletContext()).willReturn(servletContext);
+		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/mvc/*");
+		List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
+		assertThat(requestMatchers).hasSize(1);
+		assertThat(requestMatchers.get(0)).isInstanceOf(MvcRequestMatcher.class);
+	}
+
+	@Test
+	public void requestMatchersWhenImplicitServletsThenAllows() {
+		mockMvcIntrospector(true);
+		MockServletContext servletContext = new MockServletContext();
+		given(this.context.getServletContext()).willReturn(servletContext);
+		servletContext.addServlet("defaultServlet", Servlet.class);
+		servletContext.addServlet("jspServlet", Servlet.class).addMapping("*.jsp", "*.jspx");
+		servletContext.addServlet("dispatcherServlet", DispatcherServlet.class).addMapping("/");
+		List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/**");
+		assertThat(requestMatchers).hasSize(1);
+		assertThat(requestMatchers.get(0)).isInstanceOf(DispatcherServletDelegatingRequestMatcher.class);
+	}
+
+	@Test
+	public void requestMatchersWhenPathBasedNonDispatcherServletThenAllows() {
+		mockMvcIntrospector(true);
+		MockServletContext servletContext = new MockServletContext();
+		given(this.context.getServletContext()).willReturn(servletContext);
+		servletContext.addServlet("path", Servlet.class).addMapping("/services/*");
+		servletContext.addServlet("default", DispatcherServlet.class).addMapping("/");
+		List<RequestMatcher> requestMatchers = this.matcherRegistry.requestMatchers("/services/*");
+		assertThat(requestMatchers).hasSize(1);
+		assertThat(requestMatchers.get(0)).isInstanceOf(DispatcherServletDelegatingRequestMatcher.class);
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint") {
+			@Override
+			public HttpServletMapping getHttpServletMapping() {
+				return TestMockHttpServletMappings.defaultMapping();
+			}
+		};
+		assertThat(requestMatchers.get(0).matcher(request).isMatch()).isTrue();
+		request = new MockHttpServletRequest("GET", "/services/endpoint") {
+			@Override
+			public HttpServletMapping getHttpServletMapping() {
+				return TestMockHttpServletMappings.path(this, "/services");
+			}
+		};
+		request.setServletPath("/services");
+		request.setPathInfo("/endpoint");
+		assertThat(requestMatchers.get(0).matcher(request).isMatch()).isTrue();
+	}
+
+	@Test
+	public void matchesWhenDispatcherServletThenMvc() {
+		MockServletContext servletContext = new MockServletContext();
+		servletContext.addServlet("default", DispatcherServlet.class).addMapping("/");
+		servletContext.addServlet("path", Servlet.class).addMapping("/services/*");
+		MvcRequestMatcher mvc = mock(MvcRequestMatcher.class);
+		AntPathRequestMatcher ant = mock(AntPathRequestMatcher.class);
+		DispatcherServletDelegatingRequestMatcher requestMatcher = new DispatcherServletDelegatingRequestMatcher(ant,
+				mvc, servletContext);
+		MockHttpServletRequest request = new MockHttpServletRequest("GET", "/services/endpoint") {
+			@Override
+			public HttpServletMapping getHttpServletMapping() {
+				return TestMockHttpServletMappings.defaultMapping();
+			}
+		};
+		assertThat(requestMatcher.matches(request)).isFalse();
+		verify(mvc).matches(request);
+		verifyNoInteractions(ant);
+		request = new MockHttpServletRequest("GET", "/services/endpoint") {
+			@Override
+			public HttpServletMapping getHttpServletMapping() {
+				return TestMockHttpServletMappings.path(this, "/services");
+			}
+		};
+		assertThat(requestMatcher.matches(request)).isFalse();
+		verify(ant).matches(request);
+		verifyNoMoreInteractions(mvc);
+	}
+
 	private void mockMvcIntrospector(boolean isPresent) {
 		ApplicationContext context = this.matcherRegistry.getApplicationContext();
 		given(context.containsBean("mvcHandlerMappingIntrospector")).willReturn(isPresent);