Selaa lähdekoodia

Use SecurityContextHolderStrategy for Context Propagation

Issue gh-11060
Josh Cummings 3 vuotta sitten
vanhempi
commit
38cb6c3172

+ 16 - 3
core/src/main/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextSupport.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2016 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -19,6 +19,9 @@ package org.springframework.security.concurrent;
 import java.util.concurrent.Callable;
 
 import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.util.Assert;
 
 /**
  * An internal support class that wraps {@link Callable} with
@@ -30,6 +33,9 @@ import org.springframework.security.core.context.SecurityContext;
  */
 abstract class AbstractDelegatingSecurityContextSupport {
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	private final SecurityContext securityContext;
 
 	/**
@@ -44,12 +50,19 @@ abstract class AbstractDelegatingSecurityContextSupport {
 		this.securityContext = securityContext;
 	}
 
+	void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 	protected final Runnable wrap(Runnable delegate) {
-		return DelegatingSecurityContextRunnable.create(delegate, this.securityContext);
+		return DelegatingSecurityContextRunnable.create(delegate, this.securityContext,
+				this.securityContextHolderStrategy);
 	}
 
 	protected final <T> Callable<T> wrap(Callable<T> delegate) {
-		return DelegatingSecurityContextCallable.create(delegate, this.securityContext);
+		return DelegatingSecurityContextCallable.create(delegate, this.securityContext,
+				this.securityContextHolderStrategy);
 	}
 
 }

+ 49 - 12
core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextCallable.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -20,6 +20,7 @@ import java.util.concurrent.Callable;
 
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.util.Assert;
 
 /**
@@ -40,10 +41,15 @@ public final class DelegatingSecurityContextCallable<V> implements Callable<V> {
 
 	private final Callable<V> delegate;
 
+	private final boolean explicitSecurityContextProvided;
+
 	/**
 	 * The {@link SecurityContext} that the delegate {@link Callable} will be ran as.
 	 */
-	private final SecurityContext delegateSecurityContext;
+	private SecurityContext delegateSecurityContext;
+
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
 
 	/**
 	 * The {@link SecurityContext} that was on the {@link SecurityContextHolder} prior to
@@ -60,10 +66,7 @@ public final class DelegatingSecurityContextCallable<V> implements Callable<V> {
 	 * {@link Callable}. Cannot be null.
 	 */
 	public DelegatingSecurityContextCallable(Callable<V> delegate, SecurityContext securityContext) {
-		Assert.notNull(delegate, "delegate cannot be null");
-		Assert.notNull(securityContext, "securityContext cannot be null");
-		this.delegate = delegate;
-		this.delegateSecurityContext = securityContext;
+		this(delegate, securityContext, true);
 	}
 
 	/**
@@ -73,28 +76,51 @@ public final class DelegatingSecurityContextCallable<V> implements Callable<V> {
 	 * {@link SecurityContext}. Cannot be null.
 	 */
 	public DelegatingSecurityContextCallable(Callable<V> delegate) {
-		this(delegate, SecurityContextHolder.getContext());
+		this(delegate, SecurityContextHolder.getContext(), false);
+	}
+
+	private DelegatingSecurityContextCallable(Callable<V> delegate, SecurityContext securityContext,
+			boolean explicitSecurityContextProvided) {
+		Assert.notNull(delegate, "delegate cannot be null");
+		Assert.notNull(securityContext, "securityContext cannot be null");
+		this.delegate = delegate;
+		this.delegateSecurityContext = securityContext;
+		this.explicitSecurityContextProvided = explicitSecurityContextProvided;
 	}
 
 	@Override
 	public V call() throws Exception {
-		this.originalSecurityContext = SecurityContextHolder.getContext();
+		this.originalSecurityContext = this.securityContextHolderStrategy.getContext();
 		try {
-			SecurityContextHolder.setContext(this.delegateSecurityContext);
+			this.securityContextHolderStrategy.setContext(this.delegateSecurityContext);
 			return this.delegate.call();
 		}
 		finally {
-			SecurityContext emptyContext = SecurityContextHolder.createEmptyContext();
+			SecurityContext emptyContext = this.securityContextHolderStrategy.createEmptyContext();
 			if (emptyContext.equals(this.originalSecurityContext)) {
-				SecurityContextHolder.clearContext();
+				this.securityContextHolderStrategy.clearContext();
 			}
 			else {
-				SecurityContextHolder.setContext(this.originalSecurityContext);
+				this.securityContextHolderStrategy.setContext(this.originalSecurityContext);
 			}
 			this.originalSecurityContext = null;
 		}
 	}
 
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 *
+	 * @since 5.8
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+		if (!this.explicitSecurityContextProvided) {
+			this.delegateSecurityContext = securityContextHolderStrategy.getContext();
+		}
+	}
+
 	@Override
 	public String toString() {
 		return this.delegate.toString();
@@ -116,4 +142,15 @@ public final class DelegatingSecurityContextCallable<V> implements Callable<V> {
 				: new DelegatingSecurityContextCallable<>(delegate);
 	}
 
+	static <V> Callable<V> create(Callable<V> delegate, SecurityContext securityContext,
+			SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(delegate, "delegate cannot be null");
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		DelegatingSecurityContextCallable<V> callable = (securityContext != null)
+				? new DelegatingSecurityContextCallable<>(delegate, securityContext)
+				: new DelegatingSecurityContextCallable<>(delegate);
+		callable.setSecurityContextHolderStrategy(securityContextHolderStrategy);
+		return callable;
+	}
+
 }

