|
@@ -40,6 +40,7 @@ import org.springframework.core.ResolvableType;
|
|
|
import org.springframework.http.HttpMethod;
|
|
|
import org.springframework.lang.Nullable;
|
|
|
import org.springframework.security.config.ObjectPostProcessor;
|
|
|
+import org.springframework.security.config.annotation.web.ServletRegistrationsSupport.RegistrationMapping;
|
|
|
import org.springframework.security.config.annotation.web.configurers.AbstractConfigAttributeRequestMatcherRegistry;
|
|
|
import org.springframework.security.web.servlet.util.matcher.MvcRequestMatcher;
|
|
|
import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
|
@@ -235,103 +236,31 @@ public abstract class AbstractRequestMatcherRegistry<C> {
|
|
|
}
|
|
|
|
|
|
private RequestMatcher resolve(AntPathRequestMatcher ant, MvcRequestMatcher mvc, ServletContext servletContext) {
|
|
|
- Map<String, ? extends ServletRegistration> registrations = mappableServletRegistrations(servletContext);
|
|
|
- if (registrations.isEmpty()) {
|
|
|
+ ServletRegistrationsSupport registrations = new ServletRegistrationsSupport(servletContext);
|
|
|
+ Collection<RegistrationMapping> mappings = registrations.mappings();
|
|
|
+ if (mappings.isEmpty()) {
|
|
|
return new DispatcherServletDelegatingRequestMatcher(ant, mvc, new MockMvcRequestMatcher());
|
|
|
}
|
|
|
- if (!hasDispatcherServlet(registrations)) {
|
|
|
+ Collection<RegistrationMapping> dispatcherServletMappings = registrations.dispatcherServletMappings();
|
|
|
+ if (dispatcherServletMappings.isEmpty()) {
|
|
|
return new DispatcherServletDelegatingRequestMatcher(ant, mvc, new MockMvcRequestMatcher());
|
|
|
}
|
|
|
- ServletRegistration dispatcherServlet = requireOneRootDispatcherServlet(registrations);
|
|
|
- if (dispatcherServlet != null) {
|
|
|
- if (registrations.size() == 1) {
|
|
|
- return mvc;
|
|
|
- }
|
|
|
- return new DispatcherServletDelegatingRequestMatcher(ant, mvc, servletContext);
|
|
|
+ if (dispatcherServletMappings.size() > 1) {
|
|
|
+ String errorMessage = computeErrorMessage(servletContext.getServletRegistrations().values());
|
|
|
+ throw new IllegalArgumentException(errorMessage);
|
|
|
}
|
|
|
- dispatcherServlet = requireOnlyPathMappedDispatcherServlet(registrations);
|
|
|
- if (dispatcherServlet != null) {
|
|
|
- String mapping = dispatcherServlet.getMappings().iterator().next();
|
|
|
- mvc.setServletPath(mapping.substring(0, mapping.length() - 2));
|
|
|
- return mvc;
|
|
|
- }
|
|
|
- String errorMessage = computeErrorMessage(registrations.values());
|
|
|
- throw new IllegalArgumentException(errorMessage);
|
|
|
- }
|
|
|
-
|
|
|
- private Map<String, ? extends ServletRegistration> mappableServletRegistrations(ServletContext servletContext) {
|
|
|
- Map<String, ServletRegistration> mappable = new LinkedHashMap<>();
|
|
|
- for (Map.Entry<String, ? extends ServletRegistration> entry : servletContext.getServletRegistrations()
|
|
|
- .entrySet()) {
|
|
|
- if (!entry.getValue().getMappings().isEmpty()) {
|
|
|
- mappable.put(entry.getKey(), entry.getValue());
|
|
|
- }
|
|
|
+ RegistrationMapping dispatcherServlet = dispatcherServletMappings.iterator().next();
|
|
|
+ if (mappings.size() > 1 && !dispatcherServlet.isDefault()) {
|
|
|
+ String errorMessage = computeErrorMessage(servletContext.getServletRegistrations().values());
|
|
|
+ throw new IllegalArgumentException(errorMessage);
|
|
|
}
|
|
|
- return mappable;
|
|
|
- }
|
|
|
-
|
|
|
- private boolean hasDispatcherServlet(Map<String, ? extends ServletRegistration> registrations) {
|
|
|
- if (registrations == null) {
|
|
|
- return false;
|
|
|
- }
|
|
|
- for (ServletRegistration registration : registrations.values()) {
|
|
|
- if (isDispatcherServlet(registration)) {
|
|
|
- return true;
|
|
|
- }
|
|
|
- }
|
|
|
- return false;
|
|
|
- }
|
|
|
-
|
|
|
- private ServletRegistration requireOneRootDispatcherServlet(
|
|
|
- Map<String, ? extends ServletRegistration> registrations) {
|
|
|
- ServletRegistration rootDispatcherServlet = null;
|
|
|
- for (ServletRegistration registration : registrations.values()) {
|
|
|
- if (!isDispatcherServlet(registration)) {
|
|
|
- continue;
|
|
|
- }
|
|
|
- if (registration.getMappings().size() > 1) {
|
|
|
- return null;
|
|
|
- }
|
|
|
- if (!"/".equals(registration.getMappings().iterator().next())) {
|
|
|
- return null;
|
|
|
- }
|
|
|
- rootDispatcherServlet = registration;
|
|
|
- }
|
|
|
- return rootDispatcherServlet;
|
|
|
- }
|
|
|
-
|
|
|
- private ServletRegistration requireOnlyPathMappedDispatcherServlet(
|
|
|
- Map<String, ? extends ServletRegistration> registrations) {
|
|
|
- ServletRegistration pathDispatcherServlet = null;
|
|
|
- for (ServletRegistration registration : registrations.values()) {
|
|
|
- if (!isDispatcherServlet(registration)) {
|
|
|
- return null;
|
|
|
- }
|
|
|
- if (registration.getMappings().size() > 1) {
|
|
|
- return null;
|
|
|
- }
|
|
|
- String mapping = registration.getMappings().iterator().next();
|
|
|
- if (!mapping.startsWith("/") || !mapping.endsWith("/*")) {
|
|
|
- return null;
|
|
|
- }
|
|
|
- if (pathDispatcherServlet != null) {
|
|
|
- return null;
|
|
|
+ if (dispatcherServlet.isDefault()) {
|
|
|
+ if (mappings.size() == 1) {
|
|
|
+ return mvc;
|
|
|
}
|
|
|
- pathDispatcherServlet = registration;
|
|
|
- }
|
|
|
- return pathDispatcherServlet;
|
|
|
- }
|
|
|
-
|
|
|
- private boolean isDispatcherServlet(ServletRegistration registration) {
|
|
|
- Class<?> dispatcherServlet = ClassUtils.resolveClassName("org.springframework.web.servlet.DispatcherServlet",
|
|
|
- null);
|
|
|
- try {
|
|
|
- Class<?> clazz = Class.forName(registration.getClassName());
|
|
|
- return dispatcherServlet.isAssignableFrom(clazz);
|
|
|
- }
|
|
|
- catch (ClassNotFoundException ex) {
|
|
|
- return false;
|
|
|
+ return new DispatcherServletDelegatingRequestMatcher(ant, mvc);
|
|
|
}
|
|
|
+ return mvc;
|
|
|
}
|
|
|
|
|
|
private static String computeErrorMessage(Collection<? extends ServletRegistration> registrations) {
|