Browse Source

Align Filter Chain Observability Lineage

Closes gh-12849
Josh Cummings 2 năm trước cách đây
mục cha
commit
6db2b0dcd0

+ 149 - 12
web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java

@@ -22,6 +22,7 @@ import java.util.ListIterator;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 import java.util.concurrent.atomic.AtomicReference;
 
 
+import io.micrometer.common.KeyValue;
 import io.micrometer.common.KeyValues;
 import io.micrometer.common.KeyValues;
 import io.micrometer.observation.Observation;
 import io.micrometer.observation.Observation;
 import io.micrometer.observation.ObservationConvention;
 import io.micrometer.observation.ObservationConvention;
@@ -265,20 +266,23 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 			}
 			}
 
 
 			@Override
 			@Override
-			public void start() {
+			public Observation start() {
 				if (this.currentObservation.compareAndSet(ObservationReference.NOOP, this.before)) {
 				if (this.currentObservation.compareAndSet(ObservationReference.NOOP, this.before)) {
 					this.before.start();
 					this.before.start();
-					return;
+					return this.before.observation;
 				}
 				}
 				if (this.currentObservation.compareAndSet(this.before, this.after)) {
 				if (this.currentObservation.compareAndSet(this.before, this.after)) {
 					this.before.stop();
 					this.before.stop();
 					this.after.start();
 					this.after.start();
+					return this.after.observation;
 				}
 				}
+				return Observation.NOOP;
 			}
 			}
 
 
 			@Override
 			@Override
-			public void error(Throwable ex) {
+			public Observation error(Throwable ex) {
 				this.currentObservation.get().error(ex);
 				this.currentObservation.get().error(ex);
+				return this.currentObservation.get().observation;
 			}
 			}
 
 
 			@Override
 			@Override
@@ -286,6 +290,46 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 				this.currentObservation.get().stop();
 				this.currentObservation.get().stop();
 			}
 			}
 
 
+			@Override
+			public Observation contextualName(String contextualName) {
+				return this.currentObservation.get().observation.contextualName(contextualName);
+			}
+
+			@Override
+			public Observation parentObservation(Observation parentObservation) {
+				return this.currentObservation.get().observation.parentObservation(parentObservation);
+			}
+
+			@Override
+			public Observation lowCardinalityKeyValue(KeyValue keyValue) {
+				return this.currentObservation.get().observation.lowCardinalityKeyValue(keyValue);
+			}
+
+			@Override
+			public Observation highCardinalityKeyValue(KeyValue keyValue) {
+				return this.currentObservation.get().observation.highCardinalityKeyValue(keyValue);
+			}
+
+			@Override
+			public Observation observationConvention(ObservationConvention<?> observationConvention) {
+				return this.currentObservation.get().observation.observationConvention(observationConvention);
+			}
+
+			@Override
+			public Observation event(Event event) {
+				return this.currentObservation.get().observation.event(event);
+			}
+
+			@Override
+			public Context getContext() {
+				return this.currentObservation.get().observation.getContext();
+			}
+
+			@Override
+			public Scope openScope() {
+				return this.currentObservation.get().observation.openScope();
+			}
+
 			@Override
 			@Override
 			public WebFilterChain wrap(WebFilterChain chain) {
 			public WebFilterChain wrap(WebFilterChain chain) {
 				return (exchange) -> {
 				return (exchange) -> {
@@ -313,7 +357,8 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 							.doOnError((t) -> {
 							.doOnError((t) -> {
 								error(t);
 								error(t);
 								stop();
 								stop();
-							});
+							})
+							.contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, this));
 					// @formatter:on
 					// @formatter:on
 				};
 				};
 			}
 			}
@@ -328,6 +373,11 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 				return this.after.observation;
 				return this.after.observation;
 			}
 			}
 
 
+			@Override
+			public String toString() {
+				return this.currentObservation.get().observation.toString();
+			}
+
 			private static final class ObservationReference {
 			private static final class ObservationReference {
 
 
 				private static final ObservationReference NOOP = new ObservationReference(Observation.NOOP);
 				private static final ObservationReference NOOP = new ObservationReference(Observation.NOOP);
@@ -364,7 +414,7 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 
 
 	}
 	}
 
 
-	interface WebFilterObservation {
+	interface WebFilterObservation extends Observation {
 
 
 		WebFilterObservation NOOP = new WebFilterObservation() {
 		WebFilterObservation NOOP = new WebFilterObservation() {
 		};
 		};
@@ -376,13 +426,59 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 			return new SimpleWebFilterObservation(observation);
 			return new SimpleWebFilterObservation(observation);
 		}
 		}
 
 
