浏览代码

Use SecurityContextHolderStrategy for Messaging

Issue gh-11060
Josh Cummings 3 年之前
父节点
当前提交
1e498df39b

+ 24 - 10
messaging/src/main/java/org/springframework/security/messaging/access/intercept/AuthorizationChannelInterceptor.java

@@ -32,6 +32,7 @@ import org.springframework.security.authorization.AuthorizationEventPublisher;
 import org.springframework.security.authorization.AuthorizationManager;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.util.Assert;
 
 /**
@@ -42,14 +43,8 @@ import org.springframework.util.Assert;
  */
 public final class AuthorizationChannelInterceptor implements ChannelInterceptor {
 
-	static final Supplier<Authentication> AUTHENTICATION_SUPPLIER = () -> {
-		Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
-		if (authentication == null) {
-			throw new AuthenticationCredentialsNotFoundException(
-					"An Authentication object was not found in the SecurityContext");
-		}
-		return authentication;
-	};
+	private Supplier<Authentication> authentication = getAuthentication(
+			SecurityContextHolder.getContextHolderStrategy());
 
 	private final Log logger = LogFactory.getLog(this.getClass());
 
@@ -71,8 +66,8 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor
 	@Override
 	public Message<?> preSend(Message<?> message, MessageChannel channel) {
 		this.logger.debug(LogMessage.of(() -> "Authorizing message send"));
-		AuthorizationDecision decision = this.preSendAuthorizationManager.check(AUTHENTICATION_SUPPLIER, message);
-		this.eventPublisher.publishAuthorizationEvent(AUTHENTICATION_SUPPLIER, message, decision);
+		AuthorizationDecision decision = this.preSendAuthorizationManager.check(this.authentication, message);
+		this.eventPublisher.publishAuthorizationEvent(this.authentication, message, decision);
 		if (decision == null || !decision.isGranted()) { // default deny
 			this.logger.debug(LogMessage.of(() -> "Failed to authorize message with authorization manager "
 					+ this.preSendAuthorizationManager + " and decision " + decision));
@@ -82,6 +77,14 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor
 		return message;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		this.authentication = getAuthentication(securityContextHolderStrategy);
+	}
+
 	/**
 	 * Use this {@link AuthorizationEventPublisher} to publish the
 	 * {@link AuthorizationManager} result.
@@ -92,6 +95,17 @@ public final class AuthorizationChannelInterceptor implements ChannelInterceptor
 		this.eventPublisher = eventPublisher;
 	}
 
+	private Supplier<Authentication> getAuthentication(SecurityContextHolderStrategy strategy) {
+		return () -> {
+			Authentication authentication = strategy.getContext().getAuthentication();
+			if (authentication == null) {
+				throw new AuthenticationCredentialsNotFoundException(
+						"An Authentication object was not found in the SecurityContext");
+			}
+			return authentication;
+		};
+	}
+
 	private static class NoopAuthorizationEventPublisher implements AuthorizationEventPublisher {
 
 		@Override

+ 18 - 2
messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2021 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -29,7 +29,9 @@ import org.springframework.messaging.handler.invocation.HandlerMethodArgumentRes
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.annotation.AuthenticationPrincipal;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.stereotype.Controller;
+import org.springframework.util.Assert;
 import org.springframework.util.ClassUtils;
 import org.springframework.util.StringUtils;
 
@@ -85,6 +87,9 @@ import org.springframework.util.StringUtils;
  */
 public final class AuthenticationPrincipalArgumentResolver implements HandlerMethodArgumentResolver {
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	private ExpressionParser parser = new SpelExpressionParser();
 
 	@Override
@@ -94,7 +99,7 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet
 
 	@Override
 	public Object resolveArgument(MethodParameter parameter, Message<?> message) {
-		Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+		Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
 		if (authentication == null) {
 			return null;
 		}
@@ -117,6 +122,17 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet
 		return principal;
 	}
 
+	/**
+	 * Sets the {@link SecurityContextHolderStrategy} to use. The default action is to use
+	 * the {@link SecurityContextHolderStrategy} stored in {@link SecurityContextHolder}.
+	 *
+	 * @since 5.8
+	 */
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 	/**
 	 * Obtains the specified {@link Annotation} on the specified {@link MethodParameter}.
 	 * @param annotationClass the class of the {@link Annotation} to find on the

+ 20 - 11
messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2016 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -29,6 +29,7 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.util.Assert;
 
 /**
@@ -42,10 +43,13 @@ import org.springframework.util.Assert;
  */
 public final class SecurityContextChannelInterceptor implements ExecutorChannelInterceptor, ChannelInterceptor {
 
-	private static final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext();
-
 	private static final ThreadLocal<Stack<SecurityContext>> originalContext = new ThreadLocal<>();
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
+	private SecurityContext empty = this.securityContextHolderStrategy.createEmptyContext();
+
 	private final String authenticationHeaderName;
 
 	private Authentication anonymous = new AnonymousAuthenticationToken("key", "anonymous",
@@ -107,8 +111,13 @@ public final class SecurityContextChannelInterceptor implements ExecutorChannelI
 		cleanup();
 	}
 
+	public void setSecurityContextHolderStrategy(SecurityContextHolderStrategy strategy) {
+		this.securityContextHolderStrategy = strategy;
+		this.empty = this.securityContextHolderStrategy.createEmptyContext();
+	}
+
 	private void setup(Message<?> message) {
-		SecurityContext currentContext = SecurityContextHolder.getContext();
+		SecurityContext currentContext = this.securityContextHolderStrategy.getContext();
 		Stack<SecurityContext> contextStack = originalContext.get();
 		if (contextStack == null) {
 			contextStack = new Stack<>();
@@ -117,9 +126,9 @@ public final class SecurityContextChannelInterceptor implements ExecutorChannelI
 		contextStack.push(currentContext);
 		Object user = message.getHeaders().get(this.authenticationHeaderName);
 		Authentication authentication = getAuthentication(user);
-		SecurityContext context = SecurityContextHolder.createEmptyContext();
+		SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
 		context.setAuthentication(authentication);
-		SecurityContextHolder.setContext(context);
+		this.securityContextHolderStrategy.setContext(context);
 	}
 
 	private Authentication getAuthentication(Object user) {
@@ -132,22 +141,22 @@ public final class SecurityContextChannelInterceptor implements ExecutorChannelI
 	private void cleanup() {
 		Stack<SecurityContext> contextStack = originalContext.get();
 		if (contextStack == null || contextStack.isEmpty()) {
-			SecurityContextHolder.clearContext();
+			this.securityContextHolderStrategy.clearContext();
 			originalContext.remove();
 			return;
 		}
 		SecurityContext context = contextStack.pop();
 		try {
-			if (SecurityContextChannelInterceptor.EMPTY_CONTEXT.equals(context)) {
-				SecurityContextHolder.clearContext();
+			if (SecurityContextChannelInterceptor.this.empty.equals(context)) {
+				this.securityContextHolderStrategy.clearContext();
 				originalContext.remove();
 			}
 			else {
-				SecurityContextHolder.setContext(context);
+				this.securityContextHolderStrategy.setContext(context);
 			}
 		}
 		catch (Throwable ex) {
-			SecurityContextHolder.clearContext();
+			this.securityContextHolderStrategy.clearContext();
 		}
 	}
 

+ 46 - 1
messaging/src/test/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptorTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2016 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -34,9 +34,13 @@ import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.core.context.SecurityContextImpl;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
+import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
 
 @ExtendWith(MockitoExtension.class)
 public class SecurityContextChannelInterceptorTests {
@@ -94,6 +98,17 @@ public class SecurityContextChannelInterceptorTests {
 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication);
 	}
 
+	@Test
+	public void preSendWhenCustomSecurityContextHolderStrategyThenUserSet() {
+		SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
+		strategy.setContext(new SecurityContextImpl(this.authentication));
+		this.interceptor.setSecurityContextHolderStrategy(strategy);
+		this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication);
+		this.interceptor.preSend(this.messageBuilder.build(), this.channel);
+		verify(strategy).getContext();
+		assertThat(strategy.getContext().getAuthentication()).isSameAs(this.authentication);
+	}
+
 	@Test
 	public void setAnonymousAuthenticationNull() {
 		assertThatIllegalArgumentException().isThrownBy(() -> this.interceptor.setAnonymousAuthentication(null));
@@ -143,6 +158,16 @@ public class SecurityContextChannelInterceptorTests {
 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
 	}
 
+	@Test
+	public void afterSendCompletionWhenCustomSecurityContextHolderStrategyThenNullAuthentication() {
+		SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
+		strategy.setContext(new SecurityContextImpl(this.authentication));
+		this.interceptor.setSecurityContextHolderStrategy(strategy);
+		this.interceptor.afterSendCompletion(this.messageBuilder.build(), this.channel, true, null);
+		verify(strategy).clearContext();
+		assertThat(strategy.getContext().getAuthentication()).isNull();
+	}
+
 	@Test
 	public void beforeHandleUserSet() {
 		this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication);
@@ -150,6 +175,17 @@ public class SecurityContextChannelInterceptorTests {
 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(this.authentication);
 	}
 
+	@Test
+	public void beforeHandleWhenCustomSecurityContextHolderStrategyThenUserSet() {
+		SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
+		strategy.setContext(new SecurityContextImpl(this.authentication));
+		this.interceptor.setSecurityContextHolderStrategy(strategy);
+		this.messageBuilder.setHeader(SimpMessageHeaderAccessor.USER_HEADER, this.authentication);
+		this.interceptor.beforeHandle(this.messageBuilder.build(), this.channel, this.handler);
+		verify(strategy).getContext();
+		assertThat(strategy.getContext().getAuthentication()).isSameAs(this.authentication);
+	}
+
 	// SEC-2845
 	@Test
 	public void beforeHandleUserNotAuthentication() {
@@ -178,6 +214,15 @@ public class SecurityContextChannelInterceptorTests {
 		assertThat(SecurityContextHolder.getContext().getAuthentication()).isNull();
 	}
 
+	@Test
+	public void afterMessageHandledWhenCustomSecurityContextHolderStrategyThenUses() {
+		SecurityContextHolderStrategy strategy = spy(SecurityContextHolder.getContextHolderStrategy());
+		strategy.setContext(new SecurityContextImpl(this.authentication));
+		this.interceptor.setSecurityContextHolderStrategy(strategy);
+		this.interceptor.afterMessageHandled(this.messageBuilder.build(), this.channel, this.handler, null);
+		verify(strategy).clearContext();
+	}
+
 	// SEC-2829
 	@Test
 	public void restoresOriginalContext() {