Răsfoiți Sursa

Verify ReactorContext when using Virtual Threads

Closes gh-12791
Steve Riesenberg 1 an în urmă
părinte
comite
ff374935fb

+ 62 - 1
config/src/test/java/org/springframework/security/config/annotation/web/configuration/SecurityReactorContextConfigurationTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,14 +17,20 @@
 package org.springframework.security.config.annotation.web.configuration;
 
 import java.net.URI;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.ThreadFactory;
 
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletResponse;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.DisabledOnJre;
+import org.junit.jupiter.api.condition.JRE;
 import org.junit.jupiter.api.extension.ExtendWith;
 import reactor.core.CoreSubscriber;
 import reactor.core.publisher.BaseSubscriber;
@@ -35,6 +41,8 @@ import reactor.util.context.Context;
 
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
+import org.springframework.core.task.SimpleAsyncTaskExecutor;
+import org.springframework.core.task.VirtualThreadTaskExecutor;
 import org.springframework.http.HttpMethod;
 import org.springframework.http.HttpStatus;
 import org.springframework.mock.web.MockHttpServletRequest;
@@ -46,6 +54,7 @@ import org.springframework.security.config.annotation.web.configuration.Security
 import org.springframework.security.config.test.SpringTestContext;
 import org.springframework.security.config.test.SpringTestContextExtension;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
@@ -271,6 +280,58 @@ public class SecurityReactorContextConfigurationTests {
 		verify(strategy, times(2)).getContext();
 	}
 
