Bladeren bron

Merge remote-tracking branch 'origin/6.4.x' into 6.5.x

Josh Cummings 1 dag geleden
bovenliggende
commit
857ca9c412

+ 12 - 1
web/src/main/java/org/springframework/security/web/ObservationFilterChainDecorator.java

@@ -46,13 +46,14 @@ import org.springframework.util.StringUtils;
  * wraps the chain in before and after observations
  *
  * @author Josh Cummings
+ * @author Nikita Konev
  * @since 6.0
  */
 public final class ObservationFilterChainDecorator implements FilterChainProxy.FilterChainDecorator {
 
 	private static final Log logger = LogFactory.getLog(FilterChainProxy.class);
 
-	private static final String ATTRIBUTE = ObservationFilterChainDecorator.class + ".observation";
+	static final String ATTRIBUTE = ObservationFilterChainDecorator.class + ".observation";
 
 	static final String UNSECURED_OBSERVATION_NAME = "spring.security.http.unsecured.requests";
 
@@ -250,6 +251,16 @@ public final class ObservationFilterChainDecorator implements FilterChainProxy.F
 		private AroundFilterObservation parent(HttpServletRequest request) {
 			FilterChainObservationContext beforeContext = FilterChainObservationContext.before();
 			FilterChainObservationContext afterContext = FilterChainObservationContext.after();
+
+			AroundFilterObservation existingParentObservation = (AroundFilterObservation) request
+				.getAttribute(ATTRIBUTE);
+			if (existingParentObservation != null) {
+				beforeContext
+					.setParentObservation(existingParentObservation.before().getContext().getParentObservation());
+				afterContext
+					.setParentObservation(existingParentObservation.after().getContext().getParentObservation());
+			}
+
 			Observation before = Observation.createNotStarted(this.convention, () -> beforeContext, this.registry);
 			Observation after = Observation.createNotStarted(this.convention, () -> afterContext, this.registry);
 			AroundFilterObservation parent = AroundFilterObservation.create(before, after);

+ 59 - 0
web/src/test/java/org/springframework/security/web/FilterChainProxyTests.java

@@ -314,6 +314,65 @@ public class FilterChainProxyTests {
 		assertFilterChainObservation(contexts.next(), "after", 1);
 	}
 
+	// gh-12610
+	@Test
+	void parentObservationIsTakenIntoAccountDuringDispatchError() throws Exception {
+		ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);
+		given(handler.supportsContext(any())).willReturn(true);
+		ObservationRegistry registry = ObservationRegistry.create();
+		registry.observationConfig().observationHandler(handler);
+
+		given(this.matcher.matches(any())).willReturn(true);
+		SecurityFilterChain sec = new DefaultSecurityFilterChain(this.matcher, Arrays.asList(this.filter));
+		FilterChainProxy fcp = new FilterChainProxy(sec);
+		fcp.setFilterChainDecorator(new ObservationFilterChainDecorator(registry));
+		Filter initialFilter = ObservationFilterChainDecorator.FilterObservation
+			.create(Observation.createNotStarted("wrap", registry))
+			.wrap(fcp);
+
+		ServletRequest initialRequest = new MockHttpServletRequest("GET", "/");
+		initialFilter.doFilter(initialRequest, new MockHttpServletResponse(), this.chain);
+
+		// simulate request attribute copying in case dispatching to ERROR
+		ObservationFilterChainDecorator.AroundFilterObservation parentObservation = (ObservationFilterChainDecorator.AroundFilterObservation) initialRequest
+			.getAttribute(ObservationFilterChainDecorator.ATTRIBUTE);
+		assertThat(parentObservation).isNotNull();
+
+		// simulate dispatching error-related request
+		Filter errorRelatedFilter = ObservationFilterChainDecorator.FilterObservation
+			.create(Observation.createNotStarted("wrap", registry))
+			.wrap(fcp);
+		ServletRequest errorRelatedRequest = new MockHttpServletRequest("GET", "/error");
+		errorRelatedRequest.setAttribute(ObservationFilterChainDecorator.ATTRIBUTE, parentObservation);
+		errorRelatedFilter.doFilter(errorRelatedRequest, new MockHttpServletResponse(), this.chain);
+
+		ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
+		verify(handler, times(8)).onStart(captor.capture());
+		verify(handler, times(8)).onStop(any());
+		List<Observation.Context> contexts = captor.getAllValues();
+
+		Observation.Context initialRequestObservationContextBefore = contexts.get(1);
+		Observation.Context initialRequestObservationContextAfter = contexts.get(3);
+		assertFilterChainObservation(initialRequestObservationContextBefore, "before", 1);
+		assertFilterChainObservation(initialRequestObservationContextAfter, "after", 1);
+
+		assertThat(initialRequestObservationContextBefore.getParentObservation()).isNotNull();
+		assertThat(initialRequestObservationContextBefore.getParentObservation())
+			.isSameAs(initialRequestObservationContextAfter.getParentObservation());
+
+		Observation.Context errorRelatedRequestObservationContextBefore = contexts.get(5);
+		Observation.Context errorRelatedRequestObservationContextAfter = contexts.get(7);
+		assertFilterChainObservation(errorRelatedRequestObservationContextBefore, "before", 1);
+		assertFilterChainObservation(errorRelatedRequestObservationContextAfter, "after", 1);
+
+		assertThat(errorRelatedRequestObservationContextBefore.getParentObservation()).isNotNull();
+		assertThat(errorRelatedRequestObservationContextBefore.getParentObservation())
+			.isSameAs(initialRequestObservationContextBefore.getParentObservation());
+		assertThat(errorRelatedRequestObservationContextAfter.getParentObservation()).isNotNull();
+		assertThat(errorRelatedRequestObservationContextAfter.getParentObservation())
+			.isSameAs(initialRequestObservationContextBefore.getParentObservation());
+	}
+
 	@Test
 	public void doFilterWhenMultipleFiltersThenObservationRegistryObserves() throws Exception {
 		ObservationHandler<Observation.Context> handler = mock(ObservationHandler.class);