Browse Source

Address Observation Bean Name Collisions

Closes gh-16161
Josh Cummings 9 months ago
parent
commit
2b5a2eef82

+ 169 - 0
config/src/integration-test/java/org/springframework/security/config/annotation/rsocket/HelloRSocketWithWebFluxITests.java

@@ -0,0 +1,169 @@
+/*
+ * Copyright 2019 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.config.annotation.rsocket;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import io.rsocket.core.RSocketServer;
+import io.rsocket.exceptions.RejectedSetupException;
+import io.rsocket.frame.decoder.PayloadDecoder;
+import io.rsocket.transport.netty.server.CloseableChannel;
+import io.rsocket.transport.netty.server.TcpServerTransport;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.extension.ExtendWith;
+
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.context.annotation.Bean;
+import org.springframework.context.annotation.Configuration;
+import org.springframework.messaging.handler.annotation.MessageMapping;
+import org.springframework.messaging.rsocket.RSocketRequester;
+import org.springframework.messaging.rsocket.RSocketStrategies;
+import org.springframework.messaging.rsocket.annotation.support.RSocketMessageHandler;
+import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity;
+import org.springframework.security.core.userdetails.MapReactiveUserDetailsService;
+import org.springframework.security.core.userdetails.User;
+import org.springframework.security.core.userdetails.UserDetails;
+import org.springframework.security.rsocket.core.SecuritySocketAcceptorInterceptor;
+import org.springframework.security.rsocket.metadata.BasicAuthenticationEncoder;
+import org.springframework.stereotype.Controller;
+import org.springframework.test.context.ContextConfiguration;
+import org.springframework.test.context.junit.jupiter.SpringExtension;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+
+/**
+ * @author Rob Winch
+ */
+@ContextConfiguration
+@ExtendWith(SpringExtension.class)
+public class HelloRSocketWithWebFluxITests {
+
+	@Autowired
+	RSocketMessageHandler handler;
+
+	@Autowired
+	SecuritySocketAcceptorInterceptor interceptor;
+
+	@Autowired
+	ServerController controller;
+
+	private CloseableChannel server;
+
+	private RSocketRequester requester;
+
+	@BeforeEach
+	public void setup() {
+		// @formatter:off
+		this.server = RSocketServer.create()
+				.payloadDecoder(PayloadDecoder.ZERO_COPY)
+				.interceptors((registry) ->
+					registry.forSocketAcceptor(this.interceptor)
+				)
+				.acceptor(this.handler.responder())
+				.bind(TcpServerTransport.create("localhost", 0))
+				.block();
+		// @formatter:on
+	}
+
+	@AfterEach
+	public void dispose() {
+		this.requester.rsocket().dispose();
+		this.server.dispose();
+		this.controller.payloads.clear();
+	}
+
+	// gh-16161
+	@Test
+	public void retrieveMonoWhenSecureThenDenied() {
+		// @formatter:off
+		this.requester = RSocketRequester.builder()
+			.rsocketStrategies(this.handler.getRSocketStrategies())
+			.connectTcp("localhost", this.server.address().getPort())
+			.block();
+		// @formatter:on
+		String data = "rob";
+		// @formatter:off
+		assertThatExceptionOfType(Exception.class).isThrownBy(
+				() -> this.requester.route("secure.retrieve-mono")
+						.data(data)
+						.retrieveMono(String.class)
+						.block()
+				)
+				.matches((ex) -> ex instanceof RejectedSetupException
+						|| ex.getClass().toString().contains("ReactiveException"));
+		// @formatter:on
+		assertThat(this.controller.payloads).isEmpty();
+	}
+
+	@Configuration
+	@EnableRSocketSecurity
+	@EnableWebFluxSecurity
+	static class Config {
+
+		@Bean
+		ServerController controller() {
+			return new ServerController();
+		}
+
+		@Bean
+		RSocketMessageHandler messageHandler() {
+			RSocketMessageHandler handler = new RSocketMessageHandler();
+			handler.setRSocketStrategies(rsocketStrategies());
+			return handler;
+		}
+
+		@Bean
+		RSocketStrategies rsocketStrategies() {
+			return RSocketStrategies.builder().encoder(new BasicAuthenticationEncoder()).build();
+		}
+
+		@Bean
+		MapReactiveUserDetailsService uds() {
+			// @formatter:off
+			UserDetails rob = User.withDefaultPasswordEncoder()
+					.username("rob")
+					.password("password")
+					.roles("USER", "ADMIN")
+					.build();
+			// @formatter:on
+			return new MapReactiveUserDetailsService(rob);
+		}
+
+	}
+
+	@Controller
+	static class ServerController {
+
+		private List<String> payloads = new ArrayList<>();
+
+		@MessageMapping("**")
+		String retrieveMono(String payload) {
+			add(payload);
+			return "Hi " + payload;
+		}
+
+		private void add(String p) {
+			this.payloads.add(p);
+		}
+
+	}
+
+}

