瀏覽代碼

SEC-2703: ChannelSecurityInterceptor use ThreadLocal for InterceptorStatusToken

Rob Winch 11 年之前
父節點
當前提交
b6fcde880a

+ 11 - 32
messaging/src/main/java/org/springframework/security/messaging/access/intercept/ChannelSecurityInterceptor.java

@@ -17,7 +17,6 @@ package org.springframework.security.messaging.access.intercept;
 
 import org.springframework.messaging.Message;
 import org.springframework.messaging.MessageChannel;
-import org.springframework.messaging.MessageHeaders;
 import org.springframework.messaging.support.ChannelInterceptor;
 import org.springframework.security.access.SecurityMetadataSource;
 import org.springframework.security.access.intercept.AbstractSecurityInterceptor;
@@ -39,6 +38,7 @@ import org.springframework.util.Assert;
  * @author Rob Winch
  */
 public final class ChannelSecurityInterceptor extends AbstractSecurityInterceptor implements ChannelInterceptor {
+    private static final ThreadLocal<InterceptorStatusToken> tokenHolder = new ThreadLocal<InterceptorStatusToken>();
 
     private final MessageSecurityMetadataSource metadataSource;
 
@@ -67,24 +67,19 @@ public final class ChannelSecurityInterceptor extends AbstractSecurityIntercepto
 
     public Message<?> preSend(Message<?> message, MessageChannel channel) {
         InterceptorStatusToken token = beforeInvocation(message);
-        return token == null ? message : new TokenMessage(message,token);
+        if(token != null) {
+            tokenHolder.set(token);
+        }
+        return message;
     }
 
     public void postSend(Message<?> message, MessageChannel channel, boolean sent) {
-        if(!(message instanceof TokenMessage)) {
-            // TODO What if other classes return another instance too?
-            return;
-        }
-        InterceptorStatusToken token = ((TokenMessage)message).getToken();
+        InterceptorStatusToken token = clearToken();
         afterInvocation(token, null);
     }
 
     public void afterSendCompletion(Message<?> message, MessageChannel channel, boolean sent, Exception ex) {
-        if(!(message instanceof TokenMessage)) {
-            // TODO What if other classes return another instance too?
-            return;
-        }
-        InterceptorStatusToken token = ((TokenMessage)message).getToken();
+        InterceptorStatusToken token = clearToken();
         finallyInvocation(token);
     }
 
@@ -99,25 +94,9 @@ public final class ChannelSecurityInterceptor extends AbstractSecurityIntercepto
     public void afterReceiveCompletion(Message<?> message, MessageChannel channel, Exception ex) {
     }
 
-    static final class TokenMessage implements Message {
-        private final Message delegate;
-        private final InterceptorStatusToken token;
-
-        TokenMessage(Message delegate, InterceptorStatusToken token) {
-            this.delegate = delegate;
-            this.token = token;
-        }
-
-        public InterceptorStatusToken getToken() {
-            return token;
-        }
-
-        public MessageHeaders getHeaders() {
-            return delegate.getHeaders();
-        }
-
-        public Object getPayload() {
-            return delegate.getPayload();
-        }
+    private InterceptorStatusToken clearToken() {
+        InterceptorStatusToken token = tokenHolder.get();
+        tokenHolder.remove();
+        return token;
     }
 }

+ 32 - 21
messaging/src/test/java/org/springframework/security/messaging/access/intercept/ChannelSecurityInterceptorTests.java

@@ -28,12 +28,14 @@ import org.springframework.security.access.AccessDeniedException;
 import org.springframework.security.access.ConfigAttribute;
 import org.springframework.security.access.SecurityConfig;
 import org.springframework.security.access.intercept.InterceptorStatusToken;
+import org.springframework.security.access.intercept.RunAsManager;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
 
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Collection;
 import java.util.List;
 
 import static org.fest.assertions.Assertions.assertThat;
@@ -45,13 +47,20 @@ import static org.mockito.Mockito.when;
 @RunWith(MockitoJUnitRunner.class)
 public class ChannelSecurityInterceptorTests {
     @Mock
-    Message message;
+    Message<Object> message;
     @Mock
     MessageChannel channel;
     @Mock
     MessageSecurityMetadataSource source;
     @Mock
     AccessDecisionManager accessDecisionManager;
+    @Mock
+    RunAsManager runAsManager;
+    @Mock
+    Authentication runAs;
+
+    Authentication originalAuth;
+
     List<ConfigAttribute> attrs;
 
     ChannelSecurityInterceptor interceptor;
@@ -61,8 +70,10 @@ public class ChannelSecurityInterceptorTests {
         attrs = Arrays.<ConfigAttribute>asList(new SecurityConfig("ROLE_USER"));
         interceptor = new ChannelSecurityInterceptor(source);
         interceptor.setAccessDecisionManager(accessDecisionManager);
+        interceptor.setRunAsManager(runAsManager);
 
-        SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "pass", "ROLE_USER"));
+        originalAuth = new TestingAuthenticationToken("user", "pass", "ROLE_USER");
+        SecurityContextHolder.getContext().setAuthentication(originalAuth);
     }
 
     @After
