|
@@ -1,5 +1,5 @@
|
|
/*
|
|
/*
|
|
- * Copyright 2019 the original author or authors.
|
|
|
|
|
|
+ * Copyright 2019-2021 the original author or authors.
|
|
*
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* you may not use this file except in compliance with the License.
|
|
@@ -19,6 +19,8 @@ package org.springframework.security.rsocket.core;
|
|
import io.rsocket.Payload;
|
|
import io.rsocket.Payload;
|
|
import io.rsocket.RSocket;
|
|
import io.rsocket.RSocket;
|
|
import io.rsocket.metadata.WellKnownMimeType;
|
|
import io.rsocket.metadata.WellKnownMimeType;
|
|
|
|
+import io.rsocket.util.ByteBufPayload;
|
|
|
|
+import io.rsocket.util.DefaultPayload;
|
|
import io.rsocket.util.RSocketProxy;
|
|
import io.rsocket.util.RSocketProxy;
|
|
import org.junit.Test;
|
|
import org.junit.Test;
|
|
import org.junit.runner.RunWith;
|
|
import org.junit.runner.RunWith;
|
|
@@ -28,7 +30,9 @@ import org.mockito.Mock;
|
|
import org.mockito.runners.MockitoJUnitRunner;
|
|
import org.mockito.runners.MockitoJUnitRunner;
|
|
import org.mockito.stubbing.Answer;
|
|
import org.mockito.stubbing.Answer;
|
|
import org.reactivestreams.Publisher;
|
|
import org.reactivestreams.Publisher;
|
|
|
|
+import org.reactivestreams.Subscription;
|
|
import org.springframework.http.MediaType;
|
|
import org.springframework.http.MediaType;
|
|
|
|
+import org.springframework.security.access.AccessDeniedException;
|
|
import org.springframework.security.authentication.TestingAuthenticationToken;
|
|
import org.springframework.security.authentication.TestingAuthenticationToken;
|
|
import org.springframework.security.core.Authentication;
|
|
import org.springframework.security.core.Authentication;
|
|
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
|
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
|
@@ -41,6 +45,8 @@ import org.springframework.security.rsocket.core.DefaultPayloadExchange;
|
|
import org.springframework.security.rsocket.core.PayloadInterceptorRSocket;
|
|
import org.springframework.security.rsocket.core.PayloadInterceptorRSocket;
|
|
import org.springframework.util.MimeType;
|
|
import org.springframework.util.MimeType;
|
|
import org.springframework.util.MimeTypeUtils;
|
|
import org.springframework.util.MimeTypeUtils;
|
|
|
|
+import reactor.util.context.Context;
|
|
|
|
+import reactor.core.CoreSubscriber;
|
|
import reactor.core.publisher.Flux;
|
|
import reactor.core.publisher.Flux;
|
|
import reactor.core.publisher.Mono;
|
|
import reactor.core.publisher.Mono;
|
|
import reactor.test.StepVerifier;
|
|
import reactor.test.StepVerifier;
|
|
@@ -50,10 +56,13 @@ import reactor.test.publisher.TestPublisher;
|
|
import java.util.Arrays;
|
|
import java.util.Arrays;
|
|
import java.util.Collections;
|
|
import java.util.Collections;
|
|
import java.util.List;
|
|
import java.util.List;
|
|
|
|
+import java.util.concurrent.Executors;
|
|
|
|
+import java.util.concurrent.ExecutorService;
|
|
|
|
|
|
import static org.assertj.core.api.Assertions.*;
|
|
import static org.assertj.core.api.Assertions.*;
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
import static org.mockito.ArgumentMatchers.any;
|
|
import static org.mockito.ArgumentMatchers.eq;
|
|
import static org.mockito.ArgumentMatchers.eq;
|
|
|
|
+import static org.mockito.Mockito.times;
|
|
import static org.mockito.Mockito.verify;
|
|
import static org.mockito.Mockito.verify;
|
|
import static org.mockito.Mockito.verifyZeroInteractions;
|
|
import static org.mockito.Mockito.verifyZeroInteractions;
|
|
import static org.mockito.Mockito.when;
|
|
import static org.mockito.Mockito.when;
|
|
@@ -315,6 +324,57 @@ public class PayloadInterceptorRSocketTests {
|
|
verify(this.delegate).requestChannel(any());
|
|
verify(this.delegate).requestChannel(any());
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ // gh-9345
|
|
|
|
+ @Test
|
|
|
|
+ public void requestChannelWhenInterceptorCompletesThenAllPayloadsRetained() {
|
|
|
|
+ ExecutorService executors = Executors.newSingleThreadExecutor();
|
|
|
|
+ Payload payload = ByteBufPayload.create("data");
|
|
|
|
+ Payload payloadTwo = ByteBufPayload.create("moredata");
|
|
|
|
+ Payload payloadThree = ByteBufPayload.create("stillmoredata");
|
|
|
|
+ Context ctx = Context.empty();
|
|
|
|
+ Flux<Payload> payloads = this.payloadResult.flux();
|
|
|
|
+ when(this.interceptor.intercept(any(), any())).thenReturn(Mono.empty())
|
|
|
|
+ .thenReturn(Mono.error(() -> new AccessDeniedException("Access Denied")));
|
|
|
|
+ when(this.delegate.requestChannel(any())).thenAnswer((invocation) -> {
|
|
|
|
+ Flux<Payload> input = invocation.getArgument(0);
|
|
|
|
+ return Flux.from(input).switchOnFirst((signal, innerFlux) -> innerFlux.map(Payload::getDataUtf8)
|
|
|
|
+ .transform((data) -> Flux.<String>create((emitter) -> {
|
|
|
|
+ Runnable run = () -> data.subscribe(new CoreSubscriber<String>() {
|
|
|
|
+ @Override
|
|
|
|
+ public void onSubscribe(Subscription s) {
|
|
|
|
+ s.request(3);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public void onNext(String s) {
|
|
|
|
+ emitter.next(s);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public void onError(Throwable t) {
|
|
|
|
+ emitter.error(t);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public void onComplete() {
|
|
|
|
+ emitter.complete();
|
|
|
|
+ }
|
|
|
|
+ });
|
|
|
|
+ executors.execute(run);
|
|
|
|
+ })).map(DefaultPayload::create));
|
|
|
|
+ });
|
|
|
|
+ PayloadInterceptorRSocket interceptor = new PayloadInterceptorRSocket(this.delegate,
|
|
|
|
+ Arrays.asList(this.interceptor), this.metadataMimeType, this.dataMimeType, ctx);
|
|
|
|
+ StepVerifier.create(interceptor.requestChannel(payloads).doOnDiscard(Payload.class, Payload::release))
|
|
|
|
+ .then(() -> this.payloadResult.assertSubscribers())
|
|
|
|
+ .then(() -> this.payloadResult.emit(payload, payloadTwo, payloadThree))
|
|
|
|
+ .assertNext((next) -> assertThat(next.getDataUtf8()).isEqualTo(payload.getDataUtf8()))
|
|
|
|
+ .verifyError(AccessDeniedException.class);
|
|
|
|
+ verify(this.interceptor, times(2)).intercept(this.exchange.capture(), any());
|
|
|
|
+ assertThat(this.exchange.getValue().getPayload()).isEqualTo(payloadTwo);
|
|
|
|
+ verify(this.delegate).requestChannel(any());
|
|
|
|
+ }
|
|
|
|
+
|
|
@Test
|
|
@Test
|
|
public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() {
|
|
public void requestChannelWhenInterceptorErrorsThenDelegateNotSubscribed() {
|
|
RuntimeException expected = new RuntimeException("Oops");
|
|
RuntimeException expected = new RuntimeException("Oops");
|