浏览代码

Mark Observations with Firewall Failures

Closes gh-11994
Josh Cummings 2 年之前
父节点
当前提交
2713075d08

+ 6 - 0
config/src/main/java/org/springframework/security/config/annotation/web/builders/WebSecurity.java

@@ -57,6 +57,7 @@ import org.springframework.security.web.access.intercept.AuthorizationFilter;
 import org.springframework.security.web.access.intercept.FilterSecurityInterceptor;
 import org.springframework.security.web.debug.DebugFilter;
 import org.springframework.security.web.firewall.HttpFirewall;
+import org.springframework.security.web.firewall.ObservationMarkingRequestRejectedHandler;
 import org.springframework.security.web.firewall.RequestRejectedHandler;
 import org.springframework.security.web.firewall.StrictHttpFirewall;
 import org.springframework.security.web.util.matcher.RequestMatcher;
@@ -307,6 +308,10 @@ public final class WebSecurity extends AbstractConfiguredSecurityBuilder<Filter,
 		if (this.requestRejectedHandler != null) {
 			filterChainProxy.setRequestRejectedHandler(this.requestRejectedHandler);
 		}
+		else if (!this.observationRegistry.isNoop()) {
+			filterChainProxy
+					.setRequestRejectedHandler(new ObservationMarkingRequestRejectedHandler(this.observationRegistry));
+		}
 		filterChainProxy.setFilterChainDecorator(getFilterChainDecorator());
 		filterChainProxy.afterPropertiesSet();
 
@@ -319,6 +324,7 @@ public final class WebSecurity extends AbstractConfiguredSecurityBuilder<Filter,
 					+ "********************************************************************\n\n");
 			result = new DebugFilter(filterChainProxy);
 		}
+
 		this.postBuildAction.run();
 		return result;
 	}

+ 1 - 1
config/src/main/java/org/springframework/security/config/http/HttpFirewallBeanDefinitionParser.java

@@ -40,7 +40,7 @@ public class HttpFirewallBeanDefinitionParser implements BeanDefinitionParser {
 			pc.getReaderContext().error("ref attribute is required", pc.extractSource(element));
 		}
 		// Ensure the FCP is registered.
-		HttpSecurityBeanDefinitionParser.registerFilterChainProxyIfNecessary(pc, pc.extractSource(element));
+		HttpSecurityBeanDefinitionParser.registerFilterChainProxyIfNecessary(pc, element);
 		BeanDefinition filterChainProxy = pc.getRegistry().getBeanDefinition(BeanIds.FILTER_CHAIN_PROXY);
 		filterChainProxy.getPropertyValues().addPropertyValue("firewall", new RuntimeBeanReference(ref));
 		return null;

+ 19 - 3
config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java

@@ -58,6 +58,7 @@ import org.springframework.security.web.DefaultSecurityFilterChain;
 import org.springframework.security.web.FilterChainProxy;
 import org.springframework.security.web.ObservationFilterChainDecorator;
 import org.springframework.security.web.PortResolverImpl;
+import org.springframework.security.web.firewall.ObservationMarkingRequestRejectedHandler;
 import org.springframework.security.web.util.matcher.AnyRequestMatcher;
 import org.springframework.util.StringUtils;
 import org.springframework.util.xml.DomUtils;
@@ -120,7 +121,7 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
 		CompositeComponentDefinition compositeDef = new CompositeComponentDefinition(element.getTagName(),
 				pc.extractSource(element));
 		pc.pushContainingComponent(compositeDef);
-		registerFilterChainProxyIfNecessary(pc, pc.extractSource(element));
+		registerFilterChainProxyIfNecessary(pc, element);
 		// Obtain the filter chains and add the new chain to it
 		BeanDefinition listFactoryBean = pc.getRegistry().getBeanDefinition(BeanIds.FILTER_CHAINS);
 		List<BeanReference> filterChains = (List<BeanReference>) listFactoryBean.getPropertyValues()
@@ -351,7 +352,8 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
 		return customFilters;
 	}
 
-	static void registerFilterChainProxyIfNecessary(ParserContext pc, Object source) {
+	static void registerFilterChainProxyIfNecessary(ParserContext pc, Element element) {
+		Object source = pc.extractSource(element);
 		BeanDefinitionRegistry registry = pc.getRegistry();
 		if (registry.containsBeanDefinition(BeanIds.FILTER_CHAIN_PROXY)) {
 			return;
@@ -378,6 +380,7 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
 		requestRejected.addConstructorArgValue("requestRejectedHandler");
 		requestRejected.addConstructorArgValue(BeanIds.FILTER_CHAIN_PROXY);
 		requestRejected.addConstructorArgValue("requestRejectedHandler");
+		requestRejected.addPropertyValue("observationRegistry", getObservationRegistry(element));
 		AbstractBeanDefinition requestRejectedBean = requestRejected.getBeanDefinition();
 		String requestRejectedPostProcessorName = pc.getReaderContext().generateBeanName(requestRejectedBean);
 		registry.registerBeanDefinition(requestRejectedPostProcessorName, requestRejectedBean);
@@ -391,7 +394,7 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
 		return BeanDefinitionBuilder.rootBeanDefinition(ObservationRegistryFactory.class).getBeanDefinition();
 	}
 
-	static class RequestRejectedHandlerPostProcessor implements BeanDefinitionRegistryPostProcessor {
+	public static class RequestRejectedHandlerPostProcessor implements BeanDefinitionRegistryPostProcessor {
 
 		private final String beanName;
 
@@ -399,6 +402,8 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
 
 		private final String targetPropertyName;
 
+		private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
+
 		RequestRejectedHandlerPostProcessor(String beanName, String targetBeanName, String targetPropertyName) {
 			this.beanName = beanName;
 			this.targetBeanName = targetBeanName;
@@ -412,6 +417,13 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
 				beanDefinition.getPropertyValues().add(this.targetPropertyName,
 						new RuntimeBeanReference(this.beanName));
 			}
+			else if (!this.observationRegistry.isNoop()) {
+				BeanDefinition observable = BeanDefinitionBuilder
+						.rootBeanDefinition(ObservationMarkingRequestRejectedHandler.class)
+						.addConstructorArgValue(this.observationRegistry).getBeanDefinition();
+				BeanDefinition beanDefinition = registry.getBeanDefinition(this.targetBeanName);
+				beanDefinition.getPropertyValues().add(this.targetPropertyName, observable);
+			}
 		}
 
 		@Override
@@ -419,6 +431,10 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
 
 		}
 
+		public void setObservationRegistry(ObservationRegistry registry) {
+			this.observationRegistry = registry;
+		}
+
 	}
 
 	/**

+ 44 - 0
web/src/main/java/org/springframework/security/web/firewall/ObservationMarkingRequestRejectedHandler.java

@@ -0,0 +1,44 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.firewall;
+
+import java.io.IOException;
+
+import io.micrometer.observation.Observation;
+import io.micrometer.observation.ObservationRegistry;
+import jakarta.servlet.ServletException;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+
+public final class ObservationMarkingRequestRejectedHandler implements RequestRejectedHandler {
+
+	private final ObservationRegistry registry;
+
+	public ObservationMarkingRequestRejectedHandler(ObservationRegistry registry) {
+		this.registry = registry;
+	}
+
+	@Override
+	public void handle(HttpServletRequest request, HttpServletResponse response, RequestRejectedException exception)
+			throws IOException, ServletException {
+		Observation observation = this.registry.getCurrentObservation();
+		if (observation != null) {
+			observation.error(exception);
+		}
+	}
+
+}