|
@@ -1,5 +1,5 @@
|
|
|
/*
|
|
|
- * Copyright 2002-2022 the original author or authors.
|
|
|
+ * Copyright 2002-2023 the original author or authors.
|
|
|
*
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
* you may not use this file except in compliance with the License.
|
|
@@ -17,14 +17,20 @@
|
|
|
package org.springframework.security.config.annotation.web.configuration;
|
|
|
|
|
|
import java.net.URI;
|
|
|
+import java.util.Arrays;
|
|
|
import java.util.HashMap;
|
|
|
import java.util.Map;
|
|
|
+import java.util.concurrent.Executors;
|
|
|
+import java.util.concurrent.Future;
|
|
|
+import java.util.concurrent.ThreadFactory;
|
|
|
|
|
|
import jakarta.servlet.http.HttpServletRequest;
|
|
|
import jakarta.servlet.http.HttpServletResponse;
|
|
|
import org.junit.jupiter.api.AfterEach;
|
|
|
import org.junit.jupiter.api.BeforeEach;
|
|
|
import org.junit.jupiter.api.Test;
|
|
|
+import org.junit.jupiter.api.condition.DisabledOnJre;
|
|
|
+import org.junit.jupiter.api.condition.JRE;
|
|
|
import org.junit.jupiter.api.extension.ExtendWith;
|
|
|
import reactor.core.CoreSubscriber;
|
|
|
import reactor.core.publisher.BaseSubscriber;
|
|
@@ -35,6 +41,8 @@ import reactor.util.context.Context;
|
|
|
|
|
|
import org.springframework.context.annotation.Bean;
|
|
|
import org.springframework.context.annotation.Configuration;
|
|
|
+import org.springframework.core.task.SimpleAsyncTaskExecutor;
|
|
|
+import org.springframework.core.task.VirtualThreadTaskExecutor;
|
|
|
import org.springframework.http.HttpMethod;
|
|
|
import org.springframework.http.HttpStatus;
|
|
|
import org.springframework.mock.web.MockHttpServletRequest;
|
|
@@ -46,6 +54,7 @@ import org.springframework.security.config.annotation.web.configuration.Security
|
|
|
import org.springframework.security.config.test.SpringTestContext;
|
|
|
import org.springframework.security.config.test.SpringTestContextExtension;
|
|
|
import org.springframework.security.core.Authentication;
|
|
|
+import org.springframework.security.core.context.SecurityContext;
|
|
|
import org.springframework.security.core.context.SecurityContextHolder;
|
|
|
import org.springframework.security.core.context.SecurityContextHolderStrategy;
|
|
|
import org.springframework.security.oauth2.client.web.reactive.function.client.MockExchangeFunction;
|
|
@@ -271,6 +280,58 @@ public class SecurityReactorContextConfigurationTests {
|
|
|
verify(strategy, times(2)).getContext();
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void createPublisherWhenThreadFactoryIsPlatformThenSecurityContextAttributesAvailable() throws Exception {
|
|
|
+ this.spring.register(SecurityConfig.class).autowire();
|
|
|
+
|
|
|
+ ThreadFactory threadFactory = Executors.defaultThreadFactory();
|
|
|
+ assertContextAttributesAvailable(threadFactory);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ @DisabledOnJre(JRE.JAVA_17)
|
|
|
+ public void createPublisherWhenThreadFactoryIsVirtualThenSecurityContextAttributesAvailable() throws Exception {
|
|
|
+ this.spring.register(SecurityConfig.class).autowire();
|
|
|
+
|
|
|
+ ThreadFactory threadFactory = new VirtualThreadTaskExecutor().getVirtualThreadFactory();
|
|
|
+ assertContextAttributesAvailable(threadFactory);
|
|
|
+ }
|
|
|
+
|
|
|
+ private void assertContextAttributesAvailable(ThreadFactory threadFactory) throws Exception {
|
|
|
+ Map<Object, Object> expectedContextAttributes = new HashMap<>();
|
|
|
+ expectedContextAttributes.put(HttpServletRequest.class, this.servletRequest);
|
|
|
+ expectedContextAttributes.put(HttpServletResponse.class, this.servletResponse);
|
|
|
+ expectedContextAttributes.put(Authentication.class, this.authentication);
|
|
|
+
|
|
|
+ try (SimpleAsyncTaskExecutor taskExecutor = new SimpleAsyncTaskExecutor(threadFactory)) {
|
|
|
+ Future<Map<Object, Object>> future = taskExecutor.submit(this::propagateRequestAttributes);
|
|
|
+ assertThat(future.get()).isEqualTo(expectedContextAttributes);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ private Map<Object, Object> propagateRequestAttributes() {
|
|
|
+ RequestAttributes requestAttributes = new ServletRequestAttributes(this.servletRequest, this.servletResponse);
|
|
|
+ RequestContextHolder.setRequestAttributes(requestAttributes);
|
|
|
+
|
|
|
+ SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
|
|
|
+ securityContext.setAuthentication(this.authentication);
|
|
|
+ SecurityContextHolder.setContext(securityContext);
|
|
|
+
|
|
|
+ // @formatter:off
|
|
|
+ return Mono.deferContextual(Mono::just)
|
|
|
+ .filter((ctx) -> ctx.hasKey(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
|
|
|
+ .map((ctx) -> ctx.<Map<Object, Object>>get(SecurityReactorContextSubscriber.SECURITY_CONTEXT_ATTRIBUTES))
|
|
|
+ .map((attributes) -> {
|
|
|
+ Map<Object, Object> map = new HashMap<>();
|
|
|
+ // Copy over items from lazily loaded map
|
|
|
+ Arrays.asList(HttpServletRequest.class, HttpServletResponse.class, Authentication.class)
|
|
|
+ .forEach((key) -> map.put(key, attributes.get(key)));
|
|
|
+ return map;
|
|
|
+ })
|
|
|
+ .block();
|
|
|
+ // @formatter:on
|
|
|
+ }
|
|
|
+
|
|
|
@Configuration
|
|
|
@EnableWebSecurity
|
|
|
static class SecurityConfig {
|