浏览代码

Use ReactorSecurityContextHolder

Issue gh-4713
Rob Winch 7 年之前
父节点
当前提交
747473257f

+ 3 - 7
config/src/test/java/org/springframework/security/config/annotation/method/configuration/EnableReactiveMethodSecurityTests.java

@@ -25,7 +25,7 @@ import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.annotation.Bean;
 import org.springframework.security.access.AccessDeniedException;
 import org.springframework.security.authentication.TestingAuthenticationToken;
-import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.test.context.ContextConfiguration;
 import org.springframework.test.context.junit4.SpringRunner;
 import reactor.core.publisher.Flux;
@@ -34,8 +34,6 @@ import reactor.test.StepVerifier;
 import reactor.test.publisher.TestPublisher;
 import reactor.util.context.Context;
 
-import java.util.function.Function;
-
 import static org.mockito.Mockito.*;
 
 /**
@@ -49,10 +47,8 @@ public class EnableReactiveMethodSecurityTests {
 	ReactiveMessageService delegate;
 	TestPublisher<String> result = TestPublisher.create();
 
-	Function<Context, Context> withAdmin = context -> context.put(Authentication.class, Mono
-		.just(new TestingAuthenticationToken("admin","password","ROLE_USER", "ROLE_ADMIN")));
-	Function<Context, Context> withUser = context -> context.put(Authentication.class, Mono
-		.just(new TestingAuthenticationToken("user","password","ROLE_USER")));
+	Context withAdmin = ReactiveSecurityContextHolder.withAuthentication(new TestingAuthenticationToken("admin","password","ROLE_USER", "ROLE_ADMIN"));
+	Context withUser = ReactiveSecurityContextHolder.withAuthentication(new TestingAuthenticationToken("user","password","ROLE_USER"));
 
 	@After
 	public void cleanup() {

+ 8 - 6
config/src/test/java/org/springframework/security/config/annotation/web/reactive/EnableWebFluxSecurityTests.java

@@ -31,6 +31,8 @@ import org.springframework.security.config.test.SpringTestRule;
 import org.springframework.security.config.users.ReactiveAuthenticationTestConfiguration;
 import org.springframework.security.config.web.server.ServerHttpSecurity;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.userdetails.MapReactiveUserDetailsService;
 import org.springframework.security.core.userdetails.ReactiveUserDetailsService;
 import org.springframework.security.core.userdetails.User;
@@ -106,8 +108,8 @@ public class EnableWebFluxSecurityTests {
 				chain.filter(exchange.mutate().principal(Mono.just(currentPrincipal)).build()),
 			this.springSecurityFilterChain,
 			(exchange,chain) ->
-				Mono.subscriberContext()
-					.flatMap( c -> c.<Mono<Principal>>get(Authentication.class))
+				ReactiveSecurityContextHolder.getContext()
+					.map(SecurityContext::getAuthentication)
 					.flatMap( principal -> exchange.getResponse()
 						.writeWith(Mono.just(toDataBuffer(principal.getName()))))
 		).build();
@@ -126,8 +128,8 @@ public class EnableWebFluxSecurityTests {
 		WebTestClient client = WebTestClientBuilder.bindToWebFilters(
 			this.springSecurityFilterChain,
 			(exchange,chain) ->
-				Mono.subscriberContext()
-					.flatMap( c -> c.<Mono<Principal>>get(Authentication.class))
+				ReactiveSecurityContextHolder.getContext()
+					.map(SecurityContext::getAuthentication)
 					.flatMap( principal -> exchange.getResponse()
 						.writeWith(Mono.just(toDataBuffer(principal.getName()))))
 		)
@@ -154,8 +156,8 @@ public class EnableWebFluxSecurityTests {
 		WebTestClient client = WebTestClientBuilder.bindToWebFilters(
 			this.springSecurityFilterChain,
 			(exchange,chain) ->
-				Mono.subscriberContext()
-					.flatMap( c -> c.<Mono<Principal>>get(Authentication.class))
+				ReactiveSecurityContextHolder.getContext()
+					.map(SecurityContext::getAuthentication)
 					.flatMap( principal -> exchange.getResponse()
 						.writeWith(Mono.just(toDataBuffer(principal.getName()))))
 		)

+ 5 - 4
core/src/main/java/org/springframework/security/access/prepost/PrePostAdviceReactiveMethodInterceptor.java

@@ -25,11 +25,12 @@ import org.springframework.security.access.method.MethodSecurityMetadataSource;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.util.Assert;
 import reactor.core.Exceptions;
 import reactor.core.publisher.Flux;
 import reactor.core.publisher.Mono;
-import reactor.util.context.Context;
 
 import java.lang.reflect.Method;
 import java.util.Collection;
@@ -68,9 +69,9 @@ public class PrePostAdviceReactiveMethodInterceptor implements MethodInterceptor
 			.getAttributes(method, targetClass);
 
 		PreInvocationAttribute preAttr = findPreInvocationAttribute(attributes);
-		Mono<Authentication> toInvoke = Mono.subscriberContext()
-			.defaultIfEmpty(Context.empty())
-			.flatMap( cxt -> cxt.getOrDefault(Authentication.class, Mono.just(anonymous)))
+		Mono<Authentication> toInvoke = ReactiveSecurityContextHolder.getContext()
+			.map(SecurityContext::getAuthentication)
+			.defaultIfEmpty(this.anonymous)
 			.filter( auth -> this.preInvocationAdvice.before(auth, invocation, preAttr))
 			.switchIfEmpty(Mono.error(new AccessDeniedException("Denied")));
 

+ 3 - 2
test/src/main/java/org/springframework/security/test/context/support/ReactorContextTestExecutionListener.java

@@ -18,6 +18,7 @@ package org.springframework.security.test.context.support;
 
 import org.reactivestreams.Subscription;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.security.test.context.TestSecurityContextHolder;
 import org.springframework.test.context.TestContext;
 import org.springframework.test.context.TestExecutionListener;
@@ -25,7 +26,6 @@ import org.springframework.test.context.support.AbstractTestExecutionListener;
 import org.springframework.util.ClassUtils;
 import reactor.core.CoreSubscriber;
 import reactor.core.publisher.Hooks;
-import reactor.core.publisher.Mono;
 import reactor.core.publisher.Operators;
 import reactor.util.context.Context;
 
@@ -76,7 +76,8 @@ public class ReactorContextTestExecutionListener
 				if (authentication == null) {
 					return context;
 				}
-				return context.put(Authentication.class, Mono.just(authentication));
+				Context toMerge = ReactiveSecurityContextHolder.withAuthentication(authentication);
+				return context.putAll(toMerge);
 			}
 
 			@Override

+ 4 - 2
test/src/test/java/org/springframework/security/test/context/annotation/SecurityTestExecutionListenerTests.java

@@ -20,6 +20,8 @@ import static org.assertj.core.api.Assertions.assertThat;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.test.context.support.WithMockUser;
 import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
@@ -42,8 +44,8 @@ public class SecurityTestExecutionListenerTests {
 	@WithMockUser
 	@Test
 	public void reactorContextTestSecurityContextHolderExecutionListenerTestIsRegistered() {
-		Mono<String> name = Mono.subscriberContext()
-			.flatMap( context -> context.<Mono<Authentication>>get(Authentication.class))
+		Mono<String> name = ReactiveSecurityContextHolder.getContext()
+			.map(SecurityContext::getAuthentication)
 			.map(Principal::getName);
 
 		StepVerifier.create(name)

+ 4 - 2
test/src/test/java/org/springframework/security/test/context/support/ReactorContextTestExecutionListenerTests.java

@@ -26,6 +26,8 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContext;
 import reactor.core.publisher.Hooks;
 import reactor.core.publisher.Mono;
 import reactor.test.StepVerifier;
@@ -108,8 +110,8 @@ public class ReactorContextTestExecutionListenerTests {
 	}
 
 	public void assertAuthentication(Authentication expected) {
-		Mono<Authentication> authentication = Mono.subscriberContext()
-			.flatMap( context -> context.<Mono<Authentication>>get(Authentication.class));
+		Mono<Authentication> authentication = ReactiveSecurityContextHolder.getContext()
+			.map(SecurityContext::getAuthentication);
 
 		StepVerifier.create(authentication)
 			.expectNext(expected)

+ 10 - 1
web/src/main/java/org/springframework/security/web/server/context/AuthenticationReactorContextWebFilter.java

@@ -17,6 +17,8 @@
 package org.springframework.security.web.server.context;
 
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebFilter;
 import org.springframework.web.server.WebFilterChain;
@@ -38,6 +40,13 @@ public class AuthenticationReactorContextWebFilter implements WebFilter {
 	public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
 
 		return chain.filter(exchange)
-				.subscriberContext((Context context) -> context.put(Authentication.class, exchange.getPrincipal()));
+				.subscriberContext(createContext(exchange));
+	}
+
+	private Context createContext(ServerWebExchange exchange) {
+		return exchange.getPrincipal()
+			.cast(Authentication.class)
+			.map(SecurityContextImpl::new)
+			.as(ReactiveSecurityContextHolder::withSecurityContext);
 	}
 }

+ 18 - 19
web/src/test/java/org/springframework/security/web/server/context/AuthenticationReactorContextWebFilterTests.java

@@ -21,11 +21,12 @@ import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
 import org.springframework.mock.web.server.MockServerWebExchange;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.handler.DefaultWebFilterChain;
 import reactor.core.publisher.Mono;
 import reactor.test.StepVerifier;
-import reactor.util.context.Context;
 
 import java.security.Principal;
 
@@ -47,12 +48,12 @@ public class AuthenticationReactorContextWebFilterTests {
 		exchange = exchange.mutate().principal(Mono.just(principal)).build();
 		StepVerifier.create(filter.filter(exchange,
 			new DefaultWebFilterChain( e ->
-				Mono.subscriberContext().doOnSuccess( context -> {
-					Principal contextPrincipal = context.<Mono<Principal>>get(Authentication.class).block();
-					assertThat(contextPrincipal).isEqualTo(principal);
-					assertThat(context.<String>get("foo")).isEqualTo("bar");
-				})
-				.then()
+				ReactiveSecurityContextHolder.getContext()
+					.map(SecurityContext::getAuthentication)
+					.doOnSuccess(contextPrincipal -> assertThat(contextPrincipal).isEqualTo(principal))
+					.flatMap( contextPrincipal -> Mono.subscriberContext())
+					.doOnSuccess( context -> assertThat(context.<String>get("foo")).isEqualTo("bar"))
+					.then()
 			)
 		)
 		.subscriberContext( context -> context.put("foo", "bar")))
@@ -64,11 +65,10 @@ public class AuthenticationReactorContextWebFilterTests {
 		exchange = exchange.mutate().principal(Mono.just(principal)).build();
 		StepVerifier.create(filter.filter(exchange,
 			new DefaultWebFilterChain( e ->
-				Mono.subscriberContext().doOnSuccess( context -> {
-					Principal contextPrincipal = context.<Mono<Principal>>get(Authentication.class).block();
-					assertThat(contextPrincipal).isEqualTo(principal);
-				})
-				.then()
+				ReactiveSecurityContextHolder.getContext()
+					.map(SecurityContext::getAuthentication)
+					.doOnSuccess(contextPrincipal -> assertThat(contextPrincipal).isEqualTo(principal))
+					.then()
 			)
 		))
 		.verifyComplete();
@@ -76,15 +76,14 @@ public class AuthenticationReactorContextWebFilterTests {
 
 	@Test
 	public void filterWhenPrincipalNullThenContextEmpty() {
-		Context defaultContext = Context.empty();
+		Authentication defaultAuthentication = new TestingAuthenticationToken("anonymouse","anonymous", "TEST");
 		StepVerifier.create(filter.filter(exchange,
 			new DefaultWebFilterChain( e ->
-				Mono.subscriberContext()
-					.defaultIfEmpty(defaultContext)
-					.doOnSuccess( context -> {
-					Principal contextPrincipal = context.<Mono<Principal>>get(Authentication.class).block();
-					assertThat(contextPrincipal).isNull();
-				})
+				ReactiveSecurityContextHolder.getContext()
+					.map(SecurityContext::getAuthentication)
+					.defaultIfEmpty(defaultAuthentication)
+					.doOnSuccess( contextPrincipal -> assertThat(contextPrincipal).isEqualTo(defaultAuthentication)
+				)
 				.then()
 			)
 		))