Przeglądaj źródła

Address Observability Thread Safety

Closes gh-12829
Josh Cummings 2 lat temu
rodzic
commit
c06e604278

+ 56 - 30
web/src/main/java/org/springframework/security/web/ObservationFilterChainDecorator.java

@@ -18,9 +18,8 @@ package org.springframework.security.web;
 
 import java.io.IOException;
 import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Iterator;
 import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
 import io.micrometer.common.KeyValues;
@@ -227,49 +226,38 @@ public final class ObservationFilterChainDecorator implements FilterChainProxy.F
 
 		class SimpleAroundFilterObservation implements AroundFilterObservation {
 
-			private final Iterator<Observation> observations;
+			private final ObservationReference before;
 
-			private final Observation before;
+			private final ObservationReference after;
 
-			private final Observation after;
-
-			private final AtomicReference<Observation.Scope> currentScope = new AtomicReference<>(null);
+			private final AtomicReference<ObservationReference> reference = new AtomicReference<>(
+					ObservationReference.NOOP);
 
 			SimpleAroundFilterObservation(Observation before, Observation after) {
-				this.before = before;
-				this.after = after;
-				this.observations = Arrays.asList(before, after).iterator();
+				this.before = new ObservationReference(before);
+				this.after = new ObservationReference(after);
 			}
 
 			@Override
 			public void start() {
-				if (this.observations.hasNext()) {
-					stop();
-					Observation observation = this.observations.next();
-					observation.start();
-					Observation.Scope scope = observation.openScope();
-					this.currentScope.set(scope);
+				if (this.reference.compareAndSet(ObservationReference.NOOP, this.before)) {
+					this.before.start();
+					return;
+				}
+				if (this.reference.compareAndSet(this.before, this.after)) {
+					this.before.stop();
+					this.after.start();
 				}
 			}
 
 			@Override
 			public void error(Throwable ex) {
-				Observation.Scope scope = this.currentScope.get();
-				if (scope == null) {
-					return;
-				}
-				scope.close();
-				scope.getCurrentObservation().error(ex);
+				this.reference.get().error(ex);
 			}
 
 			@Override
 			public void stop() {
-				Observation.Scope scope = this.currentScope.getAndSet(null);
-				if (scope == null) {
-					return;
-				}
-				scope.close();
-				scope.getCurrentObservation().stop();
+				this.reference.get().stop();
 			}
 
 			@Override
@@ -304,12 +292,50 @@ public final class ObservationFilterChainDecorator implements FilterChainProxy.F
 
 			@Override
 			public Observation before() {
-				return this.before;
+				return this.before.observation;
 			}
 
 			@Override
 			public Observation after() {
-				return this.after;
+				return this.after.observation;
+			}
+
+			private static final class ObservationReference {
+
+				private static final ObservationReference NOOP = new ObservationReference(Observation.NOOP);
+
+				private final AtomicInteger state = new AtomicInteger(0);
+
+				private final Observation observation;
+
+				private volatile Observation.Scope scope;
+
+				private ObservationReference(Observation observation) {
+					this.observation = observation;
+					this.scope = Observation.Scope.NOOP;
+				}
+
+				private void start() {
+					if (this.state.compareAndSet(0, 1)) {
+						this.observation.start();
+						this.scope = this.observation.openScope();
+					}
+				}
+
+				private void error(Throwable error) {
+					if (this.state.get() == 1) {
+						this.scope.close();
+						this.scope.getCurrentObservation().error(error);
+					}
+				}
+
+				private void stop() {
+					if (this.state.compareAndSet(1, 2)) {
+						this.scope.close();
+						this.scope.getCurrentObservation().stop();
+					}
+				}
+
 			}
 
 		}

+ 51 - 28
web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,10 +17,9 @@
 package org.springframework.security.web.server;
 
 import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Iterator;
 import java.util.List;
 import java.util.ListIterator;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicReference;
 
 import io.micrometer.common.KeyValues;
@@ -253,46 +252,38 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 
 		class SimpleAroundWebFilterObservation implements AroundWebFilterObservation {
 
-			private final Iterator<Observation> observations;
+			private final ObservationReference before;
 
-			private final Observation before;
+			private final ObservationReference after;
 
-			private final Observation after;
-
-			private final AtomicReference<Observation> currentObservation = new AtomicReference<>(null);
+			private final AtomicReference<ObservationReference> currentObservation = new AtomicReference<>(
+					ObservationReference.NOOP);
 
 			SimpleAroundWebFilterObservation(Observation before, Observation after) {
-				this.before = before;
-				this.after = after;
-				this.observations = Arrays.asList(before, after).iterator();
+				this.before = new ObservationReference(before);
+				this.after = new ObservationReference(after);
 			}
 
 			@Override
 			public void start() {
-				if (this.observations.hasNext()) {
-					stop();
-					Observation observation = this.observations.next();
-					observation.start();
-					this.currentObservation.set(observation);
+				if (this.currentObservation.compareAndSet(ObservationReference.NOOP, this.before)) {
+					this.before.start();
+					return;
+				}
+				if (this.currentObservation.compareAndSet(this.before, this.after)) {
+					this.before.stop();
+					this.after.start();
 				}
 			}
 
 			@Override
 			public void error(Throwable ex) {
-				Observation observation = this.currentObservation.get();
-				if (observation == null) {
-					return;
-				}
-				observation.error(ex);
+				this.currentObservation.get().error(ex);
 			}
 
 			@Override
 			public void stop() {
-				Observation observation = this.currentObservation.getAndSet(null);
-				if (observation == null) {
-					return;
-				}
-				observation.stop();
+				this.currentObservation.get().stop();
 			}
 
 			@Override
@@ -329,12 +320,44 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 
 			@Override
 			public Observation before() {
-				return this.before;
+				return this.before.observation;
 			}
 
 			@Override
 			public Observation after() {
-				return this.after;
+				return this.after.observation;
+			}
+
+			private static final class ObservationReference {
+
+				private static final ObservationReference NOOP = new ObservationReference(Observation.NOOP);
+
+				private final AtomicInteger state = new AtomicInteger(0);
+
+				private final Observation observation;
+
+				private ObservationReference(Observation observation) {
+					this.observation = observation;
+				}
+
+				private void start() {
+					if (this.state.compareAndSet(0, 1)) {
+						this.observation.start();
+					}
+				}
+
+				private void error(Throwable ex) {
+					if (this.state.get() == 1) {
+						this.observation.error(ex);
+					}
+				}
+
+				private void stop() {
+					if (this.state.compareAndSet(1, 2)) {
+						this.observation.stop();
+					}
+				}
+
 			}
 
 		}