فهرست منبع

Polish requestMatchers Logic

Issue gh-13551
Josh Cummings 7 ماه پیش
والد
کامیت
75a35793dc

+ 18 - 89
config/src/main/java/org/springframework/security/config/annotation/web/AbstractRequestMatcherRegistry.java

@@ -40,6 +40,7 @@ import org.springframework.core.ResolvableType;
 import org.springframework.http.HttpMethod;
 import org.springframework.lang.Nullable;
 import org.springframework.security.config.ObjectPostProcessor;
+import org.springframework.security.config.annotation.web.ServletRegistrationsSupport.RegistrationMapping;
 import org.springframework.security.config.annotation.web.configurers.AbstractConfigAttributeRequestMatcherRegistry;
 import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher;
 import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@@ -235,103 +236,31 @@ public abstract class AbstractRequestMatcherRegistry<C> {
 	}
 
 	private RequestMatcher resolve(AntPathRequestMatcher ant, MvcRequestMatcher mvc, ServletContext servletContext) {
-		Map<String, ? extends ServletRegistration> registrations = mappableServletRegistrations(servletContext);
-		if (registrations.isEmpty()) {
+		ServletRegistrationsSupport registrations = new ServletRegistrationsSupport(servletContext);
+		Collection<RegistrationMapping> mappings = registrations.mappings();
+		if (mappings.isEmpty()) {
 			return new DispatcherServletDelegatingRequestMatcher(ant, mvc, new MockMvcRequestMatcher());
 		}
-		if (!hasDispatcherServlet(registrations)) {
+		Collection<RegistrationMapping> dispatcherServletMappings = registrations.dispatcherServletMappings();
+		if (dispatcherServletMappings.isEmpty()) {
 			return new DispatcherServletDelegatingRequestMatcher(ant, mvc, new MockMvcRequestMatcher());
 		}
-		ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations);
-		if (dispatcherServlet != null) {
-			if (registrations.size() == 1) {
-				return mvc;
-			}
-			return new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext);
+		if (dispatcherServletMappings.size() > 1) {
+			String errorMessage = computeErrorMessage(servletContext.getServletRegistrations().values());
+			throw new IllegalArgumentException(errorMessage);
 		}
-		dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations);
-		if (dispatcherServlet != null) {
-			String mapping = dispatcherServlet.getMappings().iterator().next();
-			mvc.setServletPath(mapping.substring(0, mapping.length() - 2));
-			return mvc;
-		}
-		String errorMessage = computeErrorMessage(registrations.values());
-		throw new IllegalArgumentException(errorMessage);
-	}
-
-	private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) {
-		Map<String, ServletRegistration> mappable = new LinkedHashMap<>();
-		for (Map.Entry<String, ? extends ServletRegistration> entry : servletContext.getServletRegistrations()
-			.entrySet()) {
-			if (!entry.getValue().getMappings().isEmpty()) {
-				mappable.put(entry.getKey(), entry.getValue());
-			}
+		RegistrationMapping dispatcherServlet = dispatcherServletMappings.iterator().next();
+		if (mappings.size() > 1 && !dispatcherServlet.isDefault()) {
+			String errorMessage = computeErrorMessage(servletContext.getServletRegistrations().values());
+			throw new IllegalArgumentException(errorMessage);
 		}
-		return mappable;
-	}
-
-	private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration> registrations) {
-		if (registrations == null) {
-			return false;
-		}
-		for (ServletRegistration registration : registrations.values()) {
-			if (isDispatcherServlet(registration)) {
-				return true;
-			}
-		}
-		return false;
-	}
-
-	private ServletRegistration requireOneRootDispatcherServlet(
-			Map<String, ? extends ServletRegistration> registrations) {
-		ServletRegistration rootDispatcherServlet = null;
-		for (ServletRegistration registration : registrations.values()) {
-			if (!isDispatcherServlet(registration)) {
-				continue;
-			}
-			if (registration.getMappings().size() > 1) {
-				return null;
-			}
-			if (!"/".equals(registration.getMappings().iterator().next())) {
-				return null;
-			}
-			rootDispatcherServlet = registration;
-		}
-		return rootDispatcherServlet;
-	}
-
-	private ServletRegistration requireOnlyPathMappedDispatcherServlet(
-			Map<String, ? extends ServletRegistration> registrations) {
-		ServletRegistration pathDispatcherServlet = null;
-		for (ServletRegistration registration : registrations.values()) {
-			if (!isDispatcherServlet(registration)) {
-				return null;
-			}
-			if (registration.getMappings().size() > 1) {
-				return null;
-			}
-			String mapping = registration.getMappings().iterator().next();
-			if (!mapping.startsWith("/") || !mapping.endsWith("/*")) {
-				return null;
-			}
-			if (pathDispatcherServlet != null) {
-				return null;
+		if (dispatcherServlet.isDefault()) {
+			if (mappings.size() == 1) {
+				return mvc;
 			}
-			pathDispatcherServlet = registration;
-		}
-		return pathDispatcherServlet;
-	}
-
-	private boolean isDispatcherServlet(ServletRegistration registration) {
-		Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
-				null);
-		try {
-			Class<?> clazz = Class.forName(registration.getClassName());
-			return dispatcherServlet.isAssignableFrom(clazz);
-		}
-		catch (ClassNotFoundException ex) {
-			return false;
+			return new DispatcherServletDelegatingRequestMatcher(ant, mvc);
 		}
+		return mvc;
 	}
 
 	private static String computeErrorMessage(Collection<? extends ServletRegistration> registrations) {

+ 77 - 0
config/src/main/java/org/springframework/security/config/annotation/web/ServletRegistrationsSupport.java

@@ -0,0 +1,77 @@
+/*
+ * Copyright 2002-2025 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.annotation.web;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Map;
+
+import jakarta.servlet.ServletContext;
+import jakarta.servlet.ServletRegistration;
+
+import org.springframework.util.ClassUtils;
+
+class ServletRegistrationsSupport {
+
+	private final Collection<RegistrationMapping> registrations;
+
+	ServletRegistrationsSupport(ServletContext servletContext) {
+		Map<String, ? extends ServletRegistration> registrations = servletContext.getServletRegistrations();
+		Collection<RegistrationMapping> mappings = new ArrayList<>();
+		for (Map.Entry<String, ? extends ServletRegistration> entry : registrations.entrySet()) {
+			if (!entry.getValue().getMappings().isEmpty()) {
+				for (String mapping : entry.getValue().getMappings()) {
+					mappings.add(new RegistrationMapping(entry.getValue(), mapping));
+				}
+			}
+		}
+		this.registrations = mappings;
+	}
+
+	Collection<RegistrationMapping> dispatcherServletMappings() {
+		Collection<RegistrationMapping> mappings = new ArrayList<>();
+		for (RegistrationMapping registration : this.registrations) {
+			if (registration.isDispatcherServlet()) {
+				mappings.add(registration);
+			}
+		}
+		return mappings;
+	}
+
+	Collection<RegistrationMapping> mappings() {
+		return this.registrations;
+	}
+
+	record RegistrationMapping(ServletRegistration registration, String mapping) {
+		boolean isDispatcherServlet() {
+			Class<?> dispatcherServlet = ClassUtils
+				.resolveClassName("org.springframework.web.servlet.DispatcherServlet", null);
+			try {
+				Class<?> clazz = Class.forName(this.registration.getClassName());
+				return dispatcherServlet.isAssignableFrom(clazz);
+			}
+			catch (ClassNotFoundException ex) {
+				return false;
+			}
+		}
+
+		boolean isDefault() {
+			return "/".equals(this.mapping);
+		}
+	}
+
+}