|
@@ -1,5 +1,5 @@
|
|
|
/*
|
|
|
- * Copyright 2002-2016 the original author or authors.
|
|
|
+ * Copyright 2002-2019 the original author or authors.
|
|
|
*
|
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
* you may not use this file except in compliance with the License.
|
|
@@ -26,6 +26,8 @@ import javax.servlet.http.HttpServletRequest;
|
|
|
import org.junit.After;
|
|
|
import org.junit.Before;
|
|
|
import org.junit.Test;
|
|
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
+import org.springframework.context.ApplicationContext;
|
|
|
import org.springframework.context.annotation.Bean;
|
|
|
import org.springframework.context.annotation.Configuration;
|
|
|
import org.springframework.context.annotation.Import;
|
|
@@ -40,6 +42,7 @@ import org.springframework.messaging.handler.invocation.HandlerMethodArgumentRes
|
|
|
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
|
|
|
import org.springframework.messaging.simp.SimpMessageType;
|
|
|
import org.springframework.messaging.simp.config.MessageBrokerRegistry;
|
|
|
+import org.springframework.messaging.support.AbstractMessageChannel;
|
|
|
import org.springframework.messaging.support.GenericMessage;
|
|
|
import org.springframework.mock.web.MockHttpServletRequest;
|
|
|
import org.springframework.mock.web.MockHttpServletResponse;
|
|
@@ -53,6 +56,10 @@ import org.springframework.security.core.Authentication;
|
|
|
import org.springframework.security.core.annotation.AuthenticationPrincipal;
|
|
|
import org.springframework.security.messaging.access.expression.DefaultMessageSecurityExpressionHandler;
|
|
|
import org.springframework.security.messaging.access.expression.MessageSecurityExpressionRoot;
|
|
|
+import org.springframework.security.messaging.access.intercept.ChannelSecurityInterceptor;
|
|
|
+import org.springframework.security.messaging.access.intercept.MessageSecurityMetadataSource;
|
|
|
+import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
|
|
|
+import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
|
|
|
import org.springframework.security.web.csrf.CsrfToken;
|
|
|
import org.springframework.security.web.csrf.DefaultCsrfToken;
|
|
|
import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
|
@@ -199,6 +206,16 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
|
|
|
messageChannel.send(message);
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void csrfProtectionDefinedByBean() {
|
|
|
+ loadConfig(SockJsProxylessSecurityConfig.class);
|
|
|
+
|
|
|
+ MessageChannel messageChannel = clientInboundChannel();
|
|
|
+ CsrfChannelInterceptor csrfChannelInterceptor = context.getBean(CsrfChannelInterceptor.class);
|
|
|
+
|
|
|
+ assertThat(((AbstractMessageChannel) messageChannel).getInterceptors()).contains(csrfChannelInterceptor);
|
|
|
+ }
|
|
|
+
|
|
|
@Test
|
|
|
public void messagesConnectUseCsrfTokenHandshakeInterceptor() throws Exception {
|
|
|
|
|
@@ -421,6 +438,41 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ @Test
|
|
|
+ public void channelSecurityInterceptorUsesMetadataSourceBeanWhenProxyingDisabled() {
|
|
|
+
|
|
|
+ loadConfig(SockJsProxylessSecurityConfig.class);
|
|
|
+
|
|
|
+ ChannelSecurityInterceptor channelSecurityInterceptor = context.getBean(ChannelSecurityInterceptor.class);
|
|
|
+ MessageSecurityMetadataSource messageSecurityMetadataSource =
|
|
|
+ context.getBean(MessageSecurityMetadataSource.class);
|
|
|
+
|
|
|
+ assertThat(channelSecurityInterceptor.obtainSecurityMetadataSource()).isSameAs(messageSecurityMetadataSource);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void securityContextChannelInterceptorDefinedByBean() {
|
|
|
+ loadConfig(SockJsProxylessSecurityConfig.class);
|
|
|
+
|
|
|
+ MessageChannel messageChannel = clientInboundChannel();
|
|
|
+ SecurityContextChannelInterceptor securityContextChannelInterceptor =
|
|
|
+ context.getBean(SecurityContextChannelInterceptor.class);
|
|
|
+
|
|
|
+ assertThat(((AbstractMessageChannel) messageChannel).getInterceptors())
|
|
|
+ .contains(securityContextChannelInterceptor);
|
|
|
+ }
|
|
|
+
|
|
|
+ @Test
|
|
|
+ public void inboundChannelSecurityDefinedByBean() {
|
|
|
+ loadConfig(SockJsProxylessSecurityConfig.class);
|
|
|
+
|
|
|
+ MessageChannel messageChannel = clientInboundChannel();
|
|
|
+ ChannelSecurityInterceptor inboundChannelSecurity = context.getBean(ChannelSecurityInterceptor.class);
|
|
|
+
|
|
|
+ assertThat(((AbstractMessageChannel) messageChannel).getInterceptors())
|
|
|
+ .contains(inboundChannelSecurity);
|
|
|
+ }
|
|
|
+
|
|
|
@Configuration
|
|
|
@EnableWebSocketMessageBroker
|
|
|
@Import(SyncExecutorConfig.class)
|
|
@@ -706,6 +758,38 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ @Configuration(proxyBeanMethods = false)
|
|
|
+ @EnableWebSocketMessageBroker
|
|
|
+ @Import(SyncExecutorConfig.class)
|
|
|
+ static class SockJsProxylessSecurityConfig extends
|
|
|
+ AbstractSecurityWebSocketMessageBrokerConfigurer {
|
|
|
+ private ApplicationContext context;
|
|
|
+
|
|
|
+ public void registerStompEndpoints(StompEndpointRegistry registry) {
|
|
|
+ registry.addEndpoint("/chat")
|
|
|
+ .setHandshakeHandler(context.getBean(TestHandshakeHandler.class))
|
|
|
+ .withSockJS().setInterceptors(new HttpSessionHandshakeInterceptor());
|
|
|
+ }
|
|
|
+
|
|
|
+ @Autowired
|
|
|
+ public void setContext(ApplicationContext context) {
|
|
|
+ this.context = context;
|
|
|
+ }
|
|
|
+
|
|
|
+ // @formatter:off
|
|
|
+ @Override
|
|
|
+ protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
|
|
|
+ messages
|
|
|
+ .anyMessage().denyAll();
|
|
|
+ }
|
|
|
+ // @formatter:on
|
|
|
+
|
|
|
+ @Bean
|
|
|
+ public TestHandshakeHandler testHandshakeHandler() {
|
|
|
+ return new TestHandshakeHandler();
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
@Configuration
|
|
|
static class SyncExecutorConfig {
|
|
|
@Bean
|