+ 11 - 0
core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextExecutor.java

@@ -20,6 +20,7 @@ import java.util.concurrent.Executor;
 
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.util.Assert;
 
 /**
@@ -66,4 +67,14 @@ public class DelegatingSecurityContextExecutor extends AbstractDelegatingSecurit
 		return this.delegate;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 *
+	 * @since 5.8
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		super.setSecurityContextHolderStrategy(securityContextHolderStrategy);
+	}
+
 }

+ 49 - 12
core/src/main/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnable.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -18,6 +18,7 @@ package org.springframework.security.concurrent;
 
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.util.Assert;
 
 /**
@@ -38,10 +39,15 @@ public final class DelegatingSecurityContextRunnable implements Runnable {
 
 	private final Runnable delegate;
 
+	private final boolean explicitSecurityContextProvided;
+
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	/**
 	 * The {@link SecurityContext} that the delegate {@link Runnable} will be ran as.
 	 */
-	private final SecurityContext delegateSecurityContext;
+	private SecurityContext delegateSecurityContext;
 
 	/**
 	 * The {@link SecurityContext} that was on the {@link SecurityContextHolder} prior to
@@ -58,10 +64,7 @@ public final class DelegatingSecurityContextRunnable implements Runnable {
 	 * {@link Runnable}. Cannot be null.
 	 */
 	public DelegatingSecurityContextRunnable(Runnable delegate, SecurityContext securityContext) {
-		Assert.notNull(delegate, "delegate cannot be null");
-		Assert.notNull(securityContext, "securityContext cannot be null");
-		this.delegate = delegate;
-		this.delegateSecurityContext = securityContext;
+		this(delegate, securityContext, true);
 	}
 
 	/**
@@ -71,28 +74,51 @@ public final class DelegatingSecurityContextRunnable implements Runnable {
 	 * {@link SecurityContext}. Cannot be null.
 	 */
 	public DelegatingSecurityContextRunnable(Runnable delegate) {
-		this(delegate, SecurityContextHolder.getContext());
+		this(delegate, SecurityContextHolder.getContext(), false);
+	}
+
+	private DelegatingSecurityContextRunnable(Runnable delegate, SecurityContext securityContext,
+			boolean explicitSecurityContextProvided) {
+		Assert.notNull(delegate, "delegate cannot be null");
+		Assert.notNull(securityContext, "securityContext cannot be null");
+		this.delegate = delegate;
+		this.delegateSecurityContext = securityContext;
+		this.explicitSecurityContextProvided = explicitSecurityContextProvided;
 	}
 
 	@Override
 	public void run() {
-		this.originalSecurityContext = SecurityContextHolder.getContext();
+		this.originalSecurityContext = this.securityContextHolderStrategy.getContext();
 		try {
-			SecurityContextHolder.setContext(this.delegateSecurityContext);
+			this.securityContextHolderStrategy.setContext(this.delegateSecurityContext);
 			this.delegate.run();
 		}
 		finally {
-			SecurityContext emptyContext = SecurityContextHolder.createEmptyContext();
+			SecurityContext emptyContext = this.securityContextHolderStrategy.createEmptyContext();
 			if (emptyContext.equals(this.originalSecurityContext)) {
-				SecurityContextHolder.clearContext();
+				this.securityContextHolderStrategy.clearContext();
 			}
 			else {
-				SecurityContextHolder.setContext(this.originalSecurityContext);
+				this.securityContextHolderStrategy.setContext(this.originalSecurityContext);
 			}
 			this.originalSecurityContext = null;
 		}
 	}
 
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 *
+	 * @since 5.8
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+		if (!this.explicitSecurityContextProvided) {
+			this.delegateSecurityContext = this.securityContextHolderStrategy.getContext();
+		}
+	}
+
 	@Override
 	public String toString() {
 		return this.delegate.toString();
@@ -114,4 +140,15 @@ public final class DelegatingSecurityContextRunnable implements Runnable {
 				: new DelegatingSecurityContextRunnable(delegate);
 	}
 
+	static Runnable create(Runnable delegate, SecurityContext securityContext,
+			SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(delegate, "delegate cannot be  null");
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		DelegatingSecurityContextRunnable runnable = (securityContext != null)
+				? new DelegatingSecurityContextRunnable(delegate, securityContext)
+				: new DelegatingSecurityContextRunnable(delegate);
+		runnable.setSecurityContextHolderStrategy(securityContextHolderStrategy);
+		return runnable;
+	}
+
 }

