浏览代码

Add subscriberContext to PayloadSocketAcceptor delegate.accept

Closes gh-8654
Rob Winch 5 年之前
父节点
当前提交
24a04f9c5f

+ 1 - 0
rsocket/src/main/java/org/springframework/security/rsocket/core/PayloadSocketAcceptor.java

@@ -72,6 +72,7 @@ class PayloadSocketAcceptor implements SocketAcceptor {
 		return intercept(setup, dataMimeType, metadataMimeType)
 			.flatMap(ctx -> this.delegate.accept(setup, sendingSocket)
 				.map(acceptingSocket -> new PayloadInterceptorRSocket(acceptingSocket, this.interceptors, metadataMimeType, dataMimeType, ctx))
+				.subscriberContext(ctx)
 			);
 	}
 

+ 50 - 0
rsocket/src/test/java/org/springframework/security/rsocket/core/CaptureSecurityContextSocketAcceptor.java

@@ -0,0 +1,50 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.rsocket.core;
+
+import io.rsocket.ConnectionSetupPayload;
+import io.rsocket.RSocket;
+import io.rsocket.SocketAcceptor;
+import reactor.core.publisher.Mono;
+
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContext;
+
+/**
+ * A {@link SocketAcceptor} that captures the {@link SecurityContext} and then continues with the {@link RSocket}
+ * @author Rob Winch
+ */
+class CaptureSecurityContextSocketAcceptor implements SocketAcceptor {
+	private final RSocket accept;
+
+	private SecurityContext securityContext;
+
+	CaptureSecurityContextSocketAcceptor(RSocket accept) {
+		this.accept = accept;
+	}
+
+	@Override
+	public Mono<RSocket> accept(ConnectionSetupPayload setup, RSocket sendingSocket) {
+		return ReactiveSecurityContextHolder.getContext()
+			.doOnNext(securityContext -> this.securityContext = securityContext)
+			.thenReturn(this.accept);
+	}
+
+	public SecurityContext getSecurityContext() {
+		return this.securityContext;
+	}
+}

+ 32 - 7
rsocket/src/test/java/org/springframework/security/rsocket/core/PayloadSocketAcceptorTests.java

@@ -16,6 +16,10 @@
 
 package org.springframework.security.rsocket.core;
 
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
 import io.rsocket.ConnectionSetupPayload;
 import io.rsocket.Payload;
 import io.rsocket.RSocket;
@@ -27,16 +31,16 @@ import org.junit.runner.RunWith;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
 import org.mockito.runners.MockitoJUnitRunner;
+import reactor.core.publisher.Mono;
+import reactor.util.context.Context;
+
 import org.springframework.http.MediaType;
+import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.rsocket.api.PayloadExchange;
 import org.springframework.security.rsocket.api.PayloadInterceptor;
-import org.springframework.security.rsocket.core.PayloadInterceptorRSocket;
-import org.springframework.security.rsocket.core.PayloadSocketAcceptor;
-import reactor.core.publisher.Mono;
-
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatCode;
@@ -144,6 +148,27 @@ public class PayloadSocketAcceptorTests {
 		assertThat(exchange.getDataMimeType()).isEqualTo(MediaType.APPLICATION_JSON);
 	}
 
+
+	@Test
+	// gh-8654
+	public void acceptWhenDelegateAcceptRequiresReactiveSecurityContext() {
+		when(this.setupPayload.metadataMimeType()).thenReturn(MediaType.TEXT_PLAIN_VALUE);
+		when(this.setupPayload.dataMimeType()).thenReturn(MediaType.APPLICATION_JSON_VALUE);
+		SecurityContext expectedSecurityContext = new SecurityContextImpl(new TestingAuthenticationToken("user", "password", "ROLE_USER"));
+		CaptureSecurityContextSocketAcceptor captureSecurityContext = new CaptureSecurityContextSocketAcceptor(this.rSocket);
+		PayloadInterceptor authenticateInterceptor = (exchange, chain) -> {
+			Context withSecurityContext = ReactiveSecurityContextHolder.withSecurityContext(Mono.just(expectedSecurityContext));
+			return chain.next(exchange)
+				.subscriberContext(withSecurityContext);
+		};
+		List<PayloadInterceptor> interceptors = Arrays.asList(authenticateInterceptor);
+		this.acceptor = new PayloadSocketAcceptor(captureSecurityContext, interceptors);
+
+		this.acceptor.accept(this.setupPayload, this.rSocket).block();
+
+		assertThat(captureSecurityContext.getSecurityContext()).isEqualTo(expectedSecurityContext);
+	}
+
 	private PayloadExchange captureExchange() {
 		when(this.delegate.accept(any(), any())).thenReturn(Mono.just(this.rSocket));
 		when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty());