|
@@ -0,0 +1,265 @@
|
|
|
|
+/*
|
|
|
|
+ * Copyright 2002-2015 the original author or authors.
|
|
|
|
+ *
|
|
|
|
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
|
|
|
+ * use this file except in compliance with the License. You may obtain a copy of
|
|
|
|
+ * the License at
|
|
|
|
+ *
|
|
|
|
+ * http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
+ *
|
|
|
|
+ * Unless required by applicable law or agreed to in writing, software
|
|
|
|
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
+ * License for the specific language governing permissions and limitations under
|
|
|
|
+ * the License.
|
|
|
|
+ */
|
|
|
|
+package org.springframework.security.config.annotation.web.socket;
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+import org.junit.After;
|
|
|
|
+import org.junit.Before;
|
|
|
|
+
|
|
|
|
+import org.junit.Test;
|
|
|
|
+import org.springframework.context.annotation.Bean;
|
|
|
|
+import org.springframework.context.annotation.Configuration;
|
|
|
|
+import org.springframework.context.annotation.Import;
|
|
|
|
+import org.springframework.core.MethodParameter;
|
|
|
|
+import org.springframework.http.server.ServerHttpRequest;
|
|
|
|
+import org.springframework.http.server.ServerHttpResponse;
|
|
|
|
+import org.springframework.messaging.Message;
|
|
|
|
+import org.springframework.messaging.MessageChannel;
|
|
|
|
+import org.springframework.messaging.MessageDeliveryException;
|
|
|
|
+import org.springframework.messaging.handler.annotation.MessageMapping;
|
|
|
|
+import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver;
|
|
|
|
+import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
|
|
|
|
+import org.springframework.messaging.simp.SimpMessageType;
|
|
|
|
+import org.springframework.messaging.simp.config.MessageBrokerRegistry;
|
|
|
|
+import org.springframework.messaging.support.GenericMessage;
|
|
|
|
+import org.springframework.mock.web.MockHttpServletRequest;
|
|
|
|
+import org.springframework.mock.web.MockHttpServletResponse;
|
|
|
|
+import org.springframework.mock.web.MockServletConfig;
|
|
|
|
+import org.springframework.security.access.AccessDeniedException;
|
|
|
|
+import org.springframework.security.authentication.TestingAuthenticationToken;
|
|
|
|
+import org.springframework.security.config.annotation.web.messaging.MessageSecurityMetadataSourceRegistry;
|
|
|
|
+import org.springframework.security.core.annotation.AuthenticationPrincipal;
|
|
|
|
+import org.springframework.security.web.csrf.CsrfToken;
|
|
|
|
+import org.springframework.security.web.csrf.DefaultCsrfToken;
|
|
|
|
+import org.springframework.security.web.csrf.MissingCsrfTokenException;
|
|
|
|
+import org.springframework.stereotype.Controller;
|
|
|
|
+import org.springframework.test.util.ReflectionTestUtils;
|
|
|
|
+import org.springframework.web.HttpRequestHandler;
|
|
|
|
+import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
|
|
|
|
+import org.springframework.web.servlet.HandlerMapping;
|
|
|
|
+import org.springframework.web.socket.WebSocketHandler;
|
|
|
|
+import org.springframework.web.socket.config.annotation.EnableWebSocketMessageBroker;
|
|
|
|
+import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
|
|
|
|
+import org.springframework.web.socket.server.HandshakeFailureException;
|
|
|
|
+import org.springframework.web.socket.server.HandshakeHandler;
|
|
|
|
+import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor;
|
|
|
|
+import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHandler;
|
|
|
|
+import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSession;
|
|
|
|
+
|
|
|
|
+import javax.servlet.http.HttpServletRequest;
|
|
|
|
+import java.util.HashMap;
|
|
|
|
+import java.util.Map;
|
|
|
|
+
|
|
|
|
+import static org.fest.assertions.Assertions.assertThat;
|
|
|
|
+import static org.junit.Assert.fail;
|
|
|
|
+
|
|
|
|
+public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
|
|
|
|
+ AnnotationConfigWebApplicationContext context;
|
|
|
|
+
|
|
|
|
+ TestingAuthenticationToken messageUser;
|
|
|
|
+
|
|
|
|
+ CsrfToken token;
|
|
|
|
+
|
|
|
|
+ String sessionAttr;
|
|
|
|
+
|
|
|
|
+ @Before
|
|
|
|
+ public void setup() {
|
|
|
|
+ token = new DefaultCsrfToken("header", "param", "token");
|
|
|
|
+ sessionAttr = "sessionAttr";
|
|
|
|
+ messageUser = new TestingAuthenticationToken("user","pass","ROLE_USER");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @After
|
|
|
|
+ public void cleanup() {
|
|
|
|
+ if(context != null) {
|
|
|
|
+ context.close();
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ public void simpleRegistryMappings() {
|
|
|
|
+ loadConfig(SockJsSecurityConfig.class);
|
|
|
|
+
|
|
|
|
+ clientInboundChannel().send(message("/permitAll"));
|
|
|
|
+
|
|
|
|
+ try {
|
|
|
|
+ clientInboundChannel().send(message("/denyAll"));
|
|
|
|
+ fail("Expected Exception");
|
|
|
|
+ } catch(MessageDeliveryException expected) {
|
|
|
|
+ assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class);
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ public void annonymousSupported() {
|
|
|
|
+ loadConfig(SockJsSecurityConfig.class);
|
|
|
|
+
|
|
|
|
+ messageUser = null;
|
|
|
|
+ clientInboundChannel().send(message("/permitAll"));
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Test
|
|
|
|
+ public void addsAuthenticationPrincipalResolver() throws InterruptedException {
|
|
|
|
+ loadConfig(SockJsSecurityConfig.class);
|
|
|
|
+
|
|
|
|
+ MessageChannel messageChannel = clientInboundChannel();
|
|
|
|
+ Message<String> message = message("/permitAll/authentication");
|
|
|
|
+ messageChannel.send(message);
|
|
|
|
+
|
|
|
|
+ assertThat(context.getBean(MyController.class).authenticationPrincipal).isEqualTo((String) messageUser.getPrincipal());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private MockHttpServletRequest sockjsHttpRequest(String mapping) {
|
|
|
|
+ MockHttpServletRequest request = new MockHttpServletRequest();
|
|
|
|
+ request.setMethod("GET");
|
|
|
|
+ request.setAttribute(HandlerMapping.PATH_WITHIN_HANDLER_MAPPING_ATTRIBUTE, "/289/tpyx6mde/websocket");
|
|
|
|
+ request.setRequestURI(mapping + "/289/tpyx6mde/websocket");
|
|
|
|
+ request.getSession().setAttribute(sessionAttr,"sessionValue");
|
|
|
|
+
|
|
|
|
+ request.setAttribute(CsrfToken.class.getName(), token);
|
|
|
|
+ return request;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private Message<String> message(String destination) {
|
|
|
|
+ SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.create();
|
|
|
|
+ return message(headers, destination);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private Message<String> message(SimpMessageHeaderAccessor headers, String destination) {
|
|
|
|
+ headers.setSessionId("123");
|
|
|
|
+ headers.setSessionAttributes(new HashMap<String, Object>());
|
|
|
|
+ if(destination != null) {
|
|
|
|
+ headers.setDestination(destination);
|
|
|
|
+ }
|
|
|
|
+ if(messageUser != null) {
|
|
|
|
+ headers.setUser(messageUser);
|
|
|
|
+ }
|
|
|
|
+ return new GenericMessage<String>("hi",headers.getMessageHeaders());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private MessageChannel clientInboundChannel() {
|
|
|
|
+ return context.getBean("clientInboundChannel", MessageChannel.class);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ private void loadConfig(Class<?>... configs) {
|
|
|
|
+ context = new AnnotationConfigWebApplicationContext();
|
|
|
|
+ context.register(configs);
|
|
|
|
+ context.setServletConfig(new MockServletConfig());
|
|
|
|
+ context.refresh();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ @Controller
|
|
|
|
+ static class MyController {
|
|
|
|
+
|
|
|
|
+ String authenticationPrincipal;
|
|
|
|
+ MyCustomArgument myCustomArgument;
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ @MessageMapping("/authentication")
|
|
|
|
+ public void authentication(@AuthenticationPrincipal String un) {
|
|
|
|
+ this.authenticationPrincipal = un;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @MessageMapping("/myCustom")
|
|
|
|
+ public void myCustom(MyCustomArgument myCustomArgument) {
|
|
|
|
+ this.myCustomArgument = myCustomArgument;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ static class MyCustomArgument {
|
|
|
|
+ MyCustomArgument(String notDefaultConstr) {}
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ static class MyCustomArgumentResolver implements HandlerMethodArgumentResolver {
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public boolean supportsParameter(MethodParameter parameter) {
|
|
|
|
+ return parameter.getParameterType().isAssignableFrom(MyCustomArgument.class);
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
|
|
|
|
+ return new MyCustomArgument("");
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ static class TestHandshakeHandler implements HandshakeHandler {
|
|
|
|
+ Map<String, Object> attributes;
|
|
|
|
+
|
|
|
|
+ public boolean doHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, Map<String, Object> attributes) throws HandshakeFailureException {
|
|
|
|
+ this.attributes = attributes;
|
|
|
|
+ if(wsHandler instanceof SockJsWebSocketHandler) {
|
|
|
|
+ // work around SPR-12716
|
|
|
|
+ SockJsWebSocketHandler sockJs = (SockJsWebSocketHandler) wsHandler;
|
|
|
|
+ WebSocketServerSockJsSession session = (WebSocketServerSockJsSession) ReflectionTestUtils.getField(sockJs, "sockJsSession");
|
|
|
|
+ this.attributes = session.getAttributes();
|
|
|
|
+ }
|
|
|
|
+ return true;
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Configuration
|
|
|
|
+ @EnableWebSocketMessageBroker
|
|
|
|
+ @Import(SyncExecutorConfig.class)
|
|
|
|
+ static class SockJsSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer {
|
|
|
|
+
|
|
|
|
+ public void registerStompEndpoints(StompEndpointRegistry registry) {
|
|
|
|
+ registry
|
|
|
|
+ .addEndpoint("/other")
|
|
|
|
+ .setHandshakeHandler(testHandshakeHandler())
|
|
|
|
+ .withSockJS()
|
|
|
|
+ .setInterceptors(new HttpSessionHandshakeInterceptor());
|
|
|
|
+
|
|
|
|
+ registry
|
|
|
|
+ .addEndpoint("/chat")
|
|
|
|
+ .setHandshakeHandler(testHandshakeHandler())
|
|
|
|
+ .withSockJS()
|
|
|
|
+ .setInterceptors(new HttpSessionHandshakeInterceptor());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
|
|
|
|
+ messages
|
|
|
|
+ .simpDestMatchers("/permitAll/**").permitAll()
|
|
|
|
+ .anyMessage().denyAll();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Override
|
|
|
|
+ public void configureMessageBroker(MessageBrokerRegistry registry) {
|
|
|
|
+ registry.enableSimpleBroker("/queue/", "/topic/");
|
|
|
|
+ registry.setApplicationDestinationPrefixes("/permitAll", "/denyAll");
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Bean
|
|
|
|
+ public MyController myController() {
|
|
|
|
+ return new MyController();
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Bean
|
|
|
|
+ public TestHandshakeHandler testHandshakeHandler() {
|
|
|
|
+ return new TestHandshakeHandler();
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ @Configuration
|
|
|
|
+ static class SyncExecutorConfig {
|
|
|
|
+ @Bean
|
|
|
|
+ public static SyncExecutorSubscribableChannelPostProcessor postProcessor() {
|
|
|
|
+ return new SyncExecutorSubscribableChannelPostProcessor();
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+}
|