Forráskód Böngészése

Add SecurityContextHolderStrategy Java Configuration for Messaging

Issue gh-11061
Josh Cummings 3 éve
szülő
commit
484f35ca39

+ 17 - 2
config/src/main/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfiguration.java

@@ -32,6 +32,8 @@ import org.springframework.messaging.simp.config.ChannelRegistration;
 import org.springframework.messaging.support.ChannelInterceptor;
 import org.springframework.security.authorization.AuthorizationManager;
 import org.springframework.security.authorization.SpringAuthorizationEventPublisher;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.messaging.access.intercept.AuthorizationChannelInterceptor;
 import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager;
 import org.springframework.security.messaging.context.AuthenticationPrincipalArgumentResolver;
@@ -59,7 +61,10 @@ final class WebSocketMessageBrokerSecurityConfiguration
 	private static final AuthorizationManager<Message<?>> ANY_MESSAGE_AUTHENTICATED = MessageMatcherDelegatingAuthorizationManager
 			.builder().anyMessage().authenticated().build();
 
-	private final ChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor();
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
+	private final SecurityContextChannelInterceptor securityContextChannelInterceptor = new SecurityContextChannelInterceptor();
 
 	private final ChannelInterceptor csrfChannelInterceptor = new CsrfChannelInterceptor();
 
@@ -74,17 +79,27 @@ final class WebSocketMessageBrokerSecurityConfiguration
 
 	@Override
 	public void addArgumentResolvers(List<HandlerMethodArgumentResolver> argumentResolvers) {
-		argumentResolvers.add(new AuthenticationPrincipalArgumentResolver());
+		AuthenticationPrincipalArgumentResolver resolver = new AuthenticationPrincipalArgumentResolver();
+		resolver.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
+		argumentResolvers.add(resolver);
 	}
 
 	@Override
 	public void configureClientInboundChannel(ChannelRegistration registration) {
 		this.authorizationChannelInterceptor
 				.setAuthorizationEventPublisher(new SpringAuthorizationEventPublisher(this.context));
+		this.authorizationChannelInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
+		this.securityContextChannelInterceptor.setSecurityContextHolderStrategy(this.securityContextHolderStrategy);
 		registration.interceptors(this.securityContextChannelInterceptor, this.csrfChannelInterceptor,
 				this.authorizationChannelInterceptor);
 	}
 
+	@Autowired(required = false)
+	void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		Assert.notNull(securityContextHolderStrategy, "securityContextHolderStrategy cannot be null");
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 	@Autowired(required = false)
 	void setAuthorizationManager(AuthorizationManager<Message<?>> authorizationManager) {
 		this.authorizationChannelInterceptor = new AuthorizationChannelInterceptor(authorizationManager);

+ 17 - 0
config/src/test/java/org/springframework/security/config/annotation/web/socket/WebSocketMessageBrokerSecurityConfigurationTests.java

@@ -54,9 +54,11 @@ import org.springframework.security.access.AccessDeniedException;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.authorization.AuthorizationDecision;
 import org.springframework.security.authorization.AuthorizationManager;
+import org.springframework.security.config.annotation.SecurityContextChangedListenerConfig;
 import org.springframework.security.config.annotation.web.messaging.MessageSecurityMetadataSourceRegistry;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.annotation.AuthenticationPrincipal;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.messaging.access.intercept.AuthorizationChannelInterceptor;
 import org.springframework.security.messaging.access.intercept.MessageAuthorizationContext;
 import org.springframework.security.messaging.access.intercept.MessageMatcherDelegatingAuthorizationManager;
@@ -84,6 +86,8 @@ import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSo
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
 import static org.assertj.core.api.Assertions.fail;
+import static org.mockito.Mockito.atLeastOnce;
+import static org.mockito.Mockito.verify;
 
 public class WebSocketMessageBrokerSecurityConfigurationTests {
 
@@ -225,6 +229,18 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
 		assertHandshake(request);
 	}
 
+	@Test
+	public void messagesContextWebSocketUseSecurityContextHolderStrategy() {
+		loadConfig(WebSocketSecurityConfig.class, SecurityContextChangedListenerConfig.class);
+		SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT);
+		headers.setNativeHeader(this.token.getHeaderName(), this.token.getToken());
+		Message<?> message = message(headers, "/authenticated");
+		headers.getSessionAttributes().put(CsrfToken.class.getName(), this.token);
+		MessageChannel messageChannel = clientInboundChannel();
+		messageChannel.send(message);
+		verify(this.context.getBean(SecurityContextHolderStrategy.class), atLeastOnce()).getContext();
+	}
+
 	@Test
 	public void msmsRegistryCustomPatternMatcher() {
 		loadConfig(MsmsRegistryCustomPatternMatcherConfig.class);
@@ -691,6 +707,7 @@ public class WebSocketMessageBrokerSecurityConfigurationTests {
 			// @formatter:off
 			messages
 				.simpDestMatchers("/permitAll/**").permitAll()
+				.simpDestMatchers("/authenticated/**").authenticated()
 				.anyMessage().denyAll();
 			// @formatter:on
 			return messages.build();