-		default void start() {
+		@Override
+		default Observation contextualName(String contextualName) {
+			return Observation.NOOP;
+		}
+
+		@Override
+		default Observation parentObservation(Observation parentObservation) {
+			return Observation.NOOP;
+		}
+
+		@Override
+		default Observation lowCardinalityKeyValue(KeyValue keyValue) {
+			return Observation.NOOP;
 		}
 		}
 
 
-		default void error(Throwable ex) {
+		@Override
+		default Observation highCardinalityKeyValue(KeyValue keyValue) {
+			return Observation.NOOP;
 		}
 		}
 
 
+		@Override
+		default Observation observationConvention(ObservationConvention<?> observationConvention) {
+			return Observation.NOOP;
+		}
+
+		@Override
+		default Observation error(Throwable error) {
+			return Observation.NOOP;
+		}
+
+		@Override
+		default Observation event(Event event) {
+			return Observation.NOOP;
+		}
+
+		@Override
+		default Observation start() {
+			return Observation.NOOP;
+		}
+
+		@Override
+		default Context getContext() {
+			return new Observation.Context();
+		}
+
+		@Override
 		default void stop() {
 		default void stop() {
+
+		}
+
+		@Override
+		default Scope openScope() {
+			return Scope.NOOP;
 		}
 		}
 
 
 		default WebFilter wrap(WebFilter filter) {
 		default WebFilter wrap(WebFilter filter) {
@@ -402,13 +498,13 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 			}
 			}
 
 
 			@Override
 			@Override
-			public void start() {
-				this.observation.start();
+			public Observation start() {
+				return this.observation.start();
 			}
 			}
 
 
 			@Override
 			@Override
-			public void error(Throwable ex) {
-				this.observation.error(ex);
+			public Observation error(Throwable ex) {
+				return this.observation.error(ex);
 			}
 			}
 
 
 			@Override
 			@Override
@@ -416,6 +512,46 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 				this.observation.stop();
 				this.observation.stop();
 			}
 			}
 
 
+			@Override
+			public Observation contextualName(String contextualName) {
+				return this.observation.contextualName(contextualName);
+			}
+
+			@Override
+			public Observation parentObservation(Observation parentObservation) {
+				return this.observation.parentObservation(parentObservation);
+			}
+
+			@Override
+			public Observation lowCardinalityKeyValue(KeyValue keyValue) {
+				return this.observation.lowCardinalityKeyValue(keyValue);
+			}
+
+			@Override
+			public Observation highCardinalityKeyValue(KeyValue keyValue) {
+				return this.observation.highCardinalityKeyValue(keyValue);
+			}
+
+			@Override
+			public Observation observationConvention(ObservationConvention<?> observationConvention) {
+				return this.observation.observationConvention(observationConvention);
+			}
+
+			@Override
+			public Observation event(Event event) {
+				return this.observation.event(event);
+			}
+
+			@Override
+			public Context getContext() {
+				return this.observation.getContext();
+			}
+
+			@Override
+			public Scope openScope() {
+				return this.observation.openScope();
+			}
+
 			@Override
 			@Override
 			public WebFilter wrap(WebFilter filter) {
 			public WebFilter wrap(WebFilter filter) {
 				if (this.observation.isNoop()) {
 				if (this.observation.isNoop()) {
@@ -442,7 +578,8 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 							.doOnCancel(this.observation::stop).doOnError((t) -> {
 							.doOnCancel(this.observation::stop).doOnError((t) -> {
 								this.observation.error(t);
 								this.observation.error(t);
 								this.observation.stop();
 								this.observation.stop();
-							});
+							}).contextWrite(
+									(context) -> context.put(ObservationThreadLocalAccessor.KEY, this.observation));
 				};
 				};
 			}
 			}
 
 

+ 150 - 0
web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java

@@ -16,15 +16,22 @@
 
 
 package org.springframework.security.web.server;
 package org.springframework.security.web.server;
 
 
+import java.util.ArrayList;
+import java.util.List;
+
+import io.micrometer.observation.Observation;
 import io.micrometer.observation.ObservationHandler;
 import io.micrometer.observation.ObservationHandler;
 import io.micrometer.observation.ObservationRegistry;
 import io.micrometer.observation.ObservationRegistry;
+import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.Test;
 import reactor.core.publisher.Mono;
 import reactor.core.publisher.Mono;
 
 
 import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
 import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
 import org.springframework.mock.web.server.MockServerWebExchange;
 import org.springframework.mock.web.server.MockServerWebExchange;
+import org.springframework.web.server.WebFilter;
 import org.springframework.web.server.WebFilterChain;
 import org.springframework.web.server.WebFilterChain;
 
 
+import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
@@ -64,4 +71,147 @@ public class ObservationWebFilterChainDecoratorTests {
 		verifyNoInteractions(handler);
 		verifyNoInteractions(handler);
 	}
 	}
 
 
