浏览代码

SEC-2864: Default Spring Security WebSocket PathMatcher

Rob Winch 10 年之前
父节点
当前提交
57b06fb0b5

+ 52 - 2
config/src/main/java/org/springframework/security/config/annotation/web/messaging/MessageSecurityMetadataSourceRegistry.java

@@ -47,7 +47,9 @@ public class MessageSecurityMetadataSourceRegistry {
 
 	private final LinkedHashMap<MatcherBuilder, String> matcherToExpression = new LinkedHashMap<MatcherBuilder, String>();
 
-	private PathMatcher pathMatcher = new AntPathMatcher();
+	private DelegatingPathMatcher pathMatcher = new DelegatingPathMatcher();
+
+	private boolean defaultPathMatcher = true;
 
 	/**
 	 * Maps any {@link Message} to a security expression.
@@ -169,10 +171,20 @@ public class MessageSecurityMetadataSourceRegistry {
 	public MessageSecurityMetadataSourceRegistry simpDestPathMatcher(
 			PathMatcher pathMatcher) {
 		Assert.notNull(pathMatcher, "pathMatcher cannot be null");
-		this.pathMatcher = pathMatcher;
+		this.pathMatcher.setPathMatcher(pathMatcher);
+		this.defaultPathMatcher = false;
 		return this;
 	}
 
+	/**
+	 * Determines if the {@link #simpDestPathMatcher(PathMatcher)} has been explicitly set.
+	 *
+	 * @return true if {@link #simpDestPathMatcher(PathMatcher)} has been explicitly set, else false.
+	 */
+	protected boolean isSimpDestPathMatcherConfigured() {
+		return !this.defaultPathMatcher;
+	}
+
 	/**
 	 * Maps a {@link List} of {@link MessageMatcher} instances to a security expression.
 	 *
@@ -439,4 +451,42 @@ public class MessageSecurityMetadataSourceRegistry {
 	private interface MatcherBuilder {
 		MessageMatcher<?> build();
 	}
+
+
+	static class DelegatingPathMatcher implements PathMatcher {
+
+		private PathMatcher delegate = new AntPathMatcher();
+
+		public boolean isPattern(String path) {
+			return delegate.isPattern(path);
+		}
+
+		public boolean match(String pattern, String path) {
+			return delegate.match(pattern, path);
+		}
+
+		public boolean matchStart(String pattern, String path) {
+			return delegate.matchStart(pattern, path);
+		}
+
+		public String extractPathWithinPattern(String pattern, String path) {
+			return delegate.extractPathWithinPattern(pattern, path);
+		}
+
+		public Map<String, String> extractUriTemplateVariables(String pattern, String path) {
+			return delegate.extractUriTemplateVariables(pattern, path);
+		}
+
+		public Comparator<String> getPatternComparator(String path) {
+			return delegate.getPatternComparator(path);
+		}
+
+		public String combine(String pattern1, String pattern2) {
+			return delegate.combine(pattern1, pattern2);
+		}
+
+		void setPathMatcher(PathMatcher pathMatcher) {
+			this.delegate = pathMatcher;
+		}
+	}
 }

+ 27 - 5
config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java

@@ -15,6 +15,11 @@
  */
 package org.springframework.security.config.annotation.web.socket;
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import org.springframework.beans.factory.NoSuchBeanDefinitionException;
 import org.springframework.beans.factory.SmartInitializingSingleton;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.context.ApplicationContext;
@@ -22,6 +27,7 @@ import org.springframework.context.annotation.Bean;
 import org.springframework.core.Ordered;
 import org.springframework.core.annotation.Order;
 import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver;
+import org.springframework.messaging.simp.annotation.support.SimpAnnotationMethodMessageHandler;
 import org.springframework.messaging.simp.config.ChannelRegistration;
 import org.springframework.security.access.AccessDecisionVoter;
 import org.springframework.security.access.vote.AffirmativeBased;
@@ -33,6 +39,8 @@ import org.springframework.security.messaging.context.AuthenticationPrincipalArg
 import org.springframework.security.messaging.context.SecurityContextChannelInterceptor;
 import org.springframework.security.messaging.web.csrf.CsrfChannelInterceptor;
 import org.springframework.security.messaging.web.socket.server.CsrfTokenHandshakeInterceptor;
+import org.springframework.util.AntPathMatcher;
+import org.springframework.util.PathMatcher;
 import org.springframework.web.servlet.handler.SimpleUrlHandlerMapping;
 import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer;
 import org.springframework.web.socket.config.annotation.StompEndpointRegistry;
@@ -42,10 +50,6 @@ import org.springframework.web.socket.sockjs.SockJsService;
 import org.springframework.web.socket.sockjs.support.SockJsHttpRequestHandler;
 import org.springframework.web.socket.sockjs.transport.TransportHandlingSockJsService;
 
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-
 /**
  * Allows configuring WebSocket Authorization.
  *
@@ -57,7 +61,7 @@ import java.util.Map;
  * &#064;Configuration
  * public class WebSocketSecurityConfig extends
  * 		AbstractSecurityWebSocketMessageBrokerConfigurer {
- * 
+ *
  * 	&#064;Override
  * 	protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
  * 		messages.simpDestMatchers(&quot;/user/queue/errors&quot;).permitAll()
@@ -99,6 +103,14 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends
 		customizeClientInboundChannel(registration);
 	}
 
+	private PathMatcher getDefaultPathMatcher() {
+		try {
+			return context.getBean(SimpAnnotationMethodMessageHandler.class).getPathMatcher();
+		} catch(NoSuchBeanDefinitionException e) {
+			return new AntPathMatcher();
+		}
+	}
+
 	/**
 	 * <p>
 	 * Determines if a CSRF token is required for connecting. This protects against remote
@@ -169,6 +181,11 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends
 		protected boolean containsMapping() {
 			return super.containsMapping();
 		}
+
+		@Override
+		protected boolean isSimpDestPathMatcherConfigured() {
+			return super.isSimpDestPathMatcherConfigured();
+		}
 	}
 
 	@Autowired
@@ -225,5 +242,10 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends
 								+ object);
 			}
 		}
+
+		if (inboundRegistry.containsMapping() && !inboundRegistry.isSimpDestPathMatcherConfigured()) {
+			PathMatcher pathMatcher = getDefaultPathMatcher();
+			inboundRegistry.simpDestPathMatcher(pathMatcher);
+		}
 	}
 }

+ 169 - 4
config/src/test/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurerTests.java

@@ -1,5 +1,4 @@
 /*
- * Copyright 2002-2015 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License"); you may not
  * use this file except in compliance with the License. You may obtain a copy of
@@ -17,7 +16,6 @@ package org.springframework.security.config.annotation.web.socket;
 
 import org.junit.After;
 import org.junit.Before;
-
 import org.junit.Test;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
@@ -46,6 +44,7 @@ import org.springframework.security.web.csrf.DefaultCsrfToken;
 import org.springframework.security.web.csrf.MissingCsrfTokenException;
 import org.springframework.stereotype.Controller;
 import org.springframework.test.util.ReflectionTestUtils;
+import org.springframework.util.AntPathMatcher;
 import org.springframework.web.HttpRequestHandler;
 import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
 import org.springframework.web.servlet.HandlerMapping;
@@ -59,6 +58,7 @@ import org.springframework.web.socket.sockjs.transport.handler.SockJsWebSocketHa
 import org.springframework.web.socket.sockjs.transport.session.WebSocketServerSockJsSession;
 
 import javax.servlet.http.HttpServletRequest;
+
 import java.util.HashMap;
 import java.util.Map;
 
@@ -232,6 +232,163 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
 		assertHandshake(request);
 	}
 
+	@Test
+	public void msmsRegistryCustomPatternMatcher()
+			throws Exception {
+		loadConfig(MsmsRegistryCustomPatternMatcherConfig.class);
+
+		clientInboundChannel().send(message("/app/a.b"));
+
+		try {
+			clientInboundChannel().send(message("/app/a.b.c"));
+			fail("Expected Exception");
+		}
+		catch (MessageDeliveryException expected) {
+			assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class);
+		}
+	}
+
+	@Configuration
+	@EnableWebSocketMessageBroker
+	@Import(SyncExecutorConfig.class)
+	static class MsmsRegistryCustomPatternMatcherConfig extends
+			AbstractSecurityWebSocketMessageBrokerConfigurer {
+
+		// @formatter:off
+		public void registerStompEndpoints(StompEndpointRegistry registry) {
+			registry
+				.addEndpoint("/other")
+				.setHandshakeHandler(testHandshakeHandler());
+		}
+		// @formatter:on
+
+		// @formatter:off
+		@Override
+		protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
+			messages
+				.simpDestMatchers("/app/a.*").permitAll()
+				.anyMessage().denyAll();
+		}
+		// @formatter:on
+
+		@Override
+		public void configureMessageBroker(MessageBrokerRegistry registry) {
+			registry.setPathMatcher(new AntPathMatcher("."));
+			registry.enableSimpleBroker("/queue/", "/topic/");
+			registry.setApplicationDestinationPrefixes("/app");
+		}
+
+		@Bean
+		public TestHandshakeHandler testHandshakeHandler() {
+			return new TestHandshakeHandler();
+		}
+	}
+
+	@Test
+	public void overrideMsmsRegistryCustomPatternMatcher()
+			throws Exception {
+		loadConfig(OverrideMsmsRegistryCustomPatternMatcherConfig.class);
+
+		clientInboundChannel().send(message("/app/a/b"));
+
+		try {
+			clientInboundChannel().send(message("/app/a/b/c"));
+			fail("Expected Exception");
+		}
+		catch (MessageDeliveryException expected) {
+			assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class);
+		}
+	}
+
+	@Configuration
+	@EnableWebSocketMessageBroker
+	@Import(SyncExecutorConfig.class)
+	static class OverrideMsmsRegistryCustomPatternMatcherConfig extends
+			AbstractSecurityWebSocketMessageBrokerConfigurer {
+
+		// @formatter:off
+		public void registerStompEndpoints(StompEndpointRegistry registry) {
+			registry
+				.addEndpoint("/other")
+				.setHandshakeHandler(testHandshakeHandler());
+		}
+		// @formatter:on
+
+
+		// @formatter:off
+		@Override
+		protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
+			messages
+				.simpDestPathMatcher(new AntPathMatcher())
+				.simpDestMatchers("/app/a/*").permitAll()
+				.anyMessage().denyAll();
+		}
+		// @formatter:on
+
+		@Override
+		public void configureMessageBroker(MessageBrokerRegistry registry) {
+			registry.setPathMatcher(new AntPathMatcher("."));
+			registry.enableSimpleBroker("/queue/", "/topic/");
+			registry.setApplicationDestinationPrefixes("/app");
+		}
+
+		@Bean
+		public TestHandshakeHandler testHandshakeHandler() {
+			return new TestHandshakeHandler();
+		}
+	}
+
+	@Test
+	public void defaultPatternMatcher()
+			throws Exception {
+		loadConfig(DefaultPatternMatcherConfig.class);
+
+		clientInboundChannel().send(message("/app/a/b"));
+
+		try {
+			clientInboundChannel().send(message("/app/a/b/c"));
+			fail("Expected Exception");
+		}
+		catch (MessageDeliveryException expected) {
+			assertThat(expected.getCause()).isInstanceOf(AccessDeniedException.class);
+		}
+	}
+
+	@Configuration
+	@EnableWebSocketMessageBroker
+	@Import(SyncExecutorConfig.class)
+	static class DefaultPatternMatcherConfig extends
+			AbstractSecurityWebSocketMessageBrokerConfigurer {
+
+		// @formatter:off
+		public void registerStompEndpoints(StompEndpointRegistry registry) {
+			registry
+				.addEndpoint("/other")
+				.setHandshakeHandler(testHandshakeHandler());
+		}
+		// @formatter:on
+
+		// @formatter:off
+		@Override
+		protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
+			messages
+				.simpDestMatchers("/app/a/*").permitAll()
+				.anyMessage().denyAll();
+		}
+		// @formatter:on
+
+		@Override
+		public void configureMessageBroker(MessageBrokerRegistry registry) {
+			registry.enableSimpleBroker("/queue/", "/topic/");
+			registry.setApplicationDestinationPrefixes("/app");
+		}
+
+		@Bean
+		public TestHandshakeHandler testHandshakeHandler() {
+			return new TestHandshakeHandler();
+		}
+	}
+
 	private void assertHandshake(HttpServletRequest request) {
 		TestHandshakeHandler handshakeHandler = context
 				.getBean(TestHandshakeHandler.class);
@@ -358,10 +515,14 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
 					.withSockJS().setInterceptors(new HttpSessionHandshakeInterceptor());
 		}
 
+		// @formatter:off
 		@Override
 		protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
-			messages.simpDestMatchers("/permitAll/**").permitAll().anyMessage().denyAll();
+			messages
+				.simpDestMatchers("/permitAll/**").permitAll()
+				.anyMessage().denyAll();
 		}
+		// @formatter:on
 
 		@Override
 		public void configureMessageBroker(MessageBrokerRegistry registry) {
@@ -431,10 +592,14 @@ public class AbstractSecurityWebSocketMessageBrokerConfigurerTests {
 					.addInterceptors(new HttpSessionHandshakeInterceptor());
 		}
 
+		// @formatter:off
 		@Override
 		protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {
-			messages.simpDestMatchers("/permitAll/**").permitAll().anyMessage().denyAll();
+			messages
+				.simpDestMatchers("/permitAll/**").permitAll()
+				.anyMessage().denyAll();
 		}
+		// @formatter:on
 
 		@Bean
 		public TestHandshakeHandler testHandshakeHandler() {