2
0
Эх сурвалжийг харах

SecurityContextHolder Deferred SecurityContext

Closes gh-10913
Rob Winch 3 жил өмнө
parent
commit
b6d43e58c0

+ 27 - 7
core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java

@@ -16,6 +16,8 @@
 
 package org.springframework.security.core.context;
 
+import java.util.function.Supplier;
+
 import org.springframework.util.Assert;
 
 /**
@@ -23,11 +25,12 @@ import org.springframework.util.Assert;
  * {@link org.springframework.security.core.context.SecurityContextHolderStrategy}.
  *
  * @author Ben Alex
+ * @author Rob Winch
  * @see java.lang.ThreadLocal
  */
 final class InheritableThreadLocalSecurityContextHolderStrategy implements SecurityContextHolderStrategy {
 
-	private static final ThreadLocal<SecurityContext> contextHolder = new InheritableThreadLocal<>();
+	private static final ThreadLocal<Supplier<SecurityContext>> contextHolder = new InheritableThreadLocal<>();
 
 	@Override
 	public void clearContext() {
@@ -36,18 +39,35 @@ final class InheritableThreadLocalSecurityContextHolderStrategy implements Secur
 
 	@Override
 	public SecurityContext getContext() {
-		SecurityContext ctx = contextHolder.get();
-		if (ctx == null) {
-			ctx = createEmptyContext();
-			contextHolder.set(ctx);
+		return getDeferredContext().get();
+	}
+
+	@Override
+	public Supplier<SecurityContext> getDeferredContext() {
+		Supplier<SecurityContext> result = contextHolder.get();
+		if (result == null) {
+			SecurityContext context = createEmptyContext();
+			result = () -> context;
+			contextHolder.set(result);
 		}
-		return ctx;
+		return result;
 	}
 
 	@Override
 	public void setContext(SecurityContext context) {
 		Assert.notNull(context, "Only non-null SecurityContext instances are permitted");
-		contextHolder.set(context);
+		contextHolder.set(() -> context);
+	}
+
+	@Override
+	public void setDeferredContext(Supplier<SecurityContext> deferredContext) {
+		Assert.notNull(deferredContext, "Only non-null Supplier instances are permitted");
+		Supplier<SecurityContext> notNullDeferredContext = () -> {
+			SecurityContext result = deferredContext.get();
+			Assert.notNull(result, "A Supplier<SecurityContext> returned null and is not allowed.");
+			return result;
+		};
+		contextHolder.set(notNullDeferredContext);
 	}
 
 	@Override

+ 22 - 0
core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java

@@ -17,6 +17,7 @@
 package org.springframework.security.core.context;
 
 import java.lang.reflect.Constructor;
+import java.util.function.Supplier;
 
 import org.springframework.util.Assert;
 import org.springframework.util.ReflectionUtils;
@@ -46,6 +47,7 @@ import org.springframework.util.StringUtils;
  * {@link #MODE_GLOBAL} is definitely inappropriate for server use).
  *
  * @author Ben Alex
+ * @author Rob Winch
  *
  */
 public class SecurityContextHolder {
@@ -123,6 +125,16 @@ public class SecurityContextHolder {
 		return strategy.getContext();
 	}
 
+	/**
+	 * Obtains a {@link Supplier} that returns the current context.
+	 * @return a {@link Supplier} that returns the current context (never
+	 * <code>null</code> - create a default implementation if necessary)
+	 * @since 5.8
+	 */
+	public static Supplier<SecurityContext> getDeferredContext() {
+		return strategy.getDeferredContext();
+	}
+
 	/**
 	 * Primarily for troubleshooting purposes, this method shows how many times the class
 	 * has re-initialized its <code>SecurityContextHolderStrategy</code>.
@@ -143,6 +155,16 @@ public class SecurityContextHolder {
 		strategy.setContext(context);
 	}
 
+	/**
+	 * Sets a {@link Supplier} that will return the current context. Implementations can
+	 * override the default to avoid invoking {@link Supplier#get()}.
+	 * @param deferredContext a {@link Supplier} that returns the {@link SecurityContext}
+	 * @since 5.8
+	 */
+	public static void setDeferredContext(Supplier<SecurityContext> deferredContext) {
+		strategy.setDeferredContext(deferredContext);
+	}
+
 	/**
 	 * Changes the preferred strategy. Do <em>NOT</em> call this method more than once for
 	 * a given JVM, as it will re-initialize the strategy and adversely affect any

+ 23 - 0
core/src/main/java/org/springframework/security/core/context/SecurityContextHolderStrategy.java

@@ -16,6 +16,8 @@
 
 package org.springframework.security.core.context;
 
+import java.util.function.Supplier;
+
 /**
  * A strategy for storing security context information against a thread.
  *
@@ -23,6 +25,7 @@ package org.springframework.security.core.context;
  * The preferred strategy is loaded by {@link SecurityContextHolder}.
  *
  * @author Ben Alex
+ * @author Rob Winch
  */
 public interface SecurityContextHolderStrategy {
 
@@ -38,6 +41,16 @@ public interface SecurityContextHolderStrategy {
 	 */
 	SecurityContext getContext();
 
+	/**
+	 * Obtains a {@link Supplier} that returns the current context.
+	 * @return a {@link Supplier} that returns the current context (never
+	 * <code>null</code> - create a default implementation if necessary)
+	 * @since 5.8
+	 */
+	default Supplier<SecurityContext> getDeferredContext() {
+		return () -> getContext();
+	}
+
 	/**
 	 * Sets the current context.
 	 * @param context to the new argument (should never be <code>null</code>, although
@@ -46,6 +59,16 @@ public interface SecurityContextHolderStrategy {
 	 */
 	void setContext(SecurityContext context);
 
+	/**
+	 * Sets a {@link Supplier} that will return the current context. Implementations can
+	 * override the default to avoid invoking {@link Supplier#get()}.
+	 * @param deferredContext a {@link Supplier} that returns the {@link SecurityContext}
+	 * @since 5.8
+	 */
+	default void setDeferredContext(Supplier<SecurityContext> deferredContext) {
+		setContext(deferredContext.get());
+	}
+
 	/**
 	 * Creates a new, empty context implementation, for use by
 	 * <tt>SecurityContextRepository</tt> implementations, when creating a new context for

+ 27 - 7
core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java

@@ -16,6 +16,8 @@
 
 package org.springframework.security.core.context;
 
+import java.util.function.Supplier;
+
 import org.springframework.util.Assert;
 
 /**
@@ -23,12 +25,13 @@ import org.springframework.util.Assert;
  * {@link SecurityContextHolderStrategy}.
  *
  * @author Ben Alex
+ * @author Rob Winch
  * @see java.lang.ThreadLocal
  * @see org.springframework.security.core.context.web.SecurityContextPersistenceFilter
  */
 final class ThreadLocalSecurityContextHolderStrategy implements SecurityContextHolderStrategy {
 
-	private static final ThreadLocal<SecurityContext> contextHolder = new ThreadLocal<>();
+	private static final ThreadLocal<Supplier<SecurityContext>> contextHolder = new ThreadLocal<>();
 
 	@Override
 	public void clearContext() {
@@ -37,18 +40,35 @@ final class ThreadLocalSecurityContextHolderStrategy implements SecurityContextH
 
 	@Override
 	public SecurityContext getContext() {
-		SecurityContext ctx = contextHolder.get();
-		if (ctx == null) {
-			ctx = createEmptyContext();
-			contextHolder.set(ctx);
+		return getDeferredContext().get();
+	}
+
+	@Override
+	public Supplier<SecurityContext> getDeferredContext() {
+		Supplier<SecurityContext> result = contextHolder.get();
+		if (result == null) {
+			SecurityContext context = createEmptyContext();
+			result = () -> context;
+			contextHolder.set(result);
 		}
-		return ctx;
+		return result;
 	}
 
 	@Override
 	public void setContext(SecurityContext context) {
 		Assert.notNull(context, "Only non-null SecurityContext instances are permitted");
-		contextHolder.set(context);
+		contextHolder.set(() -> context);
+	}
+
+	@Override
+	public void setDeferredContext(Supplier<SecurityContext> deferredContext) {
+		Assert.notNull(deferredContext, "Only non-null Supplier instances are permitted");
+		Supplier<SecurityContext> notNullDeferredContext = () -> {
+			SecurityContext result = deferredContext.get();
+			Assert.notNull(result, "A Supplier<SecurityContext> returned null and is not allowed.");
+			return result;
+		};
+		contextHolder.set(notNullDeferredContext);
 	}
 
 	@Override

+ 84 - 0
core/src/test/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategyTests.java

@@ -0,0 +1,84 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.core.context;
+
+import java.util.function.Supplier;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
+
+import org.springframework.security.core.Authentication;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verifyNoInteractions;
+
+class InheritableThreadLocalSecurityContextHolderStrategyTests {
+
+	InheritableThreadLocalSecurityContextHolderStrategy strategy = new InheritableThreadLocalSecurityContextHolderStrategy();
+
+	@AfterEach
+	void clearContext() {
+		this.strategy.clearContext();
+	}
+
+	@Test
+	void deferredNotInvoked() {
+		Supplier<SecurityContext> deferredContext = mock(Supplier.class);
+		this.strategy.setDeferredContext(deferredContext);
+		verifyNoInteractions(deferredContext);
+	}
+
+	@Test
+	void deferredContext() {
+		Authentication authentication = mock(Authentication.class);
+		Supplier<SecurityContext> deferredContext = () -> new SecurityContextImpl(authentication);
+		this.strategy.setDeferredContext(deferredContext);
+		assertThat(this.strategy.getDeferredContext().get()).isEqualTo(deferredContext.get());
+		assertThat(this.strategy.getContext()).isEqualTo(deferredContext.get());
+	}
+
+	@Test
+	void deferredContextValidates() {
+		this.strategy.setDeferredContext(() -> null);
+		Supplier<SecurityContext> deferredContext = this.strategy.getDeferredContext();
+		assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> deferredContext.get());
+	}
+
+	@Test
+	void context() {
+		Authentication authentication = mock(Authentication.class);
+		SecurityContext context = new SecurityContextImpl(authentication);
+		this.strategy.setContext(context);
+		assertThat(this.strategy.getContext()).isEqualTo(context);
+		assertThat(this.strategy.getDeferredContext().get()).isEqualTo(context);
+	}
+
+	@Test
+	void contextValidates() {
+		assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> this.strategy.setContext(null));
+	}
+
+	@Test
+	void getContextWhenEmptyThenReturnsSameInstance() {
+		Authentication authentication = mock(Authentication.class);
+		this.strategy.getContext().setAuthentication(authentication);
+		assertThat(this.strategy.getContext().getAuthentication()).isEqualTo(authentication);
+	}
+
+}

+ 84 - 0
core/src/test/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategyTests.java

@@ -0,0 +1,84 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.core.context;
+
+import java.util.function.Supplier;
+
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Test;
+
+import org.springframework.security.core.Authentication;
+
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verifyNoInteractions;
+
+class ThreadLocalSecurityContextHolderStrategyTests {
+
+	ThreadLocalSecurityContextHolderStrategy strategy = new ThreadLocalSecurityContextHolderStrategy();
+
+	@AfterEach
+	void clearContext() {
+		this.strategy.clearContext();
+	}
+
+	@Test
+	void deferredNotInvoked() {
+		Supplier<SecurityContext> deferredContext = mock(Supplier.class);
+		this.strategy.setDeferredContext(deferredContext);
+		verifyNoInteractions(deferredContext);
+	}
+
+	@Test
+	void deferredContext() {
+		Authentication authentication = mock(Authentication.class);
+		Supplier<SecurityContext> deferredContext = () -> new SecurityContextImpl(authentication);
+		this.strategy.setDeferredContext(deferredContext);
+		assertThat(this.strategy.getDeferredContext().get()).isEqualTo(deferredContext.get());
+		assertThat(this.strategy.getContext()).isEqualTo(deferredContext.get());
+	}
+
+	@Test
+	void deferredContextValidates() {
+		this.strategy.setDeferredContext(() -> null);
+		Supplier<SecurityContext> deferredContext = this.strategy.getDeferredContext();
+		assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> deferredContext.get());
+	}
+
+	@Test
+	void context() {
+		Authentication authentication = mock(Authentication.class);
+		SecurityContext context = new SecurityContextImpl(authentication);
+		this.strategy.setContext(context);
+		assertThat(this.strategy.getContext()).isEqualTo(context);
+		assertThat(this.strategy.getDeferredContext().get()).isEqualTo(context);
+	}
+
+	@Test
+	void contextValidates() {
+		assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> this.strategy.setContext(null));
+	}
+
+	@Test
+	void getContextWhenEmptyThenReturnsSameInstance() {
+		Authentication authentication = mock(Authentication.class);
+		this.strategy.getContext().setAuthentication(authentication);
+		assertThat(this.strategy.getContext().getAuthentication()).isEqualTo(authentication);
+	}
+
+}

+ 3 - 2
web/src/main/java/org/springframework/security/web/context/SecurityContextHolderFilter.java

@@ -17,6 +17,7 @@
 package org.springframework.security.web.context;
 
 import java.io.IOException;
+import java.util.function.Supplier;
 
 import jakarta.servlet.FilterChain;
 import jakarta.servlet.ServletException;
@@ -62,9 +63,9 @@ public class SecurityContextHolderFilter extends OncePerRequestFilter {
 	@Override
 	protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
 			throws ServletException, IOException {
-		SecurityContext securityContext = this.securityContextRepository.loadContext(request).get();
+		Supplier<SecurityContext> deferredContext = this.securityContextRepository.loadContext(request);
 		try {
-			this.securityContextHolderStrategy.setContext(securityContext);
+			this.securityContextHolderStrategy.setDeferredContext(deferredContext);
 			filterChain.doFilter(request, response);
 		}
 		finally {

+ 5 - 1
web/src/test/java/org/springframework/security/web/context/SecurityContextHolderFilterTests.java

@@ -16,6 +16,8 @@
 
 package org.springframework.security.web.context;
 
+import java.util.function.Supplier;
+
 import jakarta.servlet.FilterChain;
 import jakarta.servlet.http.HttpServletRequest;
 import jakarta.servlet.http.HttpServletResponse;
@@ -93,7 +95,9 @@ class SecurityContextHolderFilterTests {
 		this.filter.setSecurityContextHolderStrategy(this.strategy);
 		this.filter.doFilter(this.request, this.response, filterChain);
 
-		verify(this.strategy).setContext(expectedContext);
+		ArgumentCaptor<Supplier<SecurityContext>> deferredContextArg = ArgumentCaptor.forClass(Supplier.class);
+		verify(this.strategy).setDeferredContext(deferredContextArg.capture());
+		assertThat(deferredContextArg.getValue().get()).isEqualTo(expectedContext);
 		verify(this.strategy).clearContext();
 	}