+ 10 - 8
core/src/test/java/org/springframework/security/concurrent/AbstractDelegatingSecurityContextTestSupport.java

@@ -30,7 +30,9 @@ import org.mockito.junit.jupiter.MockitoExtension;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 
+import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
+import static org.mockito.ArgumentMatchers.isNull;
 
 /**
  * Abstract base class for testing classes that extend
@@ -71,18 +73,18 @@ public abstract class AbstractDelegatingSecurityContextTestSupport {
 	protected MockedStatic<DelegatingSecurityContextRunnable> delegatingSecurityContextRunnable;
 
 	public final void explicitSecurityContextSetup() throws Exception {
-		this.delegatingSecurityContextCallable.when(
-				() -> DelegatingSecurityContextCallable.create(eq(this.callable), this.securityContextCaptor.capture()))
-				.thenReturn(this.wrappedCallable);
-		this.delegatingSecurityContextRunnable.when(
-				() -> DelegatingSecurityContextRunnable.create(eq(this.runnable), this.securityContextCaptor.capture()))
-				.thenReturn(this.wrappedRunnable);
+		this.delegatingSecurityContextCallable.when(() -> DelegatingSecurityContextCallable.create(eq(this.callable),
+				this.securityContextCaptor.capture(), any())).thenReturn(this.wrappedCallable);
+		this.delegatingSecurityContextRunnable.when(() -> DelegatingSecurityContextRunnable.create(eq(this.runnable),
+				this.securityContextCaptor.capture(), any())).thenReturn(this.wrappedRunnable);
 	}
 
 	public final void currentSecurityContextSetup() throws Exception {
-		this.delegatingSecurityContextCallable.when(() -> DelegatingSecurityContextCallable.create(this.callable, null))
+		this.delegatingSecurityContextCallable
+				.when(() -> DelegatingSecurityContextCallable.create(eq(this.callable), isNull(), any()))
 				.thenReturn(this.wrappedCallable);
-		this.delegatingSecurityContextRunnable.when(() -> DelegatingSecurityContextRunnable.create(this.runnable, null))
+		this.delegatingSecurityContextRunnable
+				.when(() -> DelegatingSecurityContextRunnable.create(eq(this.runnable), isNull(), any()))
 				.thenReturn(this.wrappedRunnable);
 	}
 

+ 24 - 1
core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextCallableTests.java

@@ -30,12 +30,16 @@ import org.mockito.internal.stubbing.answers.Returns;
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.junit.jupiter.MockitoExtension;
 
+import org.springframework.security.core.context.MockSecurityContextHolderStrategy;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 
 /**
@@ -68,10 +72,15 @@ public class DelegatingSecurityContextCallableTests {
 	}
 
 	private void givenDelegateCallWillAnswerWithCurrentSecurityContext() throws Exception {
+		givenDelegateCallWillAnswerWithCurrentSecurityContext(SecurityContextHolder.getContextHolderStrategy());
+	}
+
+	private void givenDelegateCallWillAnswerWithCurrentSecurityContext(SecurityContextHolderStrategy strategy)
+			throws Exception {
 		given(this.delegate.call()).willAnswer(new Returns(this.callableResult) {
 			@Override
 			public Object answer(InvocationOnMock invocation) throws Throwable {
-				assertThat(SecurityContextHolder.getContext())
+				assertThat(strategy.getContext())
 						.isEqualTo(DelegatingSecurityContextCallableTests.this.securityContext);
 				return super.answer(invocation);
 			}
@@ -122,6 +131,20 @@ public class DelegatingSecurityContextCallableTests {
 		assertWrapped(this.callable);
 	}
 
+	@Test
+	public void callDefaultSecurityContextWithCustomSecurityContextHolderStrategy() throws Exception {
+		SecurityContextHolderStrategy securityContextHolderStrategy = spy(new MockSecurityContextHolderStrategy());
+		givenDelegateCallWillAnswerWithCurrentSecurityContext(securityContextHolderStrategy);
+		securityContextHolderStrategy.setContext(this.securityContext);
+		DelegatingSecurityContextCallable<Object> callable = new DelegatingSecurityContextCallable<>(this.delegate);
+		callable.setSecurityContextHolderStrategy(securityContextHolderStrategy);
+		this.callable = callable;
+		// ensure callable is what sets up the SecurityContextHolder
+		securityContextHolderStrategy.clearContext();
+		assertWrapped(this.callable);
+		verify(securityContextHolderStrategy, atLeastOnce()).getContext();
+	}
+
 	// SEC-3031
 	@Test
 	public void callOnSameThread() throws Exception {

+ 25 - 0
core/src/test/java/org/springframework/security/concurrent/DelegatingSecurityContextRunnableTests.java

@@ -30,12 +30,16 @@ import org.mockito.stubbing.Answer;
 
 import org.springframework.core.task.SyncTaskExecutor;
 import org.springframework.core.task.support.ExecutorServiceAdapter;
+import org.springframework.security.core.context.MockSecurityContextHolderStrategy;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
 import static org.mockito.BDDMockito.willAnswer;
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 
 /**
@@ -73,6 +77,13 @@ public class DelegatingSecurityContextRunnableTests {
 		}).given(this.delegate).run();
 	}
 
+	private void givenDelegateRunWillAnswerWithCurrentSecurityContext(SecurityContextHolderStrategy strategy) {
+		willAnswer((Answer<Object>) (invocation) -> {
+			assertThat(strategy.getContext()).isEqualTo(this.securityContext);
+			return null;
+		}).given(this.delegate).run();
+	}
+
 	@AfterEach
 	public void tearDown() {
 		SecurityContextHolder.clearContext();
@@ -117,6 +128,20 @@ public class DelegatingSecurityContextRunnableTests {
 		assertWrapped(this.runnable);
 	}
 
+	@Test
+	public void callDefaultSecurityContextWithCustomSecurityContextHolderStrategy() throws Exception {
+		SecurityContextHolderStrategy securityContextHolderStrategy = spy(new MockSecurityContextHolderStrategy());
+		givenDelegateRunWillAnswerWithCurrentSecurityContext(securityContextHolderStrategy);
+		securityContextHolderStrategy.setContext(this.securityContext);
+		DelegatingSecurityContextRunnable runnable = new DelegatingSecurityContextRunnable(this.delegate);
+		runnable.setSecurityContextHolderStrategy(securityContextHolderStrategy);
+		this.runnable = runnable;
+		// ensure callable is what sets up the SecurityContextHolder
+		securityContextHolderStrategy.clearContext();
+		assertWrapped(this.runnable);
+		verify(securityContextHolderStrategy, atLeastOnce()).getContext();
+	}
+
 	// SEC-3031
 	@Test
 	public void callOnSameThread() throws Exception {

+ 43 - 0
core/src/test/java/org/springframework/security/core/context/MockSecurityContextHolderStrategy.java

@@ -0,0 +1,43 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.core.context;
+
+public class MockSecurityContextHolderStrategy implements SecurityContextHolderStrategy {
+
+	private SecurityContext context;
+
+	@Override
+	public void clearContext() {
+		this.context = null;
+	}
+
+	@Override
+	public SecurityContext getContext() {
+		return this.context;
+	}
+
+	@Override
+	public void setContext(SecurityContext context) {
+		this.context = context;
+	}
+
+	@Override
+	public SecurityContext createEmptyContext() {
+		return new SecurityContextImpl();
+	}
+
+}