Sfoglia il codice sorgente

Cast FilterChainObservationContext Safely

Closes gh-12268
Josh Cummings 2 anni fa
parent
commit
701f754e37

+ 10 - 8
web/src/main/java/org/springframework/security/web/ObservationFilterChainDecorator.java

@@ -177,17 +177,19 @@ public final class ObservationFilterChainDecorator implements FilterChainProxy.F
 		private void wrapFilter(ServletRequest request, ServletResponse response, FilterChain chain)
 				throws IOException, ServletException {
 			AroundFilterObservation parent = observation((HttpServletRequest) request);
-			FilterChainObservationContext parentBefore = (FilterChainObservationContext) parent.before().getContext();
-			parentBefore.setChainSize(this.size);
-			parentBefore.setFilterName(this.name);
-			parentBefore.setChainPosition(this.position);
+			if (parent.before().getContext() instanceof FilterChainObservationContext parentBefore) {
+				parentBefore.setChainSize(this.size);
+				parentBefore.setFilterName(this.name);
+				parentBefore.setChainPosition(this.position);
+			}
 			parent.before().event(Observation.Event.of(this.name + " before"));
 			this.filter.doFilter(request, response, chain);
 			parent.start();
-			FilterChainObservationContext parentAfter = (FilterChainObservationContext) parent.after().getContext();
-			parentAfter.setChainSize(this.size);
-			parentAfter.setFilterName(this.name);
-			parentAfter.setChainPosition(this.size - this.position + 1);
+			if (parent.after().getContext() instanceof FilterChainObservationContext parentAfter) {
+				parentAfter.setChainSize(this.size);
+				parentAfter.setFilterName(this.name);
+				parentAfter.setChainPosition(this.size - this.position + 1);
+			}
 			parent.after().event(Observation.Event.of(this.name + " after"));
 		}
 

+ 10 - 10
web/src/main/java/org/springframework/security/web/server/ObservationWebFilterChainDecorator.java

@@ -196,18 +196,18 @@ public final class ObservationWebFilterChainDecorator implements WebFilterChainP
 
 		private Mono<Void> wrapFilter(ServerWebExchange exchange, WebFilterChain chain) {
 			AroundWebFilterObservation parent = observation(exchange);
-			WebFilterChainObservationContext parentBefore = (WebFilterChainObservationContext) parent.before()
-					.getContext();
-			parentBefore.setChainSize(this.size);
-			parentBefore.setFilterName(this.name);
-			parentBefore.setChainPosition(this.position);
+			if (parent.before().getContext() instanceof WebFilterChainObservationContext parentBefore) {
+				parentBefore.setChainSize(this.size);
+				parentBefore.setFilterName(this.name);
+				parentBefore.setChainPosition(this.position);
+			}
 			return this.filter.filter(exchange, chain).doOnSuccess((result) -> {
 				parent.start();
-				WebFilterChainObservationContext parentAfter = (WebFilterChainObservationContext) parent.after()
-						.getContext();
-				parentAfter.setChainSize(this.size);
-				parentAfter.setFilterName(this.name);
-				parentAfter.setChainPosition(this.size - this.position + 1);
+				if (parent.after().getContext() instanceof WebFilterChainObservationContext parentAfter) {
+					parentAfter.setChainSize(this.size);
+					parentAfter.setFilterName(this.name);
+					parentAfter.setChainPosition(this.size - this.position + 1);
+				}
 			});
 		}
 

+ 64 - 0
web/src/test/java/org/springframework/security/web/ObservationFilterChainDecoratorTests.java

@@ -0,0 +1,64 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web;
+
+import io.micrometer.observation.ObservationHandler;
+import io.micrometer.observation.ObservationRegistry;
+import jakarta.servlet.FilterChain;
+import org.junit.jupiter.api.Test;
+
+import org.springframework.mock.web.MockHttpServletRequest;
+import org.springframework.mock.web.MockHttpServletResponse;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+
+/**
+ * Tests for {@link ObservationFilterChainDecorator}
+ */
+public class ObservationFilterChainDecoratorTests {
+
+	@Test
+	void decorateWhenDefaultsThenObserves() throws Exception {
+		ObservationHandler<?> handler = mock(ObservationHandler.class);
+		given(handler.supportsContext(any())).willReturn(true);
+		ObservationRegistry registry = ObservationRegistry.create();
+		registry.observationConfig().observationHandler(handler);
+		ObservationFilterChainDecorator decorator = new ObservationFilterChainDecorator(registry);
+		FilterChain chain = mock(FilterChain.class);
+		FilterChain decorated = decorator.decorate(chain);
+		decorated.doFilter(new MockHttpServletRequest("GET", "/"), new MockHttpServletResponse());
+		verify(handler).onStart(any());
+	}
+
+	@Test
+	void decorateWhenNoopThenDoesNotObserve() throws Exception {
+		ObservationHandler<?> handler = mock(ObservationHandler.class);
+		given(handler.supportsContext(any())).willReturn(true);
+		ObservationRegistry registry = ObservationRegistry.NOOP;
+		registry.observationConfig().observationHandler(handler);
+		ObservationFilterChainDecorator decorator = new ObservationFilterChainDecorator(registry);
+		FilterChain chain = mock(FilterChain.class);
+		FilterChain decorated = decorator.decorate(chain);
+		decorated.doFilter(new MockHttpServletRequest("GET", "/"), new MockHttpServletResponse());
+		verifyNoInteractions(handler);
+	}
+
+}

+ 67 - 0
web/src/test/java/org/springframework/security/web/server/ObservationWebFilterChainDecoratorTests.java

@@ -0,0 +1,67 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.web.server;
+
+import io.micrometer.observation.ObservationHandler;
+import io.micrometer.observation.ObservationRegistry;
+import org.junit.jupiter.api.Test;
+import reactor.core.publisher.Mono;
+
+import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
+import org.springframework.mock.web.server.MockServerWebExchange;
+import org.springframework.web.server.WebFilterChain;
+
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+
+/**
+ * Tests for {@link ObservationWebFilterChainDecorator}
+ */
+public class ObservationWebFilterChainDecoratorTests {
+
+	@Test
+	void decorateWhenDefaultsThenObserves() {
+		ObservationHandler<?> handler = mock(ObservationHandler.class);
+		given(handler.supportsContext(any())).willReturn(true);
+		ObservationRegistry registry = ObservationRegistry.create();
+		registry.observationConfig().observationHandler(handler);
+		ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry);
+		WebFilterChain chain = mock(WebFilterChain.class);
+		given(chain.filter(any())).willReturn(Mono.empty());
+		WebFilterChain decorated = decorator.decorate(chain);
+		decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build())).block();
+		verify(handler).onStart(any());
+	}
+
+	@Test
+	void decorateWhenNoopThenDoesNotObserve() {
+		ObservationHandler<?> handler = mock(ObservationHandler.class);
+		given(handler.supportsContext(any())).willReturn(true);
+		ObservationRegistry registry = ObservationRegistry.NOOP;
+		registry.observationConfig().observationHandler(handler);
+		ObservationWebFilterChainDecorator decorator = new ObservationWebFilterChainDecorator(registry);
+		WebFilterChain chain = mock(WebFilterChain.class);
+		given(chain.filter(any())).willReturn(Mono.empty());
+		WebFilterChain decorated = decorator.decorate(chain);
+		decorated.filter(MockServerWebExchange.from(MockServerHttpRequest.get("/").build())).block();
+		verifyNoInteractions(handler);
+	}
+
+}