Просмотр исходного кода

Replace SecurityContextHolder#addListener

Closes gh-10226
Josh Cummings 4 лет назад
Родитель
Сommit
3e87ef84ae

+ 0 - 4
core/src/main/java/org/springframework/security/core/context/GlobalSecurityContextHolderStrategy.java

@@ -31,10 +31,6 @@ final class GlobalSecurityContextHolderStrategy implements SecurityContextHolder
 
 	private static SecurityContext contextHolder;
 
-	SecurityContext peek() {
-		return contextHolder;
-	}
-
 	@Override
 	public void clearContext() {
 		contextHolder = null;

+ 0 - 4
core/src/main/java/org/springframework/security/core/context/InheritableThreadLocalSecurityContextHolderStrategy.java

@@ -29,10 +29,6 @@ final class InheritableThreadLocalSecurityContextHolderStrategy implements Secur
 
 	private static final ThreadLocal<SecurityContext> contextHolder = new InheritableThreadLocal<>();
 
-	SecurityContext peek() {
-		return contextHolder.get();
-	}
-
 	@Override
 	public void clearContext() {
 		contextHolder.remove();

+ 93 - 36
core/src/main/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategy.java

@@ -16,73 +16,130 @@
 
 package org.springframework.security.core.context;
 
-import java.util.List;
-import java.util.concurrent.CopyOnWriteArrayList;
-import java.util.function.BiConsumer;
-import java.util.function.Supplier;
+import java.util.Arrays;
+import java.util.Collection;
 
-final class ListeningSecurityContextHolderStrategy implements SecurityContextHolderStrategy {
+import org.springframework.util.Assert;
 
-	private static final BiConsumer<SecurityContext, SecurityContext> NULL_PUBLISHER = (previous, current) -> {
-	};
+/**
+ * An API for notifying when the {@link SecurityContext} changes.
+ *
+ * Note that this does not notify when the underlying authentication changes. To get
+ * notified about authentication changes, ensure that you are using {@link #setContext}
+ * when changing the authentication like so:
+ *
+ * <pre>
+ *	SecurityContext context = SecurityContextHolder.createEmptyContext();
+ *	context.setAuthentication(authentication);
+ *	SecurityContextHolder.setContext(context);
+ * </pre>
+ *
+ * To add a listener to the existing {@link SecurityContextHolder}, you can do:
+ *
+ * <pre>
+ *  SecurityContextHolderStrategy original = SecurityContextHolder.getContextHolderStrategy();
+ *  SecurityContextChangedListener listener = new YourListener();
+ *  SecurityContextHolderStrategy strategy = new ListeningSecurityContextHolderStrategy(original, listener);
+ *  SecurityContextHolder.setContextHolderStrategy(strategy);
+ * </pre>
+ *
+ * NOTE: Any object that you supply to the {@link SecurityContextHolder} is now part of
+ * the static context and as such will not get garbage collected. To remove the reference,
+ * {@link SecurityContextHolder#setContextHolderStrategy reset the strategy} like so:
+ *
+ * <pre>
+ *   SecurityContextHolder.setContextHolderStrategy(original);
+ * </pre>
+ *
+ * This will then allow {@code YourListener} and its members to be garbage collected.
+ *
+ * @author Josh Cummings
+ * @since 5.6
+ */
+public final class ListeningSecurityContextHolderStrategy implements SecurityContextHolderStrategy {
 
-	private final Supplier<SecurityContext> peek;
+	private final Collection<SecurityContextChangedListener> listeners;
 
 	private final SecurityContextHolderStrategy delegate;
 
-	private final SecurityContextEventPublisher base = new SecurityContextEventPublisher();
-
-	private BiConsumer<SecurityContext, SecurityContext> publisher = NULL_PUBLISHER;
+	/**
+	 * Construct a {@link ListeningSecurityContextHolderStrategy}
+	 * @param listeners the listeners that should be notified when the
+	 * {@link SecurityContext} is {@link #setContext(SecurityContext) set} or
+	 * {@link #clearContext() cleared}
+	 * @param delegate the underlying {@link SecurityContextHolderStrategy}
+	 */
+	public ListeningSecurityContextHolderStrategy(SecurityContextHolderStrategy delegate,
+			Collection<SecurityContextChangedListener> listeners) {
+		Assert.notNull(delegate, "securityContextHolderStrategy cannot be null");
+		Assert.notNull(listeners, "securityContextChangedListeners cannot be null");
+		Assert.notEmpty(listeners, "securityContextChangedListeners cannot be empty");
+		Assert.noNullElements(listeners, "securityContextChangedListeners cannot contain null elements");
+		this.delegate = delegate;
+		this.listeners = listeners;
+	}
 
-	ListeningSecurityContextHolderStrategy(Supplier<SecurityContext> peek, SecurityContextHolderStrategy delegate) {
-		this.peek = peek;
+	/**
+	 * Construct a {@link ListeningSecurityContextHolderStrategy}
+	 * @param listeners the listeners that should be notified when the
+	 * {@link SecurityContext} is {@link #setContext(SecurityContext) set} or
+	 * {@link #clearContext() cleared}
+	 * @param delegate the underlying {@link SecurityContextHolderStrategy}
+	 */
+	public ListeningSecurityContextHolderStrategy(SecurityContextHolderStrategy delegate,
+			SecurityContextChangedListener... listeners) {
+		Assert.notNull(delegate, "securityContextHolderStrategy cannot be null");
+		Assert.notNull(listeners, "securityContextChangedListeners cannot be null");
+		Assert.notEmpty(listeners, "securityContextChangedListeners cannot be empty");
+		Assert.noNullElements(listeners, "securityContextChangedListeners cannot contain null elements");
 		this.delegate = delegate;
+		this.listeners = Arrays.asList(listeners);
 	}
 
+	/**
+	 * {@inheritDoc}
+	 */
 	@Override
 	public void clearContext() {
-		SecurityContext from = this.peek.get();
+		SecurityContext from = getContext();
 		this.delegate.clearContext();
-		this.publisher.accept(from, null);
+		publish(from, null);
 	}
 
+	/**
+	 * {@inheritDoc}
+	 */
 	@Override
 	public SecurityContext getContext() {
 		return this.delegate.getContext();
 	}
 
+	/**
+	 * {@inheritDoc}
+	 */
 	@Override
 	public void setContext(SecurityContext context) {
-		SecurityContext from = this.peek.get();
+		SecurityContext from = getContext();
 		this.delegate.setContext(context);
-		this.publisher.accept(from, context);
+		publish(from, context);
 	}
 
+	/**
+	 * {@inheritDoc}
+	 */
 	@Override
 	public SecurityContext createEmptyContext() {
 		return this.delegate.createEmptyContext();
 	}
 
-	void addListener(SecurityContextChangedListener listener) {
-		this.base.listeners.add(listener);
-		this.publisher = this.base;
-	}
-
-	private static class SecurityContextEventPublisher implements BiConsumer<SecurityContext, SecurityContext> {
-
-		private final List<SecurityContextChangedListener> listeners = new CopyOnWriteArrayList<>();
-
-		@Override
-		public void accept(SecurityContext previous, SecurityContext current) {
-			if (previous == current) {
-				return;
-			}
-			SecurityContextChangedEvent event = new SecurityContextChangedEvent(previous, current);
-			for (SecurityContextChangedListener listener : this.listeners) {
-				listener.securityContextChanged(event);
-			}
+	private void publish(SecurityContext previous, SecurityContext current) {
+		if (previous == current) {
+			return;
+		}
+		SecurityContextChangedEvent event = new SecurityContextChangedEvent(previous, current);
+		for (SecurityContextChangedListener listener : this.listeners) {
+			listener.securityContextChanged(event);
 		}
-
 	}
 
 }

+ 68 - 50
core/src/main/java/org/springframework/security/core/context/SecurityContextHolder.java

@@ -56,6 +56,8 @@ public class SecurityContextHolder {
 
 	public static final String MODE_GLOBAL = "MODE_GLOBAL";
 
+	private static final String MODE_PRE_INITIALIZED = "MODE_PRE_INITIALIZED";
+
 	public static final String SYSTEM_PROPERTY = "spring.security.strategy";
 
 	private static String strategyName = System.getProperty(SYSTEM_PROPERTY);
@@ -69,34 +71,41 @@ public class SecurityContextHolder {
 	}
 
 	private static void initialize() {
+		initializeStrategy();
+		initializeCount++;
+	}
+
+	private static void initializeStrategy() {
+		if (MODE_PRE_INITIALIZED.equals(strategyName)) {
+			Assert.state(strategy != null, "When using " + MODE_PRE_INITIALIZED
+					+ ", setContextHolderStrategy must be called with the fully constructed strategy");
+			return;
+		}
 		if (!StringUtils.hasText(strategyName)) {
 			// Set default
 			strategyName = MODE_THREADLOCAL;
 		}
 		if (strategyName.equals(MODE_THREADLOCAL)) {
-			ThreadLocalSecurityContextHolderStrategy delegate = new ThreadLocalSecurityContextHolderStrategy();
-			strategy = new ListeningSecurityContextHolderStrategy(delegate::peek, delegate);
+			strategy = new ThreadLocalSecurityContextHolderStrategy();
+			return;
 		}
-		else if (strategyName.equals(MODE_INHERITABLETHREADLOCAL)) {
-			InheritableThreadLocalSecurityContextHolderStrategy delegate = new InheritableThreadLocalSecurityContextHolderStrategy();
-			strategy = new ListeningSecurityContextHolderStrategy(delegate::peek, delegate);
+		if (strategyName.equals(MODE_INHERITABLETHREADLOCAL)) {
+			strategy = new InheritableThreadLocalSecurityContextHolderStrategy();
+			return;
 		}
-		else if (strategyName.equals(MODE_GLOBAL)) {
-			GlobalSecurityContextHolderStrategy delegate = new GlobalSecurityContextHolderStrategy();
-			strategy = new ListeningSecurityContextHolderStrategy(delegate::peek, delegate);
+		if (strategyName.equals(MODE_GLOBAL)) {
+			strategy = new GlobalSecurityContextHolderStrategy();
+			return;
 		}
-		else {
-			// Try to load a custom strategy
-			try {
-				Class<?> clazz = Class.forName(strategyName);
-				Constructor<?> customStrategy = clazz.getConstructor();
-				strategy = (SecurityContextHolderStrategy) customStrategy.newInstance();
-			}
-			catch (Exception ex) {
-				ReflectionUtils.handleReflectionException(ex);
-			}
+		// Try to load a custom strategy
+		try {
+			Class<?> clazz = Class.forName(strategyName);
+			Constructor<?> customStrategy = clazz.getConstructor();
+			strategy = (SecurityContextHolderStrategy) customStrategy.newInstance();
+		}
+		catch (Exception ex) {
+			ReflectionUtils.handleReflectionException(ex);
 		}
-		initializeCount++;
 	}
 
 	/**
@@ -118,7 +127,9 @@ public class SecurityContextHolder {
 	 * Primarily for troubleshooting purposes, this method shows how many times the class
 	 * has re-initialized its <code>SecurityContextHolderStrategy</code>.
 	 * @return the count (should be one unless you've called
-	 * {@link #setStrategyName(String)} to switch to an alternate strategy.
+	 * {@link #setStrategyName(String)} or
+	 * {@link #setContextHolderStrategy(SecurityContextHolderStrategy)} to switch to an
+	 * alternate strategy).
 	 */
 	public static int getInitializeCount() {
 		return initializeCount;
@@ -144,6 +155,41 @@ public class SecurityContextHolder {
 		initialize();
 	}
 
+	/**
+	 * Use this {@link SecurityContextHolderStrategy}.
+	 *
+	 * Call either {@link #setStrategyName(String)} or this method, but not both.
+	 *
+	 * This method is not thread safe. Changing the strategy while requests are in-flight
+	 * may cause race conditions.
+	 *
+	 * {@link SecurityContextHolder} maintains a static reference to the provided
+	 * {@link SecurityContextHolderStrategy}. This means that the strategy and its members
+	 * will not be garbage collected until you remove your strategy.
+	 *
+	 * To ensure garbage collection, remember the original strategy like so:
+	 *
+	 * <pre>
+	 *     SecurityContextHolderStrategy original = SecurityContextHolder.getContextHolderStrategy();
+	 *     SecurityContextHolder.setContextHolderStrategy(myStrategy);
+	 * </pre>
+	 *
+	 * And then when you are ready for {@code myStrategy} to be garbage collected you can
+	 * do:
+	 *
+	 * <pre>
+	 *     SecurityContextHolder.setContextHolderStrategy(original);
+	 * </pre>
+	 * @param strategy the {@link SecurityContextHolderStrategy} to use
+	 * @since 5.6
+	 */
+	public static void setContextHolderStrategy(SecurityContextHolderStrategy strategy) {
+		Assert.notNull(strategy, "securityContextHolderStrategy cannot be null");
+		SecurityContextHolder.strategyName = MODE_PRE_INITIALIZED;
+		SecurityContextHolder.strategy = strategy;
+		initialize();
+	}
+
 	/**
 	 * Allows retrieval of the context strategy. See SEC-1188.
 	 * @return the configured strategy for storing the security context.
@@ -159,38 +205,10 @@ public class SecurityContextHolder {
 		return strategy.createEmptyContext();
 	}
 
-	/**
-	 * Register a listener to be notified when the {@link SecurityContext} changes.
-	 *
-	 * Note that this does not notify when the underlying authentication changes. To get
-	 * notified about authentication changes, ensure that you are using
-	 * {@link #setContext} when changing the authentication like so:
-	 *
-	 * <pre>
-	 *	SecurityContext context = SecurityContextHolder.createEmptyContext();
-	 *	context.setAuthentication(authentication);
-	 *	SecurityContextHolder.setContext(context);
-	 * </pre>
-	 *
-	 * To integrate this with Spring's
-	 * {@link org.springframework.context.ApplicationEvent} support, you can add a
-	 * listener like so:
-	 *
-	 * <pre>
-	 *	SecurityContextHolder.addListener(this.applicationContext::publishEvent);
-	 * </pre>
-	 * @param listener a listener to be notified when the {@link SecurityContext} changes
-	 * @since 5.6
-	 */
-	public static void addListener(SecurityContextChangedListener listener) {
-		Assert.isInstanceOf(ListeningSecurityContextHolderStrategy.class, strategy,
-				"strategy must be of type ListeningSecurityContextHolderStrategy to add listeners");
-		((ListeningSecurityContextHolderStrategy) strategy).addListener(listener);
-	}
-
 	@Override
 	public String toString() {
-		return "SecurityContextHolder[strategy='" + strategyName + "'; initializeCount=" + initializeCount + "]";
+		return "SecurityContextHolder[strategy='" + strategy.getClass().getSimpleName() + "'; initializeCount="
+				+ initializeCount + "]";
 	}
 
 }

+ 0 - 4
core/src/main/java/org/springframework/security/core/context/ThreadLocalSecurityContextHolderStrategy.java

@@ -30,10 +30,6 @@ final class ThreadLocalSecurityContextHolderStrategy implements SecurityContextH
 
 	private static final ThreadLocal<SecurityContext> contextHolder = new ThreadLocal<>();
 
-	SecurityContext peek() {
-		return contextHolder.get();
-	}
-
 	@Override
 	public void clearContext() {
 		contextHolder.remove();

+ 70 - 0
core/src/test/java/org/springframework/security/core/context/ListeningSecurityContextHolderStrategyTests.java

@@ -0,0 +1,70 @@
+/*
+ * Copyright 2002-2021 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.core.context;
+
+import org.junit.jupiter.api.Test;
+
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.verifyNoInteractions;
+
+public class ListeningSecurityContextHolderStrategyTests {
+
+	@Test
+	public void setContextWhenInvokedThenListenersAreNotified() {
+		SecurityContextHolderStrategy delegate = mock(SecurityContextHolderStrategy.class);
+		SecurityContextChangedListener one = mock(SecurityContextChangedListener.class);
+		SecurityContextChangedListener two = mock(SecurityContextChangedListener.class);
+		SecurityContextHolderStrategy strategy = new ListeningSecurityContextHolderStrategy(delegate, one, two);
+		given(delegate.createEmptyContext()).willReturn(new SecurityContextImpl());
+		SecurityContext context = strategy.createEmptyContext();
+		strategy.setContext(context);
+		verify(delegate).setContext(context);
+		verify(one).securityContextChanged(any());
+		verify(two).securityContextChanged(any());
+	}
+
+	@Test
+	public void setContextWhenNoChangeToContextThenListenersAreNotNotified() {
+		SecurityContextHolderStrategy delegate = mock(SecurityContextHolderStrategy.class);
+		SecurityContextChangedListener listener = mock(SecurityContextChangedListener.class);
+		SecurityContextHolderStrategy strategy = new ListeningSecurityContextHolderStrategy(delegate, listener);
+		SecurityContext context = new SecurityContextImpl();
+		given(delegate.getContext()).willReturn(context);
+		strategy.setContext(strategy.getContext());
+		verify(delegate).setContext(context);
+		verifyNoInteractions(listener);
+	}
+
+	@Test
+	public void constructorWhenNullDelegateThenIllegalArgument() {
+		assertThatExceptionOfType(IllegalArgumentException.class)
+				.isThrownBy(() -> new ListeningSecurityContextHolderStrategy(null, (event) -> {
+				}));
+	}
+
+	@Test
+	public void constructorWhenNullListenerThenIllegalArgument() {
+		assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(
+				() -> new ListeningSecurityContextHolderStrategy(new ThreadLocalSecurityContextHolderStrategy(),
+						(SecurityContextChangedListener) null));
+	}
+
+}

+ 11 - 12
core/src/test/java/org/springframework/security/core/context/SecurityContextHolderTests.java

@@ -23,9 +23,7 @@ import org.springframework.security.authentication.UsernamePasswordAuthenticatio
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
-import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 
 /**
@@ -63,16 +61,17 @@ public class SecurityContextHolderTests {
 	}
 
 	@Test
-	public void addListenerWhenInvokedThenListenersAreNotified() {
-		SecurityContextChangedListener one = mock(SecurityContextChangedListener.class);
-		SecurityContextChangedListener two = mock(SecurityContextChangedListener.class);
-		SecurityContextHolder.addListener(one);
-		SecurityContextHolder.addListener(two);
-		SecurityContext context = SecurityContextHolder.createEmptyContext();
-		SecurityContextHolder.setContext(context);
-		SecurityContextHolder.clearContext();
-		verify(one, times(2)).securityContextChanged(any(SecurityContextChangedEvent.class));
-		verify(two, times(2)).securityContextChanged(any(SecurityContextChangedEvent.class));
+	public void setContextHolderStrategyWhenCalledThenUsed() {
+		SecurityContextHolderStrategy original = SecurityContextHolder.getContextHolderStrategy();
+		try {
+			SecurityContextHolderStrategy delegate = mock(SecurityContextHolderStrategy.class);
+			SecurityContextHolder.setContextHolderStrategy(delegate);
+			SecurityContextHolder.getContext();
+			verify(delegate).getContext();
+		}
+		finally {
+			SecurityContextHolder.setContextHolderStrategy(original);
+		}
 	}
 
 }