Browse Source

Mark Observations with CSRF Failures

Closes gh-11993
Josh Cummings 2 năm trước cách đây
mục cha
commit
46ab84684b

+ 19 - 0
config/src/main/java/org/springframework/security/config/annotation/web/configurers/CsrfConfigurer.java

@@ -20,6 +20,7 @@ import java.util.ArrayList;
 import java.util.LinkedHashMap;
 import java.util.List;
 
+import io.micrometer.observation.ObservationRegistry;
 import jakarta.servlet.http.HttpServletRequest;
 
 import org.springframework.context.ApplicationContext;
@@ -29,7 +30,9 @@ import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
 import org.springframework.security.config.annotation.web.builders.HttpSecurity;
 import org.springframework.security.web.access.AccessDeniedHandler;
 import org.springframework.security.web.access.AccessDeniedHandlerImpl;
+import org.springframework.security.web.access.CompositeAccessDeniedHandler;
 import org.springframework.security.web.access.DelegatingAccessDeniedHandler;
+import org.springframework.security.web.access.ObservationMarkingAccessDeniedHandler;
 import org.springframework.security.web.authentication.session.SessionAuthenticationStrategy;
 import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
 import org.springframework.security.web.csrf.CsrfFilter;
@@ -221,6 +224,11 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 			filter.setRequireCsrfProtectionMatcher(requireCsrfProtectionMatcher);
 		}
 		AccessDeniedHandler accessDeniedHandler = createAccessDeniedHandler(http);
+		ObservationRegistry registry = getObservationRegistry();
+		if (!registry.isNoop()) {
+			ObservationMarkingAccessDeniedHandler observable = new ObservationMarkingAccessDeniedHandler(registry);
+			accessDeniedHandler = new CompositeAccessDeniedHandler(observable, accessDeniedHandler);
+		}
 		if (accessDeniedHandler != null) {
 			filter.setAccessDeniedHandler(accessDeniedHandler);
 		}
@@ -331,6 +339,17 @@ public final class CsrfConfigurer<H extends HttpSecurityBuilder<H>>
 		return csrfAuthenticationStrategy;
 	}
 