+ 8 - 2
config/src/main/java/org/springframework/security/config/annotation/rsocket/RSocketSecurityConfiguration.java

@@ -16,6 +16,8 @@
 
 
 package org.springframework.security.config.annotation.rsocket;
 package org.springframework.security.config.annotation.rsocket;
 
 
+import java.util.Map;
+
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.ApplicationContext;
 import org.springframework.context.ApplicationContext;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Bean;
@@ -62,8 +64,12 @@ class RSocketSecurityConfiguration {
 	}
 	}
 
 
 	@Autowired(required = false)
 	@Autowired(required = false)
-	void setAuthenticationManagerPostProcessor(ObjectPostProcessor<ReactiveAuthenticationManager> postProcessor) {
-		this.postProcessor = postProcessor;
+	void setAuthenticationManagerPostProcessor(
+			Map<String, ObjectPostProcessor<ReactiveAuthenticationManager>> postProcessors) {
+		if (postProcessors.size() == 1) {
+			this.postProcessor = postProcessors.values().iterator().next();
+		}
+		this.postProcessor = postProcessors.get("rSocketAuthenticationManagerPostProcessor");
 	}
 	}
 
 
 	@Bean(name = RSOCKET_SECURITY_BEAN_NAME)
 	@Bean(name = RSOCKET_SECURITY_BEAN_NAME)

+ 2 - 18
config/src/main/java/org/springframework/security/config/annotation/rsocket/ReactiveObservationConfiguration.java

@@ -29,9 +29,7 @@ import org.springframework.security.authorization.ObservationReactiveAuthorizati
 import org.springframework.security.authorization.ReactiveAuthorizationManager;
 import org.springframework.security.authorization.ReactiveAuthorizationManager;
 import org.springframework.security.config.ObjectPostProcessor;
 import org.springframework.security.config.ObjectPostProcessor;
 import org.springframework.security.config.observation.SecurityObservationSettings;
 import org.springframework.security.config.observation.SecurityObservationSettings;
-import org.springframework.security.web.server.ObservationWebFilterChainDecorator;
-import org.springframework.security.web.server.WebFilterChainProxy.WebFilterChainDecorator;
-import org.springframework.web.server.ServerWebExchange;
+import org.springframework.security.rsocket.api.PayloadExchange;
 
 
 @Configuration(proxyBeanMethods = false)
 @Configuration(proxyBeanMethods = false)
 @Role(BeanDefinition.ROLE_INFRASTRUCTURE)
 @Role(BeanDefinition.ROLE_INFRASTRUCTURE)
@@ -45,7 +43,7 @@ class ReactiveObservationConfiguration {
 
 
 	@Bean
 	@Bean
 	@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
 	@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
-	static ObjectPostProcessor<ReactiveAuthorizationManager<ServerWebExchange>> rSocketAuthorizationManagerPostProcessor(
+	static ObjectPostProcessor<ReactiveAuthorizationManager<PayloadExchange>> rSocketAuthorizationManagerPostProcessor(
 			ObjectProvider<ObservationRegistry> registry, ObjectProvider<SecurityObservationSettings> predicate) {
 			ObjectProvider<ObservationRegistry> registry, ObjectProvider<SecurityObservationSettings> predicate) {
 		return new ObjectPostProcessor<>() {
 		return new ObjectPostProcessor<>() {
 			@Override
 			@Override
@@ -71,18 +69,4 @@ class ReactiveObservationConfiguration {
 		};
 		};
 	}
 	}
 
 
-	@Bean
-	@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
-	static ObjectPostProcessor<WebFilterChainDecorator> rSocketFilterChainDecoratorPostProcessor(
-			ObjectProvider<ObservationRegistry> registry, ObjectProvider<SecurityObservationSettings> predicate) {
-		return new ObjectPostProcessor<>() {
-			@Override
-			public WebFilterChainDecorator postProcess(WebFilterChainDecorator object) {
-				ObservationRegistry r = registry.getIfUnique(() -> ObservationRegistry.NOOP);
-				boolean active = !r.isNoop() && predicate.getIfUnique(() -> all).shouldObserveRequests();
-				return active ? new ObservationWebFilterChainDecorator(r) : object;
-			}
-		};
-	}
-
 }
 }