+	@Test
+	public void createPublisherWhenThreadFactoryIsPlatformThenSecurityContextAttributesAvailable() throws Exception {
+		this.spring.register(SecurityConfig.class).autowire();
+
+		ThreadFactory threadFactory = Executors.defaultThreadFactory();
+		assertContextAttributesAvailable(threadFactory);
+	}
+
+	@Test
+	@DisabledOnJre(JRE.JAVA_17)
+	public void createPublisherWhenThreadFactoryIsVirtualThenSecurityContextAttributesAvailable() throws Exception {
+		this.spring.register(SecurityConfig.class).autowire();
+
+		ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
+		assertContextAttributesAvailable(threadFactory);
+	}
+
+	private void assertContextAttributesAvailable(ThreadFactory threadFactory) throws Exception {
+		Map<Object, Object> expectedContextAttributes = new HashMap<>();
+		expectedContextAttributes.put(HttpServletRequest.class, this.servletRequest);
+		expectedContextAttributes.put(HttpServletResponse.class, this.servletResponse);
+		expectedContextAttributes.put(Authentication.class, this.authentication);
+
+		try (SimpleAsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(threadFactory)) {
+			Future<Map<Object, Object>> future = taskExecutor.submit(this::propagateRequestAttributes);
+			assertThat(future.get()).isEqualTo(expectedContextAttributes);
+		}
+	}
+
+	private Map<Object, Object> propagateRequestAttributes() {
+		RequestAttributes requestAttributes = new ServletRequestAttributes(this.servletRequest, this.servletResponse);
+		RequestContextHolder.setRequestAttributes(requestAttributes);
+
+		SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+		securityContext.setAuthentication(this.authentication);
+		SecurityContextHolder.setContext(securityContext);
+
+		// @formatter:off
+		return Mono.deferContextual(Mono::just)
+				.filter((ctx) -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
+				.map((ctx) -> ctx.<Map<Object, Object>>get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
+				.map((attributes) -> {
+					Map<Object, Object> map = new HashMap<>();
+					// Copy over items from lazily loaded map
+					Arrays.asList(HttpServletRequest.class, HttpServletResponse.class, Authentication.class)
+							.forEach((key) -> map.put(key, attributes.get(key)));
+					return map;
+				})
+				.block();
+		// @formatter:on
+	}
+
 	@Configuration
 	@EnableWebSecurity
 	static class SecurityConfig {

+ 57 - 1
core/src/test/java/org/springframework/security/core/context/ReactiveSecurityContextHolderTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,10 +16,17 @@
 
 package org.springframework.security.core.context;
 
+import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadFactory;
+
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.DisabledOnJre;
+import org.junit.jupiter.api.condition.JRE;
 import reactor.core.publisher.Mono;
+import reactor.core.scheduler.Schedulers;
 import reactor.test.StepVerifier;
 
+import org.springframework.core.task.VirtualThreadTaskExecutor;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 
@@ -99,4 +106,53 @@ public class ReactiveSecurityContextHolderTests {
 		// @formatter:on
 	}
 
+	@Test
+	public void getContextWhenThreadFactoryIsPlatformThenPropagated() {
+		verifySecurityContextIsPropagated(Executors.defaultThreadFactory());
+	}
+
+	@Test
+	@DisabledOnJre(JRE.JAVA_17)
+	public void getContextWhenThreadFactoryIsVirtualThenPropagated() {
+		verifySecurityContextIsPropagated(new VirtualThreadTaskExecutor().getVirtualThreadFactory());
+	}
+
+	private static void verifySecurityContextIsPropagated(ThreadFactory threadFactory) {
+		Authentication authentication = new TestingAuthenticationToken("user", null);
+
+		// @formatter:off
+		Mono<Authentication> publisher = ReactiveSecurityContextHolder.getContext()
+				.map(SecurityContext::getAuthentication)
+				.contextWrite((context) -> ReactiveSecurityContextHolder.withAuthentication(authentication))
+				.subscribeOn(Schedulers.newSingle(threadFactory));
+		// @formatter:on
+
+		StepVerifier.create(publisher).expectNext(authentication).verifyComplete();
+	}
+
+	@Test
+	public void clearContextWhenThreadFactoryIsPlatformThenCleared() {
+		verifySecurityContextIsCleared(Executors.defaultThreadFactory());
+	}
+
+	@Test
+	@DisabledOnJre(JRE.JAVA_17)
+	public void clearContextWhenThreadFactoryIsVirtualThenCleared() {
+		verifySecurityContextIsCleared(new VirtualThreadTaskExecutor().getVirtualThreadFactory());
+	}
+
+	private static void verifySecurityContextIsCleared(ThreadFactory threadFactory) {
+		Authentication authentication = new TestingAuthenticationToken("user", null);
+
+		// @formatter:off
+		Mono<Authentication> publisher = ReactiveSecurityContextHolder.getContext()
+				.map(SecurityContext::getAuthentication)
+				.contextWrite(ReactiveSecurityContextHolder.clearContext())
+				.contextWrite((context) -> ReactiveSecurityContextHolder.withAuthentication(authentication))
+				.subscribeOn(Schedulers.newSingle(threadFactory));
+		// @formatter:on
+
+		StepVerifier.create(publisher).verifyComplete();
+	}
+
 }

+ 35 - 1
web/src/test/java/org/springframework/security/web/server/context/ReactorContextWebFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,17 +17,23 @@
 package org.springframework.security.web.server.context;
 
 import java.util.List;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadFactory;
 
 import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.DisabledOnJre;
+import org.junit.jupiter.api.condition.JRE;
 import org.junit.jupiter.api.extension.ExtendWith;
 import org.mockito.Mock;
 import org.mockito.junit.jupiter.MockitoExtension;
 import reactor.core.publisher.Mono;
+import reactor.core.scheduler.Schedulers;
 import reactor.test.StepVerifier;
 import reactor.test.publisher.TestPublisher;
 import reactor.util.context.Context;
 
+import org.springframework.core.task.VirtualThreadTaskExecutor;
 import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
 import org.springframework.mock.web.server.MockServerWebExchange;
 import org.springframework.security.core.Authentication;
@@ -117,4 +123,32 @@ public class ReactorContextWebFilterTests {
 		StepVerifier.create(filter).expectAccessibleContext().hasKey(contextKey).then().verifyComplete();
 	}
 
+	@Test
+	public void filterWhenThreadFactoryIsPlatformThenSecurityContextLoaded() {
+		ThreadFactory threadFactory = Executors.defaultThreadFactory();
+		assertSecurityContextLoaded(threadFactory);
+	}
+
+	@Test
+	@DisabledOnJre(JRE.JAVA_17)
+	public void filterWhenThreadFactoryIsVirtualThenSecurityContextLoaded() {
+		ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
+		assertSecurityContextLoaded(threadFactory);
+	}
+
+	private void assertSecurityContextLoaded(ThreadFactory threadFactory) {
+		SecurityContextImpl context = new SecurityContextImpl(this.principal);
+		given(this.repository.load(any())).willReturn(Mono.just(context));
+		// @formatter:off
+		WebFilter subscribeOnThreadFactory = (exchange, chain) -> chain.filter(exchange)
+				.subscribeOn(Schedulers.newSingle(threadFactory));
+		WebFilter assertSecurityContext = (exchange, chain) -> ReactiveSecurityContextHolder.getContext()
+				.map(SecurityContext::getAuthentication)
+				.doOnSuccess((authentication) -> assertThat(authentication).isSameAs(this.principal))
+				.then(chain.filter(exchange));
+		// @formatter:on
+		this.handler = WebTestHandler.bindToWebFilters(subscribeOnThreadFactory, this.filter, assertSecurityContext);
+		this.handler.exchange(this.exchange);
+	}
+
 }

+ 36 - 1
web/src/test/java/org/springframework/security/web/server/context/SecurityContextServerWebExchangeWebFilterTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2017 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -17,17 +17,25 @@
 package org.springframework.security.web.server.context;
 
 import java.util.Collections;
+import java.util.concurrent.Executors;
+import java.util.concurrent.ThreadFactory;
 
 import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.condition.DisabledOnJre;
+import org.junit.jupiter.api.condition.JRE;
 import reactor.core.publisher.Mono;
+import reactor.core.scheduler.Schedulers;
 import reactor.test.StepVerifier;
 
+import org.springframework.core.task.VirtualThreadTaskExecutor;
 import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
 import org.springframework.mock.web.server.MockServerWebExchange;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.ReactiveSecurityContextHolder;
+import org.springframework.security.test.web.reactive.server.WebTestHandler;
 import org.springframework.web.server.ServerWebExchange;
+import org.springframework.web.server.WebFilter;
 import org.springframework.web.server.handler.DefaultWebFilterChain;
 
 import static org.assertj.core.api.Assertions.assertThat;
@@ -80,4 +88,31 @@ public class SecurityContextServerWebExchangeWebFilterTests {
 		StepVerifier.create(result).verifyComplete();
 	}
 
+	@Test
+	public void filterWhenThreadFactoryIsPlatformThenContextPopulated() {
+		ThreadFactory threadFactory = Executors.defaultThreadFactory();
+		assertPrincipalPopulated(threadFactory);
+	}
+
+	@Test
+	@DisabledOnJre(JRE.JAVA_17)
+	public void filterWhenThreadFactoryIsVirtualThenContextPopulated() {
+		ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
+		assertPrincipalPopulated(threadFactory);
+	}
+
+	private void assertPrincipalPopulated(ThreadFactory threadFactory) {
+		// @formatter:off
+		WebFilter subscribeOnThreadFactory = (exchange, chain) -> chain.filter(exchange)
+				.contextWrite(ReactiveSecurityContextHolder.withAuthentication(this.principal))
+				.subscribeOn(Schedulers.newSingle(threadFactory));
+		WebFilter assertPrincipal = (exchange, chain) -> exchange.getPrincipal()
+				.doOnSuccess((principal) -> assertThat(principal).isSameAs(this.principal))
+				.then(chain.filter(exchange));
+		// @formatter:on
+		WebTestHandler handler = WebTestHandler.bindToWebFilters(subscribeOnThreadFactory, this.filter,
+				assertPrincipal);
+		handler.exchange(this.exchange);
+	}
+
 }