2
0
Эх сурвалжийг харах

Consistently handle RequestRejectedException if it is wrapped

Closes gh-11645
Marcus Da Coregio 3 жил өмнө
parent
commit
ead587c597

+ 12 - 2
web/src/main/java/org/springframework/security/web/FilterChainProxy.java

@@ -40,6 +40,7 @@ import org.springframework.security.web.firewall.HttpFirewall;
 import org.springframework.security.web.firewall.RequestRejectedException;
 import org.springframework.security.web.firewall.RequestRejectedHandler;
 import org.springframework.security.web.firewall.StrictHttpFirewall;
+import org.springframework.security.web.util.ThrowableAnalyzer;
 import org.springframework.security.web.util.UrlUtils;
 import org.springframework.security.web.util.matcher.RequestMatcher;
 import org.springframework.util.Assert;
@@ -154,6 +155,8 @@ public class FilterChainProxy extends GenericFilterBean {
 
 	private RequestRejectedHandler requestRejectedHandler = new DefaultRequestRejectedHandler();
 
+	private ThrowableAnalyzer throwableAnalyzer = new ThrowableAnalyzer();
+
 	public FilterChainProxy() {
 	}
 
@@ -182,8 +185,15 @@ public class FilterChainProxy extends GenericFilterBean {
 			request.setAttribute(FILTER_APPLIED, Boolean.TRUE);
 			doFilterInternal(request, response, chain);
 		}
-		catch (RequestRejectedException ex) {
-			this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response, ex);
+		catch (Exception ex) {
+			Throwable[] causeChain = this.throwableAnalyzer.determineCauseChain(ex);
+			Throwable requestRejectedException = this.throwableAnalyzer
+					.getFirstThrowableOfType(RequestRejectedException.class, causeChain);
+			if (!(requestRejectedException instanceof RequestRejectedException)) {
+				throw ex;
+			}
+			this.requestRejectedHandler.handle((HttpServletRequest) request, (HttpServletResponse) response,
+					(RequestRejectedException) requestRejectedException);
 		}
 		finally {
 			SecurityContextHolder.clearContext();

+ 15 - 0
web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java

@@ -49,6 +49,7 @@ import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.BDDMockito.willAnswer;
+import static org.mockito.BDDMockito.willThrow;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyZeroInteractions;
@@ -252,4 +253,18 @@ public class FilterChainProxyTests {
 		verify(rjh).handle(eq(this.request), eq(this.response), eq((requestRejectedException)));
 	}
 
+	@Test
+	public void requestRejectedHandlerIsCalledIfFirewallThrowsWrappedRequestRejectedException() throws Exception {
+		HttpFirewall fw = mock(HttpFirewall.class);
+		RequestRejectedHandler rjh = mock(RequestRejectedHandler.class);
+		this.fcp.setFirewall(fw);
+		this.fcp.setRequestRejectedHandler(rjh);
+		RequestRejectedException requestRejectedException = new RequestRejectedException("Contains illegal chars");
+		ServletException servletException = new ServletException(requestRejectedException);
+		given(fw.getFirewalledRequest(this.request)).willReturn(mock(FirewalledRequest.class));
+		willThrow(servletException).given(this.chain).doFilter(any(), any());
+		this.fcp.doFilter(this.request, this.response, this.chain);
+		verify(rjh).handle(eq(this.request), eq(this.response), eq((requestRejectedException)));
+	}
+
 }