+	private ObservationRegistry getObservationRegistry() {
+		ApplicationContext context = getBuilder().getSharedObject(ApplicationContext.class);
+		String[] names = context.getBeanNamesForType(ObservationRegistry.class);
+		if (names.length == 1) {
+			return context.getBean(ObservationRegistry.class);
+		}
+		else {
+			return ObservationRegistry.NOOP;
+		}
+	}
+
 	/**
 	 * Allows registering {@link RequestMatcher} instances that should be ignored (even if
 	 * the {@link HttpServletRequest} matches the

+ 18 - 1
config/src/main/java/org/springframework/security/config/http/CsrfBeanDefinitionParser.java

@@ -36,7 +36,9 @@ import org.springframework.beans.factory.xml.ParserContext;
 import org.springframework.security.access.AccessDeniedException;
 import org.springframework.security.config.annotation.web.HttpSecurityBuilder;
 import org.springframework.security.web.access.AccessDeniedHandler;
+import org.springframework.security.web.access.CompositeAccessDeniedHandler;
 import org.springframework.security.web.access.DelegatingAccessDeniedHandler;
+import org.springframework.security.web.access.ObservationMarkingAccessDeniedHandler;
 import org.springframework.security.web.csrf.CsrfAuthenticationStrategy;
 import org.springframework.security.web.csrf.CsrfFilter;
 import org.springframework.security.web.csrf.CsrfLogoutHandler;
@@ -80,6 +82,8 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
 
 	private String requestHandlerRef;
 
+	private BeanMetadataElement observationRegistry;
+
 	@Override
 	public BeanDefinition parse(Element element, ParserContext pc) {
 		boolean disabled = element != null && "true".equals(element.getAttribute("disabled"));
@@ -160,7 +164,16 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
 				.rootBeanDefinition(DelegatingAccessDeniedHandler.class);
 		deniedBldr.addConstructorArgValue(handlers);
 		deniedBldr.addConstructorArgValue(defaultDeniedHandler);
-		return deniedBldr.getBeanDefinition();
+		BeanDefinition denied = deniedBldr.getBeanDefinition();
+		ManagedList compositeList = new ManagedList();
+		BeanDefinitionBuilder compositeBldr = BeanDefinitionBuilder
+				.rootBeanDefinition(CompositeAccessDeniedHandler.class);
+		BeanDefinition observing = BeanDefinitionBuilder.rootBeanDefinition(ObservationMarkingAccessDeniedHandler.class)
+				.addConstructorArgValue(this.observationRegistry).getBeanDefinition();
+		compositeList.add(denied);
+		compositeList.add(observing);
+		compositeBldr.addConstructorArgValue(compositeList);
+		return compositeBldr.getBeanDefinition();
 	}
 
 	BeanDefinition getCsrfAuthenticationStrategy() {
@@ -195,6 +208,10 @@ public class CsrfBeanDefinitionParser implements BeanDefinitionParser {
 		}
 	}
 
+	void setObservationRegistry(BeanMetadataElement observationRegistry) {
+		this.observationRegistry = observationRegistry;
+	}
+
 	private static final class DefaultRequiresCsrfMatcher implements RequestMatcher {
 
 		private final HashSet<String> allowedMethods = new HashSet<>(Arrays.asList("GET", "HEAD", "TRACE", "OPTIONS"));

+ 29 - 3
config/src/main/java/org/springframework/security/config/http/HttpConfigurationBuilder.java

@@ -19,6 +19,7 @@ package org.springframework.security.config.http;
 import java.util.ArrayList;
 import java.util.List;
 
+import io.micrometer.observation.ObservationRegistry;
 import jakarta.servlet.ServletRequest;
 import org.w3c.dom.Element;
 
@@ -106,6 +107,8 @@ class HttpConfigurationBuilder {
 
 	private static final String ATT_INVALID_SESSION_URL = "invalid-session-url";
 
+	private static final String ATT_OBSERVATION_REGISTRY_REF = "observation-registry-ref";
+
 	private static final String ATT_SESSION_AUTH_STRATEGY_REF = "session-authentication-strategy-ref";
 
 	private static final String ATT_SESSION_AUTH_ERROR_URL = "session-authentication-error-url";
@@ -211,7 +214,7 @@ class HttpConfigurationBuilder {
 	private boolean addAllAuth;
 
 	HttpConfigurationBuilder(Element element, boolean addAllAuth, ParserContext pc, BeanReference portMapper,
-			BeanReference portResolver, BeanReference authenticationManager) {
+			BeanReference portResolver, BeanReference authenticationManager, BeanMetadataElement observationRegistry) {
 		this.httpElt = element;
 		this.addAllAuth = addAllAuth;
 		this.pc = pc;
@@ -226,7 +229,7 @@ class HttpConfigurationBuilder {
 		createSecurityContextHolderStrategy();
 		createForceEagerSessionCreationFilter();
 		createDisableEncodeUrlFilter();
-		createCsrfFilter();
+		createCsrfFilter(observationRegistry);
 		createSecurityPersistence();
 		createSessionManagementFilters();
 		createWebAsyncManagerFilter();
@@ -812,9 +815,10 @@ class HttpConfigurationBuilder {
 		}
 	}
 
-	private void createCsrfFilter() {
+	private void createCsrfFilter(BeanMetadataElement observationRegistry) {
 		Element elmt = DomUtils.getChildElementByTagName(this.httpElt, Elements.CSRF);
 		this.csrfParser = new CsrfBeanDefinitionParser();
+		this.csrfParser.setObservationRegistry(observationRegistry);
 		this.csrfFilter = this.csrfParser.parse(elmt, this.pc);
 		if (this.csrfFilter == null) {
 			this.csrfParser = null;
@@ -897,6 +901,14 @@ class HttpConfigurationBuilder {
 		return filters;
 	}
 
+	private static BeanMetadataElement getObservationRegistry(Element httpElmt) {
+		String holderStrategyRef = httpElmt.getAttribute(ATT_OBSERVATION_REGISTRY_REF);
+		if (StringUtils.hasText(holderStrategyRef)) {
+			return new RuntimeBeanReference(holderStrategyRef);
+		}
+		return BeanDefinitionBuilder.rootBeanDefinition(ObservationRegistryFactory.class).getBeanDefinition();
+	}
+
 	static class RoleVoterBeanFactory extends AbstractGrantedAuthorityDefaultsBeanFactory {
 
 		private RoleVoter voter = new RoleVoter();
@@ -944,4 +956,18 @@ class HttpConfigurationBuilder {
 
 	}
 
+	static class ObservationRegistryFactory implements FactoryBean<ObservationRegistry> {
+
+		@Override
+		public ObservationRegistry getObject() throws Exception {
+			return ObservationRegistry.NOOP;
+		}
+
+		@Override
+		public Class<?> getObjectType() {
+			return ObservationRegistry.class;
+		}
+
+	}
+
 }

+ 2 - 1
config/src/main/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParser.java

@@ -150,8 +150,9 @@ public class HttpSecurityBeanDefinitionParser implements BeanDefinitionParser {
 		ManagedList<BeanReference> authenticationProviders = new ManagedList<>();
 		BeanReference authenticationManager = createAuthenticationManager(element, pc, authenticationProviders);
 		boolean forceAutoConfig = isDefaultHttpConfig(element);
+		BeanMetadataElement observationRegistry = getObservationRegistry(element);
 		HttpConfigurationBuilder httpBldr = new HttpConfigurationBuilder(element, forceAutoConfig, pc, portMapper,
-				portResolver, authenticationManager);
+				portResolver, authenticationManager, observationRegistry);
 		httpBldr.getSecurityContextRepositoryForAuthenticationFilters();
 		AuthenticationConfigBuilder authBldr = new AuthenticationConfigBuilder(element, forceAutoConfig, pc,
 				httpBldr.getSessionCreationPolicy(), httpBldr.getRequestCache(), authenticationManager,

+ 50 - 0
web/src/main/java/org/springframework/security/web/access/CompositeAccessDeniedHandler.java

@@ -0,0 +1,50 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.access;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+
+import jakarta.servlet.ServletException;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+
+import org.springframework.security.access.AccessDeniedException;
+
+public final class CompositeAccessDeniedHandler implements AccessDeniedHandler {
+
+	private Collection<AccessDeniedHandler> handlers;
+
+	public CompositeAccessDeniedHandler(AccessDeniedHandler... handlers) {
+		this(Arrays.asList(handlers));
+	}
+
+	public CompositeAccessDeniedHandler(Collection<AccessDeniedHandler> handlers) {
+		this.handlers = new ArrayList<>(handlers);
+	}
+
+	@Override
+	public void handle(HttpServletRequest request, HttpServletResponse response,
+			AccessDeniedException accessDeniedException) throws IOException, ServletException {
+		for (AccessDeniedHandler handler : this.handlers) {
+			handler.handle(request, response, accessDeniedException);
+		}
+	}
+
+}

+ 46 - 0
web/src/main/java/org/springframework/security/web/access/ObservationMarkingAccessDeniedHandler.java

@@ -0,0 +1,46 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.access;
+
+import java.io.IOException;
+
+import io.micrometer.observation.Observation;
+import io.micrometer.observation.ObservationRegistry;
+import jakarta.servlet.ServletException;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+
+import org.springframework.security.access.AccessDeniedException;
+
+public final class ObservationMarkingAccessDeniedHandler implements AccessDeniedHandler {
+
+	private final ObservationRegistry registry;
+
+	public ObservationMarkingAccessDeniedHandler(ObservationRegistry registry) {
+		this.registry = registry;
+	}
+
+	@Override
+	public void handle(HttpServletRequest request, HttpServletResponse response, AccessDeniedException exception)
+			throws IOException, ServletException {
+		Observation observation = this.registry.getCurrentObservation();
+		if (observation != null) {
+			observation.error(exception);
+		}
+	}
+
+}