Browse Source

Lookup Parent Observation

Closes gh-12524
Josh Cummings 2 years ago
parent
commit
4d2dab9b6b

+ 11 - 7
core/src/main/java/org/springframework/security/authentication/ObservationReactiveAuthenticationManager.java

@@ -18,6 +18,7 @@ package org.springframework.security.authentication;
 
 import io.micrometer.observation.Observation;
 import io.micrometer.observation.ObservationRegistry;
+import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
 import reactor.core.publisher.Mono;
 
 import org.springframework.security.core.Authentication;
@@ -48,13 +49,16 @@ public class ObservationReactiveAuthenticationManager implements ReactiveAuthent
 		AuthenticationObservationContext context = new AuthenticationObservationContext();
 		context.setAuthenticationRequest(authentication);
 		context.setAuthenticationManagerClass(this.delegate.getClass());
-		Observation observation = Observation.createNotStarted(this.convention, () -> context, this.registry).start();
-		return this.delegate.authenticate(authentication).doOnSuccess((result) -> {
-			context.setAuthenticationResult(result);
-			observation.stop();
-		}).doOnCancel(observation::stop).doOnError((t) -> {
-			observation.error(t);
-			observation.stop();
+		return Mono.deferContextual((contextView) -> {
+			Observation observation = Observation.createNotStarted(this.convention, () -> context, this.registry)
+					.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
+			return this.delegate.authenticate(authentication).doOnSuccess((result) -> {
+				context.setAuthenticationResult(result);
+				observation.stop();
+			}).doOnCancel(observation::stop).doOnError((t) -> {
+				observation.error(t);
+				observation.stop();
+			});
 		});
 	}
 

+ 14 - 10
core/src/main/java/org/springframework/security/authorization/ObservationReactiveAuthorizationManager.java

@@ -18,6 +18,7 @@ package org.springframework.security.authorization;
 
 import io.micrometer.observation.Observation;
 import io.micrometer.observation.ObservationRegistry;
+import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
 import reactor.core.publisher.Mono;
 
 import org.springframework.security.access.AccessDeniedException;
@@ -50,16 +51,19 @@ public final class ObservationReactiveAuthorizationManager<T> implements Reactiv
 			context.setAuthentication(auth);
 			return context.getAuthentication();
 		});
-		Observation observation = Observation.createNotStarted(this.convention, () -> context, this.registry).start();
-		return this.delegate.check(wrapped, object).doOnSuccess((decision) -> {
-			context.setDecision(decision);
-			if (decision == null || !decision.isGranted()) {
-				observation.error(new AccessDeniedException("Access Denied"));
-			}
-			observation.stop();
-		}).doOnCancel(observation::stop).doOnError((t) -> {
-			observation.error(t);
-			observation.stop();
+		return Mono.deferContextual((contextView) -> {
+			Observation observation = Observation.createNotStarted(this.convention, () -> context, this.registry)
+					.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();
+			return this.delegate.check(wrapped, object).doOnSuccess((decision) -> {
+				context.setDecision(decision);
+				if (decision == null || !decision.isGranted()) {
+					observation.error(new AccessDeniedException("Access Denied"));
+				}
+				observation.stop();
+			}).doOnCancel(observation::stop).doOnError((t) -> {
+				observation.error(t);
+				observation.stop();
+			});
 		});
 	}
 

+ 19 - 11
web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java

@@ -27,6 +27,7 @@ import io.micrometer.common.KeyValues;
 import io.micrometer.observation.Observation;
 import io.micrometer.observation.ObservationConvention;
 import io.micrometer.observation.ObservationRegistry;
+import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
 import reactor.core.publisher.Mono;
 
 import org.springframework.lang.Nullable;
@@ -73,20 +74,22 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 	}
 
 	private WebFilterChain wrapSecured(WebFilterChain original) {
-		return (exchange) -> {
+		return (exchange) -> Mono.deferContextual((contextView) -> {
 			AroundWebFilterObservation parent = observation(exchange);
+			Observation parentObservation = contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null);
 			Observation observation = Observation.createNotStarted(SECURED_OBSERVATION_NAME, this.registry)
-					.contextualName("secured request");
+					.contextualName("secured request").parentObservation(parentObservation);
 			return parent.wrap(WebFilterObservation.create(observation).wrap(original)).filter(exchange);
-		};
+		});
 	}
 
 	private WebFilterChain wrapUnsecured(WebFilterChain original) {
-		return (exchange) -> {
+		return (exchange) -> Mono.deferContextual((contextView) -> {
+			Observation parentObservation = contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null);
 			Observation observation = Observation.createNotStarted(UNSECURED_OBSERVATION_NAME, this.registry)
-					.contextualName("unsecured request");
+					.contextualName("unsecured request").parentObservation(parentObservation);
 			return WebFilterObservation.create(observation).wrap(original).filter(exchange);
-		};
+		});
 	}
 
 	private List<ObservationWebFilter> wrap(List<WebFilter> filters) {
@@ -186,8 +189,11 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 		@Override
 		public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
 			if (this.position == 1) {
-				AroundWebFilterObservation parent = parent(exchange);
-				return parent.wrap(this::wrapFilter).filter(exchange, chain);
+				return Mono.deferContextual((contextView) -> {
+					Observation parentObservation = contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null);
+					AroundWebFilterObservation parent = parent(exchange, parentObservation);
+					return parent.wrap(this::wrapFilter).filter(exchange, chain);
+				});
 			}
 			else {
 				return wrapFilter(exchange, chain);
@@ -211,11 +217,13 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 			});
 		}
 
-		private AroundWebFilterObservation parent(ServerWebExchange exchange) {
+		private AroundWebFilterObservation parent(ServerWebExchange exchange, Observation parentObservation) {
 			WebFilterChainObservationContext beforeContext = WebFilterChainObservationContext.before();
 			WebFilterChainObservationContext afterContext = WebFilterChainObservationContext.after();
-			Observation before = Observation.createNotStarted(this.convention, () -> beforeContext, this.registry);
-			Observation after = Observation.createNotStarted(this.convention, () -> afterContext, this.registry);
+			Observation before = Observation.createNotStarted(this.convention, () -> beforeContext, this.registry)
+					.parentObservation(parentObservation);
+			Observation after = Observation.createNotStarted(this.convention, () -> afterContext, this.registry)
+					.parentObservation(parentObservation);
 			AroundWebFilterObservation parent = AroundWebFilterObservation.create(before, after);
 			exchange.getAttributes().put(ATTRIBUTE, parent);
 			return parent;