@@ -96,11 +107,7 @@ public class ChannelSecurityInterceptorTests {
 
         Message<?> result = interceptor.preSend(message, channel);
 
-        assertThat(result).isInstanceOf(ChannelSecurityInterceptor.TokenMessage.class);
-        ChannelSecurityInterceptor.TokenMessage tm = (ChannelSecurityInterceptor.TokenMessage) result;
-        assertThat(tm.getHeaders()).isSameAs(message.getHeaders());
-        assertThat(tm.getPayload()).isSameAs(message.getPayload());
-        assertThat(tm.getToken()).isNotNull();
+        assertThat(result).isSameAs(message);
     }
 
     @Test(expected = AccessDeniedException.class)
@@ -111,19 +118,19 @@ public class ChannelSecurityInterceptorTests {
         interceptor.preSend(message, channel);
     }
 
-    @Test
-    public void postSendNotTokenMessageNoExceptionThrown() throws Exception {
-        interceptor.postSend(message, channel, true);
-    }
 
     @Test
-    public void postSendTokenMessage() throws Exception {
-        InterceptorStatusToken token = new InterceptorStatusToken(SecurityContextHolder.createEmptyContext(),true,attrs,message);
-        ChannelSecurityInterceptor.TokenMessage tokenMessage = new ChannelSecurityInterceptor.TokenMessage(message, token);
+    public void preSendPostSendRunAs() throws Exception {
+        when(source.getAttributes(message)).thenReturn(attrs);
+        when(runAsManager.buildRunAs(any(Authentication.class), any(), any(Collection.class))).thenReturn(runAs);
+
+        Message<?> preSend = interceptor.preSend(message,channel);
+
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(runAs);
 
-        interceptor.postSend(tokenMessage, channel, true);
+        interceptor.postSend(preSend, channel, true);
 
-        assertThat(SecurityContextHolder.getContext()).isSameAs(token.getSecurityContext());
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(originalAuth);
     }
 
     @Test
@@ -132,13 +139,17 @@ public class ChannelSecurityInterceptorTests {
     }
 
     @Test
-    public void afterSendCompletionTokenMessage() throws Exception {
-        InterceptorStatusToken token = new InterceptorStatusToken(SecurityContextHolder.createEmptyContext(),true,attrs,message);
-        ChannelSecurityInterceptor.TokenMessage tokenMessage = new ChannelSecurityInterceptor.TokenMessage(message, token);
+    public void preSendFinallySendRunAs() throws Exception {
+        when(source.getAttributes(message)).thenReturn(attrs);
+        when(runAsManager.buildRunAs(any(Authentication.class), any(), any(Collection.class))).thenReturn(runAs);
+
+        Message<?> preSend = interceptor.preSend(message,channel);
+
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(runAs);
 
-        interceptor.afterSendCompletion(tokenMessage, channel, true, null);
+        interceptor.afterSendCompletion(preSend, channel, true, new RuntimeException());
 
-        assertThat(SecurityContextHolder.getContext()).isSameAs(token.getSecurityContext());
+        assertThat(SecurityContextHolder.getContext().getAuthentication()).isSameAs(originalAuth);
     }
 
     @Test