浏览代码

Close Both Observations

Depending on when a request is cancelled, the before and after observation
starts and stops may be called out of order due to the order in
which their doOnCancel handlers are invoked.

To address this, the before filter-wrapper now always closes both the
before observation and the after observation. Since the before filter-
wrapper wraps the entire request, this ensures that either that was
started is stopped, and either that has not been started yet cannot
inadvertently be started by any unexpected ordering of events that
follows.

Closes gh-14031
Josh Cummings 1 年之前
父节点
当前提交
5dce82c48b

+ 25 - 4
web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java

@@ -292,7 +292,13 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 
 			@Override
 			public void stop() {
-				this.currentObservation.get().stop();
+				this.before.stop();
+				this.after.stop();
+			}
+
+			private void close() {
+				this.before.close();
+				this.after.close();
 			}
 
 			@Override
@@ -357,11 +363,11 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 					start();
 					// @formatter:off
 					return filter.filter(exchange, chain)
-							.doOnSuccess((v) -> stop())
-							.doOnCancel(this::stop)
+							.doOnSuccess((v) -> close())
+							.doOnCancel(this::close)
 							.doOnError((t) -> {
 								error(t);
-								stop();
+								close();
 							})
 							.contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, this));
 					// @formatter:on
@@ -433,6 +439,21 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 					}
 				}
 
+				private void close() {
+					try {
+						this.lock.lock();
+						if (this.state.compareAndSet(1, 3)) {
+							this.observation.stop();
+						}
+						else {
+							this.state.set(3);
+						}
+					}
+					finally {
+						this.lock.unlock();
+					}
+				}
+
 			}
 
 		}

+ 115 - 0
web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java

@@ -78,6 +78,98 @@ public class ObservationWebFilterChainDecoratorTests {
 		verifyNoInteractions(handler);
 	}
 
+	@Test
+	void decorateWhenTerminatingFilterThenObserves() {
+		AccumulatingObservationHandler handler = new AccumulatingObservationHandler();
+		ObservationRegistry registry = ObservationRegistry.create();
+		registry.observationConfig().observationHandler(handler);
+		ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry);
+		WebFilterChain chain = mock(WebFilterChain.class);
+		given(chain.filter(any())).willReturn(Mono.error(() -> new Exception("ack")));
+		WebFilterChain decorated = decorator.decorate(chain,
+				List.of(new BasicAuthenticationFilter(), new TerminatingFilter()));
+		Observation http = Observation.start("http", registry).contextualName("http");
+		try {
+			decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build()))
+				.contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, http))
+				.block();
+		}
+		catch (Exception ex) {
+			http.error(ex);
+		}
+		finally {
+			http.stop();
+		}
+		handler.assertSpanStart(0, "http", null);
+		handler.assertSpanStart(1, "spring.security.filterchains", "http");
+		handler.assertSpanStop(2, "security filterchain before");
+		handler.assertSpanStart(3, "spring.security.filterchains", "http");
+		handler.assertSpanStop(4, "security filterchain after");
+		handler.assertSpanStop(5, "http");
+	}
+
+	@Test
+	void decorateWhenFilterErrorThenStopsObservation() {
+		AccumulatingObservationHandler handler = new AccumulatingObservationHandler();
+		ObservationRegistry registry = ObservationRegistry.create();
+		registry.observationConfig().observationHandler(handler);
+		ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry);
+		WebFilterChain chain = mock(WebFilterChain.class);
+		WebFilterChain decorated = decorator.decorate(chain, List.of(new ErroringFilter()));
+		Observation http = Observation.start("http", registry).contextualName("http");
+		try {
+			decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build()))
+				.contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, http))
+				.block();
+		}
+		catch (Exception ex) {
+			http.error(ex);
+		}
+		finally {
+			http.stop();
+		}
+		handler.assertSpanStart(0, "http", null);
+		handler.assertSpanStart(1, "spring.security.filterchains", "http");
+		handler.assertSpanError(2);
+		handler.assertSpanStop(3, "security filterchain before");
+		handler.assertSpanError(4);
+		handler.assertSpanStop(5, "http");
+	}
+
+	@Test
+	void decorateWhenErrorSignalThenStopsObservation() {
+		AccumulatingObservationHandler handler = new AccumulatingObservationHandler();
+		ObservationRegistry registry = ObservationRegistry.create();
+		registry.observationConfig().observationHandler(handler);
+		ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry);
+		WebFilterChain chain = mock(WebFilterChain.class);
+		given(chain.filter(any())).willReturn(Mono.error(() -> new Exception("ack")));
+		WebFilterChain decorated = decorator.decorate(chain, List.of(new BasicAuthenticationFilter()));
+		Observation http = Observation.start("http", registry).contextualName("http");
+		try {
+			decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build()))
+				.contextWrite((context) -> context.put(ObservationThreadLocalAccessor.KEY, http))
+				.block();
+		}
+		catch (Exception ex) {
+			http.error(ex);
+		}
+		finally {
+			http.stop();
+		}
+		handler.assertSpanStart(0, "http", null);
+		handler.assertSpanStart(1, "spring.security.filterchains", "http");
+		handler.assertSpanStop(2, "security filterchain before");
+		handler.assertSpanStart(3, "secured request", "security filterchain before");
+		handler.assertSpanError(4);
+		handler.assertSpanStop(5, "secured request");
+		handler.assertSpanStart(6, "spring.security.filterchains", "http");
+		handler.assertSpanError(7);
+		handler.assertSpanStop(8, "security filterchain after");
+		handler.assertSpanError(9);
+		handler.assertSpanStop(10, "http");
+	}
+
 	// gh-12849
 	@Test
 	void decorateWhenCustomAfterFilterThenObserves() {
@@ -171,6 +263,24 @@ public class ObservationWebFilterChainDecoratorTests {
 
 	}
 
+	static class ErroringFilter implements WebFilter {
+
+		@Override
+		public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
+			return Mono.error(() -> new RuntimeException("ack"));
+		}
+
+	}
+
+	static class TerminatingFilter implements WebFilter {
+
+		@Override
+		public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
+			return Mono.empty();
+		}
+
+	}
+
 	static class AccumulatingObservationHandler implements ObservationHandler<Observation.Context> {
 
 		List<Event> contexts = new ArrayList<>();
@@ -246,6 +356,11 @@ public class ObservationWebFilterChainDecoratorTests {
 			}
 		}
 
+		private void assertSpanError(int index) {
+			Event event = this.contexts.get(index);
+			assertThat(event.event).isEqualTo("error");
+		}
+
 		static class Event {
 
 			String event;