+ 1 - 1
config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveObservationConfiguration.java

@@ -59,7 +59,7 @@ class ReactiveObservationConfiguration {
 
 
 	@Bean
 	@Bean
 	@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
 	@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
-	static ObjectPostProcessor<ReactiveAuthenticationManager> authenticationManagerPostProcessor(
+	static ObjectPostProcessor<ReactiveAuthenticationManager> reactiveAuthenticationManagerPostProcessor(
 			ObjectProvider<ObservationRegistry> registry, ObjectProvider<SecurityObservationSettings> predicate) {
 			ObjectProvider<ObservationRegistry> registry, ObjectProvider<SecurityObservationSettings> predicate) {
 		return new ObjectPostProcessor<>() {
 		return new ObjectPostProcessor<>() {
 			@Override
 			@Override

+ 8 - 2
config/src/main/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfiguration.java

@@ -16,6 +16,8 @@
 
 
 package org.springframework.security.config.annotation.web.reactive;
 package org.springframework.security.config.annotation.web.reactive;
 
 
+import java.util.Map;
+
 import org.springframework.beans.BeansException;
 import org.springframework.beans.BeansException;
 import org.springframework.beans.factory.BeanFactory;
 import org.springframework.beans.factory.BeanFactory;
 import org.springframework.beans.factory.ObjectProvider;
 import org.springframework.beans.factory.ObjectProvider;
@@ -96,8 +98,12 @@ class ServerHttpSecurityConfiguration {
 	}
 	}
 
 
 	@Autowired(required = false)
 	@Autowired(required = false)
-	void setAuthenticationManagerPostProcessor(ObjectPostProcessor<ReactiveAuthenticationManager> postProcessor) {
-		this.postProcessor = postProcessor;
+	void setAuthenticationManagerPostProcessor(
+			Map<String, ObjectPostProcessor<ReactiveAuthenticationManager>> postProcessors) {
+		if (postProcessors.size() == 1) {
+			this.postProcessor = postProcessors.values().iterator().next();
+		}
+		this.postProcessor = postProcessors.get("reactiveAuthenticationManagerPostProcessor");
 	}
 	}
 
 
 	@Autowired(required = false)
 	@Autowired(required = false)

+ 25 - 0
config/src/test/java/org/springframework/security/config/annotation/web/reactive/ServerHttpSecurityConfigurationTests.java

@@ -242,6 +242,31 @@ public class ServerHttpSecurityConfigurationTests {
 		assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain after");
 		assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain after");
 	}
 	}
 
 
+	// gh-16161
+	@Test
+	public void getWhenUsingRSocketThenObservesRequest() {
+		this.spring.register(ObservationRegistryConfig.class, RSocketSecurityConfig.class).autowire();
+		// @formatter:off
+		this.webClient
+				.get()
+				.uri("/hello")
+				.headers((headers) -> headers.setBasicAuth("user", "password"))
+				.exchange()
+				.expectStatus()
+				.isNotFound();
+		// @formatter:on
+		ObservationHandler<Observation.Context> handler = this.spring.getContext().getBean(ObservationHandler.class);
+		ArgumentCaptor<Observation.Context> captor = ArgumentCaptor.forClass(Observation.Context.class);
+		verify(handler, times(6)).onStart(captor.capture());
+		Iterator<Observation.Context> contexts = captor.getAllValues().iterator();
+		assertThat(contexts.next().getContextualName()).isEqualTo("http get");
+		assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain before");
+		assertThat(contexts.next().getName()).isEqualTo("spring.security.authentications");
+		assertThat(contexts.next().getName()).isEqualTo("spring.security.authorizations");
+		assertThat(contexts.next().getName()).isEqualTo("spring.security.http.secured.requests");
+		assertThat(contexts.next().getContextualName()).isEqualTo("security filterchain after");
+	}
+
 	@Configuration
 	@Configuration
 	static class SubclassConfig extends ServerHttpSecurityConfiguration {
 	static class SubclassConfig extends ServerHttpSecurityConfiguration {