浏览代码

PathPatternRequestParser Retains Servlet Path

Issue gh-16765
Josh Cummings 5 月之前
父节点
当前提交
c53bf2befe

+ 17 - 83
web/src/main/java/org/springframework/security/web/servlet/util/matcher/PathPatternRequestMatcher.java

@@ -16,14 +16,8 @@
 
 package org.springframework.security.web.servlet.util.matcher;
 
-import java.util.Collection;
-import java.util.LinkedHashMap;
-import java.util.Map;
 import java.util.Objects;
-import java.util.concurrent.atomic.AtomicReference;
 
-import jakarta.servlet.ServletContext;
-import jakarta.servlet.ServletRegistration;
 import jakarta.servlet.http.HttpServletRequest;
 
 import org.springframework.http.HttpMethod;
@@ -40,11 +34,12 @@ import org.springframework.web.util.pattern.PathPatternParser;
 
 /**
  * A {@link RequestMatcher} that uses {@link PathPattern}s to match against each
- * {@link HttpServletRequest}. The provided path should be relative to the servlet (that
- * is, it should exclude any context or servlet path).
+ * {@link HttpServletRequest}. The provided path should be relative to the context path
+ * (that is, it should exclude any context path).
  *
  * <p>
- * To also match the servlet, please see {@link PathPatternRequestMatcher#servletPath}
+ * You can provide the servlet path in {@link PathPatternRequestMatcher#servletPath} and
+ * reuse for multiple matchers.
  *
  * <p>
  * Note that the {@link org.springframework.web.servlet.HandlerMapping} that contains the
@@ -113,7 +108,7 @@ public final class PathPatternRequestMatcher implements RequestMatcher {
 		if (!this.method.matches(request)) {
 			return MatchResult.notMatch();
 		}
-		PathContainer path = getRequestPath(request).pathWithinApplication();
+		PathContainer path = getPathContainer(request);
 		PathPattern.PathMatchInfo info = this.pattern.matchAndExtract(path);
 		return (info != null) ? MatchResult.match(info.getUriVariables()) : MatchResult.notMatch();
 	}
@@ -122,11 +117,7 @@ public final class PathPatternRequestMatcher implements RequestMatcher {
 		this.method = method;
 	}
 
-	void setServletPath(RequestMatcher servletPath) {
-		this.servletPath = servletPath;
-	}
-
-	private RequestPath getRequestPath(HttpServletRequest request) {
+	private PathContainer getPathContainer(HttpServletRequest request) {
 		RequestPath path;
 		if (ServletRequestPathUtils.hasParsedRequestPath(request)) {
 			path = ServletRequestPathUtils.getParsedRequestPath(request);
@@ -135,7 +126,8 @@ public final class PathPatternRequestMatcher implements RequestMatcher {
 			path = ServletRequestPathUtils.parseAndCache(request);
 			ServletRequestPathUtils.clearParsedRequestPath(request);
 		}
-		return path;
+		PathContainer contextPath = path.contextPath();
+		return path.subPath(contextPath.elements().size());
 	}
 
 	/**
@@ -166,9 +158,6 @@ public final class PathPatternRequestMatcher implements RequestMatcher {
 		if (this.method instanceof HttpMethodRequestMatcher m) {
 			request.append(m.method.name()).append(' ');
 		}
-		if (this.servletPath instanceof ServletPathRequestMatcher s) {
-			request.append(s.path);
-		}
 		return "PathPattern [" + request + this.pattern + "]";
 	}
 
@@ -194,17 +183,17 @@ public final class PathPatternRequestMatcher implements RequestMatcher {
 
 		private final PathPatternParser parser;
 
-		private final RequestMatcher servletPath;
+		private final String servletPath;
 
 		Builder() {
 			this(PathPatternParser.defaultInstance);
 		}
 
 		Builder(PathPatternParser parser) {
-			this(parser, AnyRequestMatcher.INSTANCE);
+			this(parser, "");
 		}
 
-		Builder(PathPatternParser parser, RequestMatcher servletPath) {
+		Builder(PathPatternParser parser, String servletPath) {
 			this.parser = parser;
 			this.servletPath = servletPath;
 		}
@@ -215,7 +204,11 @@ public final class PathPatternRequestMatcher implements RequestMatcher {
 		 * @return the {@link Builder} for more configuration
 		 */
 		public Builder servletPath(String servletPath) {
-			return new Builder(this.parser, new ServletPathRequestMatcher(servletPath));
+			Assert.notNull(servletPath, "servletPath cannot be null");
+			Assert.isTrue(servletPath.startsWith("/"), "servletPath must start with '/'");
+			Assert.isTrue(!servletPath.endsWith("/"), "servletPath must not end with a slash");
+			Assert.isTrue(!servletPath.contains("*"), "servletPath must not contain a star");
+			return new Builder(this.parser, servletPath);
 		}
 
 		/**
@@ -286,14 +279,11 @@ public final class PathPatternRequestMatcher implements RequestMatcher {
 		public PathPatternRequestMatcher matcher(@Nullable HttpMethod method, String path) {
 			Assert.notNull(path, "pattern cannot be null");
 			Assert.isTrue(path.startsWith("/"), "pattern must start with a /");
-			PathPattern pathPattern = this.parser.parse(path);
+			PathPattern pathPattern = this.parser.parse(this.servletPath + path);
 			PathPatternRequestMatcher requestMatcher = new PathPatternRequestMatcher(pathPattern);
 			if (method != null) {
 				requestMatcher.setMethod(new HttpMethodRequestMatcher(method));
 			}
-			if (this.servletPath != AnyRequestMatcher.INSTANCE) {
-				requestMatcher.setServletPath(this.servletPath);
-			}
 			return requestMatcher;
 		}
 
@@ -319,60 +309,4 @@ public final class PathPatternRequestMatcher implements RequestMatcher {
 
 	}
 
-	private static final class ServletPathRequestMatcher implements RequestMatcher {
-
-		private final String path;
-
-		private final AtomicReference<Boolean> servletExists = new AtomicReference<>();
-
-		ServletPathRequestMatcher(String servletPath) {
-			Assert.notNull(servletPath, "servletPath cannot be null");
-			Assert.isTrue(servletPath.startsWith("/"), "servletPath must start with '/'");
-			Assert.isTrue(!servletPath.endsWith("/"), "servletPath must not end with a slash");
-			Assert.isTrue(!servletPath.contains("*"), "servletPath must not contain a star");
-			this.path = servletPath;
-		}
-
-		@Override
-		public boolean matches(HttpServletRequest request) {
-			Assert.isTrue(servletExists(request), () -> this.path + "/* does not exist in your servlet registration "
-					+ registrationMappings(request));
-			return Objects.equals(this.path, ServletRequestPathUtils.getServletPathPrefix(request));
-		}
-
-		private boolean servletExists(HttpServletRequest request) {
-			return this.servletExists.updateAndGet((value) -> {
-				if (value != null) {
-					return value;
-				}
-				if (request.getAttribute("org.springframework.test.web.servlet.MockMvc.MVC_RESULT_ATTRIBUTE") != null) {
-					return true;
-				}
-				for (ServletRegistration registration : request.getServletContext()
-					.getServletRegistrations()
-					.values()) {
-					if (registration.getMappings().contains(this.path + "/*")) {
-						return true;
-					}
-				}
-				return false;
-			});
-		}
-
-		private Map<String, Collection<String>> registrationMappings(HttpServletRequest request) {
-			Map<String, Collection<String>> map = new LinkedHashMap<>();
-			ServletContext servletContext = request.getServletContext();
-			for (ServletRegistration registration : servletContext.getServletRegistrations().values()) {
-				map.put(registration.getName(), registration.getMappings());
-			}
-			return map;
-		}
-
-		@Override
-		public String toString() {
-			return "ServletPath [" + this.path + "]";
-		}
-
-	}
-
 }

+ 18 - 9
web/src/test/java/org/springframework/security/web/servlet/util/matcher/PathPatternRequestMatcherTests.java

@@ -49,15 +49,15 @@ public class PathPatternRequestMatcherTests {
 	}
 
 	@Test
-	void matcherWhenOnlyPathInfoMatchesThenMatches() {
+	void matcherWhenOnlyPathInfoMatchesThenNoMatch() {
 		RequestMatcher matcher = PathPatternRequestMatcher.withDefaults().matcher("/uri");
-		assertThat(matcher.matches(request("GET", "/mvc/uri", "/mvc"))).isTrue();
+		assertThat(matcher.matches(request("GET", "/mvc/uri", "/mvc"))).isFalse();
 	}
 
 	@Test
-	void matcherWhenUriContainsServletPathThenNoMatch() {
+	void matcherWhenUriContainsServletPathThenMatch() {
 		RequestMatcher matcher = PathPatternRequestMatcher.withDefaults().matcher("/mvc/uri");
-		assertThat(matcher.matches(request("GET", "/mvc/uri", "/mvc"))).isFalse();
+		assertThat(matcher.matches(request("GET", "/mvc/uri", "/mvc"))).isTrue();
 	}
 
 	@Test
@@ -101,24 +101,33 @@ public class PathPatternRequestMatcherTests {
 	}
 
 	@Test
-	void matcherWhenRequestPathThenIgnoresServletPath() {
+	void matcherWhenRequestPathThenRequiresServletPath() {
 		PathPatternRequestMatcher.Builder request = PathPatternRequestMatcher.withDefaults();
 		RequestMatcher matcher = request.matcher(HttpMethod.GET, "/endpoint");
 		MockHttpServletRequest mock = get("/servlet/path/endpoint").servletPath("/servlet/path").buildRequest(null);
 		ServletRequestPathUtils.parseAndCache(mock);
-		assertThat(matcher.matches(mock)).isTrue();
+		assertThat(matcher.matches(mock)).isFalse();
 		mock = get("/endpoint").servletPath("/endpoint").buildRequest(null);
 		ServletRequestPathUtils.parseAndCache(mock);
 		assertThat(matcher.matches(mock)).isTrue();
 	}
 
 	@Test
-	void matcherWhenServletPathThenRequiresServletPathToExist() {
+	void matcherWhenMultiServletPathThenMatches() {
+		PathPatternRequestMatcher.Builder servlet = PathPatternRequestMatcher.withDefaults()
+			.servletPath("/servlet/path");
+		RequestMatcher matcher = servlet.matcher(HttpMethod.GET, "/endpoint");
+		MockHttpServletRequest mock = get("/servlet/path/endpoint").servletPath("/servlet/path").buildRequest(null);
+		assertThat(matcher.matches(mock)).isTrue();
+	}
+
+	@Test
+	void matcherWhenMultiContextPathThenMatches() {
 		PathPatternRequestMatcher.Builder servlet = PathPatternRequestMatcher.withDefaults()
 			.servletPath("/servlet/path");
 		RequestMatcher matcher = servlet.matcher(HttpMethod.GET, "/endpoint");
-		assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(
-				() -> matcher.matches(get("/servlet/path/endpoint").servletPath("/servlet/path").buildRequest(null)));
+		assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> matcher.matches(
+				get("/servlet/path/endpoint").servletPath("/servlet/path").contextPath("/app").buildRequest(null)));
 	}
 
 	@Test