|
@@ -3,6 +3,7 @@ package org.springframework.security.config.websocket
|
|
|
import org.springframework.beans.BeansException
|
|
|
import org.springframework.beans.factory.config.BeanDefinition
|
|
|
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory
|
|
|
+import org.springframework.beans.factory.parsing.BeanDefinitionParsingException
|
|
|
import org.springframework.beans.factory.support.BeanDefinitionRegistry
|
|
|
import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor
|
|
|
import org.springframework.beans.factory.support.RootBeanDefinition
|
|
@@ -17,8 +18,10 @@ import org.springframework.mock.web.MockHttpServletRequest
|
|
|
import org.springframework.mock.web.MockHttpServletResponse
|
|
|
import org.springframework.security.core.Authentication
|
|
|
import org.springframework.security.core.annotation.AuthenticationPrincipal
|
|
|
+import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher
|
|
|
import org.springframework.security.web.csrf.CsrfToken
|
|
|
import org.springframework.security.web.csrf.DefaultCsrfToken
|
|
|
+import org.springframework.security.web.csrf.InvalidCsrfTokenException
|
|
|
import org.springframework.security.web.csrf.MissingCsrfTokenException
|
|
|
import org.springframework.stereotype.Controller
|
|
|
import org.springframework.web.servlet.HandlerMapping
|
|
@@ -30,6 +33,7 @@ import org.springframework.web.socket.server.support.HttpSessionHandshakeInterce
|
|
|
import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler
|
|
|
import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler
|
|
|
import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHandler
|
|
|
+import spock.lang.Unroll
|
|
|
|
|
|
import static org.mockito.Mockito.*
|
|
|
|
|
@@ -50,6 +54,7 @@ import org.springframework.security.core.context.SecurityContextHolder
|
|
|
class WebSocketMessageBrokerConfigTests extends AbstractXmlConfigTests {
|
|
|
Authentication messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER')
|
|
|
boolean useSockJS = false
|
|
|
+ CsrfToken csrfToken = new DefaultCsrfToken('headerName', 'paramName', 'token')
|
|
|
|
|
|
def cleanup() {
|
|
|
SecurityContextHolder.clearContext()
|
|
@@ -89,6 +94,75 @@ class WebSocketMessageBrokerConfigTests extends AbstractXmlConfigTests {
|
|
|
noExceptionThrown()
|
|
|
}
|
|
|
|
|
|
+ @Unroll
|
|
|
+ def "message type - #type"(SimpMessageType type) {
|
|
|
+ setup:
|
|
|
+ websocket {
|
|
|
+ 'intercept-message'('type': type.toString(), access:'permitAll')
|
|
|
+ 'intercept-message'(pattern:'/**', access:'denyAll')
|
|
|
+ }
|
|
|
+ messageUser = null
|
|
|
+ SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(type)
|
|
|
+ if(SimpMessageType.CONNECT == type) {
|
|
|
+ headers.setNativeHeader(csrfToken.headerName, csrfToken.token)
|
|
|
+ }
|
|
|
+ Message message = message(headers, '/permitAll')
|
|
|
+
|
|
|
+ when: 'message is sent to the permitAll endpoint with no user'
|
|
|
+ clientInboundChannel.send(message)
|
|
|
+
|
|
|
+ then: 'access is granted'
|
|
|
+ noExceptionThrown()
|
|
|
+
|
|
|
+ where:
|
|
|
+ type << SimpMessageType.values()
|
|
|
+ }
|
|
|
+
|
|
|
+ @Unroll
|
|
|
+ def "pattern and message type - #type"(SimpMessageType type) {
|
|
|
+ setup:
|
|
|
+ websocket {
|
|
|
+ 'intercept-message'(pattern: '/permitAll', 'type': type.toString(), access:'permitAll')
|
|
|
+ 'intercept-message'(pattern:'/**', access:'denyAll')
|
|
|
+ }
|
|
|
+
|
|
|
+ when: 'message is sent to the permitAll endpoint with no user'
|
|
|
+ clientInboundChannel.send(message('/permitAll', type))
|
|
|
+
|
|
|
+ then: 'access is granted'
|
|
|
+ noExceptionThrown()
|
|
|
+
|
|
|
+ when: 'message sent to other message type'
|
|
|
+ clientInboundChannel.send(message('/permitAll', SimpMessageType.UNSUBSCRIBE))
|
|
|
+
|
|
|
+ then: 'does not match'
|
|
|
+ MessageDeliveryException e = thrown()
|
|
|
+ e.cause instanceof AccessDeniedException
|
|
|
+
|
|
|
+ when: 'message is sent to other pattern'
|
|
|
+ clientInboundChannel.send(message('/other', type))
|
|
|
+
|
|
|
+ then: 'does not match'
|
|
|
+ MessageDeliveryException eOther = thrown()
|
|
|
+ eOther.cause instanceof AccessDeniedException
|
|
|
+
|
|
|
+ where:
|
|
|
+ type << [SimpMessageType.MESSAGE, SimpMessageType.SUBSCRIBE]
|
|
|
+ }
|
|
|
+
|
|
|
+ @Unroll
|
|
|
+ def "intercept-message with invalid type and pattern - #type"(SimpMessageType type) {
|
|
|
+ when:
|
|
|
+ websocket {
|
|
|
+ 'intercept-message'(pattern : '/**', 'type': type.toString(), access:'permitAll')
|
|
|
+ }
|
|
|
+ then:
|
|
|
+ thrown(BeanDefinitionParsingException)
|
|
|
+
|
|
|
+ where:
|
|
|
+ type << [SimpMessageType.CONNECT, SimpMessageType.CONNECT_ACK, SimpMessageType.DISCONNECT, SimpMessageType.DISCONNECT_ACK, SimpMessageType.HEARTBEAT, SimpMessageType.OTHER, SimpMessageType.UNSUBSCRIBE ]
|
|
|
+ }
|
|
|
+
|
|
|
def 'messages with no id automatically adds Authentication argument resolver'() {
|
|
|
setup:
|
|
|
def id = 'authenticationController'
|
|
@@ -186,7 +260,7 @@ class WebSocketMessageBrokerConfigTests extends AbstractXmlConfigTests {
|
|
|
|
|
|
then: 'CSRF Protection blocks the Message'
|
|
|
MessageDeliveryException expected = thrown()
|
|
|
- expected.cause instanceof MissingCsrfTokenException
|
|
|
+ expected.cause instanceof InvalidCsrfTokenException
|
|
|
}
|
|
|
|
|
|
def 'websocket with no id does not override customArgumentResolvers'() {
|
|
@@ -314,8 +388,8 @@ class WebSocketMessageBrokerConfigTests extends AbstractXmlConfigTests {
|
|
|
appContext.getBean("clientInboundChannel")
|
|
|
}
|
|
|
|
|
|
- def message(String destination) {
|
|
|
- SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create()
|
|
|
+ def message(String destination, SimpMessageType type=SimpMessageType.MESSAGE) {
|
|
|
+ SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(type)
|
|
|
message(headers, destination)
|
|
|
}
|
|
|
|
|
@@ -327,6 +401,9 @@ class WebSocketMessageBrokerConfigTests extends AbstractXmlConfigTests {
|
|
|
if(messageUser != null) {
|
|
|
headers.user = messageUser
|
|
|
}
|
|
|
+ if(csrfToken != null) {
|
|
|
+ headers.sessionAttributes[CsrfToken.name] = csrfToken
|
|
|
+ }
|
|
|
new GenericMessage<String>("hi",headers.messageHeaders)
|
|
|
}
|
|
|
|