|
@@ -12,12 +12,24 @@ import org.springframework.http.server.ServerHttpRequest
|
|
|
import org.springframework.http.server.ServerHttpResponse
|
|
|
import org.springframework.messaging.handler.annotation.MessageMapping
|
|
|
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver
|
|
|
+import org.springframework.messaging.simp.SimpMessageType
|
|
|
+import org.springframework.mock.web.MockHttpServletRequest
|
|
|
+import org.springframework.mock.web.MockHttpServletResponse
|
|
|
import org.springframework.security.core.Authentication
|
|
|
import org.springframework.security.core.annotation.AuthenticationPrincipal
|
|
|
+import org.springframework.security.web.csrf.CsrfToken
|
|
|
+import org.springframework.security.web.csrf.DefaultCsrfToken
|
|
|
+import org.springframework.security.web.csrf.MissingCsrfTokenException
|
|
|
import org.springframework.stereotype.Controller
|
|
|
+import org.springframework.web.servlet.HandlerMapping
|
|
|
+import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping
|
|
|
import org.springframework.web.socket.WebSocketHandler
|
|
|
import org.springframework.web.socket.server.HandshakeFailureException
|
|
|
import org.springframework.web.socket.server.HandshakeHandler
|
|
|
+import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor
|
|
|
+import org.springframework.web.socket.server.support.WebSocketHttpRequestHandler
|
|
|
+import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler
|
|
|
+import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHandler
|
|
|
|
|
|
import static org.mockito.Mockito.*
|
|
|
|
|
@@ -37,6 +49,7 @@ import org.springframework.security.core.context.SecurityContextHolder
|
|
|
*/
|
|
|
class MessagesConfigTests extends AbstractXmlConfigTests {
|
|
|
Authentication messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER')
|
|
|
+ boolean useSockJS = false
|
|
|
|
|
|
def cleanup() {
|
|
|
SecurityContextHolder.clearContext()
|
|
@@ -93,6 +106,89 @@ class MessagesConfigTests extends AbstractXmlConfigTests {
|
|
|
controller.authenticationPrincipal == messageUser.name
|
|
|
}
|
|
|
|
|
|
+ def 'messages of type CONNECT use CsrfTokenHandshakeInterceptor'() {
|
|
|
+ setup:
|
|
|
+ def id = 'authenticationController'
|
|
|
+ bean(id,MyController)
|
|
|
+ bean('inPostProcessor',InboundExecutorPostProcessor)
|
|
|
+ messages {
|
|
|
+ 'message-interceptor'(pattern:'/**',access:'permitAll')
|
|
|
+ }
|
|
|
+
|
|
|
+ SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT)
|
|
|
+ Message<?> message = message(headers,'/authentication')
|
|
|
+ WebSocketHttpRequestHandler handler = appContext.getBean(WebSocketHttpRequestHandler)
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest()
|
|
|
+ String sessionAttr = "sessionAttr"
|
|
|
+ request.getSession().setAttribute(sessionAttr,"sessionValue")
|
|
|
+
|
|
|
+ CsrfToken token = new DefaultCsrfToken("header", "param", "token")
|
|
|
+ request.setAttribute(CsrfToken.name, token)
|
|
|
+
|
|
|
+ when:
|
|
|
+ handler.handleRequest(request , new MockHttpServletResponse())
|
|
|
+ TestHandshakeHandler handshakeHandler = appContext.getBean(TestHandshakeHandler)
|
|
|
+
|
|
|
+ then: 'CsrfToken is populated'
|
|
|
+ handshakeHandler.attributes.get(CsrfToken.name) == token
|
|
|
+
|
|
|
+ and: 'Explicitly listed HandshakeInterceptor are not overridden'
|
|
|
+ handshakeHandler.attributes.get(sessionAttr) == request.getSession().getAttribute(sessionAttr)
|
|
|
+ }
|
|
|
+
|
|
|
+ def 'messages of type CONNECT use CsrfTokenHandshakeInterceptor with SockJS'() {
|
|
|
+ setup:
|
|
|
+ useSockJS = true
|
|
|
+ def id = 'authenticationController'
|
|
|
+ bean(id,MyController)
|
|
|
+ bean('inPostProcessor',InboundExecutorPostProcessor)
|
|
|
+ messages {
|
|
|
+ 'message-interceptor'(pattern:'/**',access:'permitAll')
|
|
|
+ }
|
|
|
+
|
|
|
+ SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT)
|
|
|
+ Message<?> message = message(headers,'/authentication')
|
|
|
+ SockJsHttpRequestHandler handler = appContext.getBean(SockJsHttpRequestHandler)
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest()
|
|
|
+ String sessionAttr = "sessionAttr"
|
|
|
+ request.getSession().setAttribute(sessionAttr,"sessionValue")
|
|
|
+
|
|
|
+ CsrfToken token = new DefaultCsrfToken("header", "param", "token")
|
|
|
+ request.setAttribute(CsrfToken.name, token)
|
|
|
+
|
|
|
+ request.setMethod("GET")
|
|
|
+ request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket")
|
|
|
+
|
|
|
+ when:
|
|
|
+ handler.handleRequest(request , new MockHttpServletResponse())
|
|
|
+ TestHandshakeHandler handshakeHandler = appContext.getBean(TestHandshakeHandler)
|
|
|
+
|
|
|
+ then: 'CsrfToken is populated'
|
|
|
+ handshakeHandler.attributes?.get(CsrfToken.name) == token
|
|
|
+
|
|
|
+ and: 'Explicitly listed HandshakeInterceptor are not overridden'
|
|
|
+ handshakeHandler.attributes?.get(sessionAttr) == request.getSession().getAttribute(sessionAttr)
|
|
|
+ }
|
|
|
+
|
|
|
+ def 'messages of type CONNECT require valid CsrfToken'() {
|
|
|
+ setup:
|
|
|
+ def id = 'authenticationController'
|
|
|
+ bean(id,MyController)
|
|
|
+ bean('inPostProcessor',InboundExecutorPostProcessor)
|
|
|
+ messages {
|
|
|
+ 'message-interceptor'(pattern:'/**',access:'permitAll')
|
|
|
+ }
|
|
|
+
|
|
|
+ when: 'message of type CONNECTION is sent without CsrfTOken'
|
|
|
+ SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT)
|
|
|
+ Message<?> message = message(headers,'/authentication')
|
|
|
+ clientInboundChannel.send(message)
|
|
|
+
|
|
|
+ then: 'CSRF Protection blocks the Message'
|
|
|
+ MessageDeliveryException expected = thrown()
|
|
|
+ expected.cause instanceof MissingCsrfTokenException
|
|
|
+ }
|
|
|
+
|
|
|
def 'messages with no id does not override customArgumentResolvers'() {
|
|
|
setup:
|
|
|
def id = 'authenticationController'
|
|
@@ -201,6 +297,12 @@ class MessagesConfigTests extends AbstractXmlConfigTests {
|
|
|
'websocket:transport' {}
|
|
|
'websocket:stomp-endpoint'(path:'/app') {
|
|
|
'websocket:handshake-handler'(ref:'testHandler') {}
|
|
|
+ 'websocket:handshake-interceptors' {
|
|
|
+ 'b:bean'('class':HttpSessionHandshakeInterceptor.name) {}
|
|
|
+ }
|
|
|
+ if(useSockJS) {
|
|
|
+ 'websocket:sockjs' {}
|
|
|
+ }
|
|
|
}
|
|
|
'websocket:simple-broker'(prefix:"/queue, /topic"){}
|
|
|
}
|
|
@@ -214,6 +316,11 @@ class MessagesConfigTests extends AbstractXmlConfigTests {
|
|
|
|
|
|
def message(String destination) {
|
|
|
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create()
|
|
|
+ message(headers, destination)
|
|
|
+ }
|
|
|
+
|
|
|
+ def message(SimpMessageHeaderAccessor headers, String destination) {
|
|
|
+ messageUser = new TestingAuthenticationToken('user','pass','ROLE_USER')
|
|
|
headers.sessionId = '123'
|
|
|
headers.sessionAttributes = [:]
|
|
|
headers.destination = destination
|
|
@@ -257,8 +364,15 @@ class MessagesConfigTests extends AbstractXmlConfigTests {
|
|
|
}
|
|
|
|
|
|
static class TestHandshakeHandler implements HandshakeHandler {
|
|
|
- @Override
|
|
|
+ Map<String, Object> attributes;
|
|
|
+
|
|
|
boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
|
|
|
+ this.attributes = attributes
|
|
|
+ if(wsHandler instanceof SockJsWebSocketHandler) {
|
|
|
+ // work around SPR-12716
|
|
|
+ SockJsWebSocketHandler sockJs = (SockJsWebSocketHandler) wsHandler;
|
|
|
+ this.attributes = sockJs.sockJsSession.attributes
|
|
|
+ }
|
|
|
true
|
|
|
}
|
|
|
}
|