+	// gh-12849
+	@Test
+	void decorateWhenCustomAfterFilterThenObserves() {
+		AccumulatingObservationHandler handler = new AccumulatingObservationHandler();
+		ObservationRegistry registry = ObservationRegistry.create();
+		registry.observationConfig().observationHandler(handler);
+		ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry);
+		WebFilter mock = mock(WebFilter.class);
+		given(mock.filter(any(), any())).willReturn(Mono.empty());
+		WebFilterChain chain = mock(WebFilterChain.class);
+		given(chain.filter(any())).willReturn(Mono.empty());
+		WebFilterChain decorated = decorator.decorate(chain,
+				List.of((e, c) -> c.filter(e).then(Mono.deferContextual((context) -> {
+					Observation parentObservation = context.getOrDefault(ObservationThreadLocalAccessor.KEY, null);
+					Observation observation = Observation.createNotStarted("custom", registry)
+							.parentObservation(parentObservation).contextualName("custom").start();
+					return Mono.just("3").doOnSuccess((v) -> observation.stop()).doOnCancel(observation::stop)
+							.doOnError((t) -> {
+								observation.error(t);
+								observation.stop();
+							}).then(Mono.empty());
+				}))));
+		Observation http = Observation.start("http", registry).contextualName("http");
+		try {
+			decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build()))
+					.contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, http)).block();
+		}
+		finally {
+			http.stop();
+		}
+		handler.assertSpanStart(0, "http", null);
+		handler.assertSpanStart(1, "spring.security.filterchains", "http");
+		handler.assertSpanStop(2, "security filterchain before");
+		handler.assertSpanStart(3, "secured request", "security filterchain before");
+		handler.assertSpanStop(4, "secured request");
+		handler.assertSpanStart(5, "spring.security.filterchains", "http");
+		handler.assertSpanStart(6, "custom", "spring.security.filterchains");
+		handler.assertSpanStop(7, "custom");
+		handler.assertSpanStop(8, "security filterchain after");
+		handler.assertSpanStop(9, "http");
+	}
+
+	static class AccumulatingObservationHandler implements ObservationHandler<Observation.Context> {
+
+		List<Event> contexts = new ArrayList<>();
+
+		@Override
+		public boolean supportsContext(Observation.Context context) {
+			return true;
+		}
+
+		@Override
+		public void onStart(Observation.Context context) {
+			this.contexts.add(new Event("start", context));
+		}
+
+		@Override
+		public void onError(Observation.Context context) {
+			this.contexts.add(new Event("error", context));
+		}
+
+		@Override
+		public void onEvent(Observation.Event event, Observation.Context context) {
+			this.contexts.add(new Event("event", context));
+		}
+
+		@Override
+		public void onScopeOpened(Observation.Context context) {
+			this.contexts.add(new Event("opened", context));
+		}
+
+		@Override
+		public void onScopeClosed(Observation.Context context) {
+			this.contexts.add(new Event("closed", context));
+		}
+
+		@Override
+		public void onScopeReset(Observation.Context context) {
+			this.contexts.add(new Event("reset", context));
+		}
+
+		@Override
+		public void onStop(Observation.Context context) {
+			this.contexts.add(new Event("stop", context));
+		}
+
+		private void assertSpanStart(int index, String name, String parentName) {
+			Event event = this.contexts.get(index);
+			assertThat(event.event).isEqualTo("start");
+			if (event.contextualName == null) {
+				assertThat(event.name).isEqualTo(name);
+			}
+			else {
+				assertThat(event.contextualName).isEqualTo(name);
+			}
+			if (parentName == null) {
+				return;
+			}
+			if (event.parentContextualName == null) {
+				assertThat(event.parentName).isEqualTo(parentName);
+			}
+			else {
+				assertThat(event.parentContextualName).isEqualTo(parentName);
+			}
+		}
+
+		private void assertSpanStop(int index, String name) {
+			Event event = this.contexts.get(index);
+			assertThat(event.event).isEqualTo("stop");
+			if (event.contextualName == null) {
+				assertThat(event.name).isEqualTo(name);
+			}
+			else {
+				assertThat(event.contextualName).isEqualTo(name);
+			}
+		}
+
+		static class Event {
+
+			String event;
+
+			String name;
+
+			String contextualName;
+
+			String parentName;
+
+			String parentContextualName;
+
+			Event(String event, Observation.Context context) {
+				this.event = event;
+				this.name = context.getName();
+				this.contextualName = context.getContextualName();
+				if (context.getParentObservation() != null) {
+					this.parentName = context.getParentObservation().getContextView().getName();
+					this.parentContextualName = context.getParentObservation().getContextView().getContextualName();
+				}
+			}
+
+		}
+
+	}
+
 }
 }