Pārlūkot izejas kodu

SEC-2190: Support WebApplicationContext in ServletContext

Rob Winch 9 gadi atpakaļ
vecāks
revīzija
d467146e49

+ 5 - 4
taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java

@@ -42,6 +42,7 @@ import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.FilterInvocation;
 import org.springframework.security.web.WebAttributes;
 import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator;
+import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils;
 import org.springframework.util.StringUtils;
 import org.springframework.web.context.support.WebApplicationContextUtils;
 
@@ -312,8 +313,7 @@ public abstract class AbstractAuthorizeTag {
 
     @SuppressWarnings({ "unchecked", "rawtypes" })
     private SecurityExpressionHandler<FilterInvocation> getExpressionHandler() throws IOException {
-        ApplicationContext appContext = WebApplicationContextUtils
-                .getRequiredWebApplicationContext(getServletContext());
+        ApplicationContext appContext = SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(getServletContext());
         Map<String, SecurityExpressionHandler> handlers = appContext
                 .getBeansOfType(SecurityExpressionHandler.class);
 
@@ -335,8 +335,9 @@ public abstract class AbstractAuthorizeTag {
             return privEvaluatorFromRequest;
         }
 
-        ApplicationContext ctx = WebApplicationContextUtils.getRequiredWebApplicationContext(getServletContext());
-        Map<String, WebInvocationPrivilegeEvaluator> wipes = ctx.getBeansOfType(WebInvocationPrivilegeEvaluator.class);
+        ApplicationContext ctx = SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(getServletContext());
+        Map<String, WebInvocationPrivilegeEvaluator> wipes = ctx
+                .getBeansOfType(WebInvocationPrivilegeEvaluator.class);
 
         if (wipes.size() == 0) {
             throw new IOException(

+ 2 - 1
taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java

@@ -21,6 +21,7 @@ import org.springframework.security.access.PermissionEvaluator;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.taglibs.TagLibConfig;
+import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils;
 import org.springframework.web.context.support.WebApplicationContextUtils;
 
 import javax.servlet.ServletContext;
@@ -136,7 +137,7 @@ public class AccessControlListTag extends TagSupport {
     protected ApplicationContext getContext(PageContext pageContext) {
         ServletContext servletContext = pageContext.getServletContext();
 
-        return WebApplicationContextUtils.getRequiredWebApplicationContext(servletContext);
+        return SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(servletContext);
     }
 
     public Object getDomainObject() {

+ 37 - 1
taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java

@@ -12,12 +12,16 @@
  */
 package org.springframework.security.taglibs.authz;
 
+import static org.fest.assertions.Assertions.*;
+
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 import java.io.IOException;
+import java.util.Collections;
 
 import javax.servlet.ServletContext;
 import javax.servlet.ServletRequest;
@@ -29,10 +33,14 @@ import org.junit.Test;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockServletContext;
+import org.springframework.security.access.expression.SecurityExpressionHandler;
+import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.WebAttributes;
 import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator;
+import org.springframework.security.web.access.expression.DefaultWebSecurityExpressionHandler;
+import org.springframework.web.context.WebApplicationContext;
 
 /**
  *
@@ -63,13 +71,41 @@ public class AbstractAuthorizeTagTests {
         String uri = "/something";
         WebInvocationPrivilegeEvaluator expected = mock(WebInvocationPrivilegeEvaluator.class);
         tag.setUrl(uri);
-        request.setAttribute(WebAttributes.WEB_INVOCATION_PRIVILEGE_EVALUATOR_ATTRIBUTE, expected);
+        request.setAttribute(WebAttributes.WEB_INVOCATION_PRIVILEGE_EVALUATOR_ATTRIBUTE,
+                expected);
 
         tag.authorizeUsingUrlCheck();
 
         verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any(Authentication.class));
     }
 
+    @Test
+    public void privilegeEvaluatorFromChildContext() throws IOException {
+        String uri = "/something";
+        WebInvocationPrivilegeEvaluator expected = mock(WebInvocationPrivilegeEvaluator.class);
+        tag.setUrl(uri);
+        WebApplicationContext wac = mock(WebApplicationContext.class);
+        when(wac.getBeansOfType(WebInvocationPrivilegeEvaluator.class)).thenReturn(Collections.singletonMap("wipe", expected));
+        servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac);
+
+        tag.authorizeUsingUrlCheck();
+
+        verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any(Authentication.class));
+    }
+
+    @Test
+    @SuppressWarnings("rawtypes")
+    public void expressionFromChildContext() throws IOException {
+        SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "pass","USER"));
+        DefaultWebSecurityExpressionHandler expected = new DefaultWebSecurityExpressionHandler();
+        tag.setAccess("permitAll");
+        WebApplicationContext wac = mock(WebApplicationContext.class);
+        when(wac.getBeansOfType(SecurityExpressionHandler.class)).thenReturn(Collections.<String,SecurityExpressionHandler>singletonMap("wipe", expected));
+        servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac);
+
+        assertThat(tag.authorize()).isTrue();
+    }
+
     private class AuthzTag extends AbstractAuthorizeTag {
 
         @Override

+ 32 - 8
taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java

@@ -26,6 +26,7 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.web.context.WebApplicationContext;
 
+import javax.servlet.ServletContext;
 import javax.servlet.jsp.tagext.Tag;
 import java.util.*;
 
@@ -40,7 +41,7 @@ public class AccessControlListTagTests {
     AccessControlListTag tag;
     PermissionEvaluator pe;
     MockPageContext pageContext;
-    Authentication bob = new TestingAuthenticationToken("bob","bobspass","A");
+    Authentication bob = new TestingAuthenticationToken("bob", "bobspass", "A");
 
     @Before
     @SuppressWarnings("rawtypes")
@@ -56,8 +57,10 @@ public class AccessControlListTagTests {
         when(ctx.getBeansOfType(PermissionEvaluator.class)).thenReturn(beanMap);
 
         MockServletContext servletCtx = new MockServletContext();
-        servletCtx.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, ctx);
-        pageContext = new MockPageContext(servletCtx, new MockHttpServletRequest(), new MockHttpServletResponse());
+        servletCtx.setAttribute(
+                WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, ctx);
+        pageContext = new MockPageContext(servletCtx, new MockHttpServletRequest(),
+                new MockHttpServletResponse());
         tag.setPageContext(pageContext);
     }
 
@@ -78,7 +81,28 @@ public class AccessControlListTagTests {
         assertEquals("READ", tag.getHasPermission());
 
         assertEquals(Tag.EVAL_BODY_INCLUDE, tag.doStartTag());
-        assertTrue((Boolean)pageContext.getAttribute("allowed"));
+        assertTrue((Boolean) pageContext.getAttribute("allowed"));
+    }
+
+    @Test
+    public void childContext() throws Exception {
+        ServletContext servletContext = pageContext.getServletContext();
+        WebApplicationContext wac = (WebApplicationContext) servletContext
+                .getAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE);
+        servletContext.removeAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE);
+        servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac);
+
+        Object domainObject = new Object();
+        when(pe.hasPermission(bob, domainObject, "READ")).thenReturn(true);
+
+        tag.setDomainObject(domainObject);
+        tag.setHasPermission("READ");
+        tag.setVar("allowed");
+        assertSame(domainObject, tag.getDomainObject());
+        assertEquals("READ", tag.getHasPermission());
+
+        assertEquals(Tag.EVAL_BODY_INCLUDE, tag.doStartTag());
+        assertTrue((Boolean) pageContext.getAttribute("allowed"));
     }
 
     // SEC-2022
@@ -95,7 +119,7 @@ public class AccessControlListTagTests {
         assertEquals("READ,WRITE", tag.getHasPermission());
 
         assertEquals(Tag.EVAL_BODY_INCLUDE, tag.doStartTag());
-        assertTrue((Boolean)pageContext.getAttribute("allowed"));
+        assertTrue((Boolean) pageContext.getAttribute("allowed"));
         verify(pe).hasPermission(bob, domainObject, "READ");
         verify(pe).hasPermission(bob, domainObject, "WRITE");
         verifyNoMoreInteractions(pe);
@@ -115,7 +139,7 @@ public class AccessControlListTagTests {
         assertEquals("1,2", tag.getHasPermission());
 
         assertEquals(Tag.EVAL_BODY_INCLUDE, tag.doStartTag());
-        assertTrue((Boolean)pageContext.getAttribute("allowed"));
+        assertTrue((Boolean) pageContext.getAttribute("allowed"));
         verify(pe).hasPermission(bob, domainObject, 1);
         verify(pe).hasPermission(bob, domainObject, 2);
         verifyNoMoreInteractions(pe);
@@ -134,7 +158,7 @@ public class AccessControlListTagTests {
         assertEquals("1,WRITE", tag.getHasPermission());
 
         assertEquals(Tag.EVAL_BODY_INCLUDE, tag.doStartTag());
-        assertTrue((Boolean)pageContext.getAttribute("allowed"));
+        assertTrue((Boolean) pageContext.getAttribute("allowed"));
         verify(pe).hasPermission(bob, domainObject, 1);
         verify(pe).hasPermission(bob, domainObject, "WRITE");
         verifyNoMoreInteractions(pe);
@@ -150,6 +174,6 @@ public class AccessControlListTagTests {
         tag.setVar("allowed");
 
         assertEquals(Tag.SKIP_BODY, tag.doStartTag());
-        assertFalse((Boolean)pageContext.getAttribute("allowed"));
+        assertFalse((Boolean) pageContext.getAttribute("allowed"));
     }
 }

+ 88 - 0
web/src/main/java/org/springframework/security/web/context/support/SecurityWebApplicationContextUtils.java

@@ -0,0 +1,88 @@
+/*
+ * Copyright 2002-2015 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.web.context.support;
+
+import java.util.Enumeration;
+
+import javax.servlet.ServletContext;
+
+import org.springframework.web.context.WebApplicationContext;
+import org.springframework.web.context.support.WebApplicationContextUtils;
+
+/**
+ * Spring Security extension to Spring's {@link WebApplicationContextUtils}.
+ *
+ * @author Rob Winch
+ */
+public abstract class SecurityWebApplicationContextUtils extends WebApplicationContextUtils {
+
+    /**
+     * Find a unique {@code WebApplicationContext} for this web app: either the
+     * root web app context (preferred) or a unique {@code WebApplicationContext}
+     * among the registered {@code ServletContext} attributes (typically coming
+     * from a single {@code DispatcherServlet} in the current web application).
+     * <p>Note that {@code DispatcherServlet}'s exposure of its context can be
+     * controlled through its {@code publishContext} property, which is {@code true}
+     * by default but can be selectively switched to only publish a single context
+     * despite multiple {@code DispatcherServlet} registrations in the web app.
+     * @param sc ServletContext to find the web application context for
+     * @return the desired WebApplicationContext for this web app
+     * @see #getWebApplicationContext(ServletContext)
+     * @see ServletContext#getAttributeNames()
+     * @throws IllegalStateException if no WebApplicationContext can be found
+     */
+    public static WebApplicationContext findRequiredWebApplicationContext(ServletContext servletContext) {
+        WebApplicationContext wac = findWebApplicationContext(servletContext);
+        if (wac == null) {
+            throw new IllegalStateException("No WebApplicationContext found: no ContextLoaderListener registered?");
+        }
+        return wac;
+    }
+
+    /**
+     * Find a unique {@code WebApplicationContext} for this web app: either the
+     * root web app context (preferred) or a unique {@code WebApplicationContext}
+     * among the registered {@code ServletContext} attributes (typically coming
+     * from a single {@code DispatcherServlet} in the current web application).
+     * <p>Note that {@code DispatcherServlet}'s exposure of its context can be
+     * controlled through its {@code publishContext} property, which is {@code true}
+     * by default but can be selectively switched to only publish a single context
+     * despite multiple {@code DispatcherServlet} registrations in the web app.
+     * @param sc ServletContext to find the web application context for
+     * @return the desired WebApplicationContext for this web app, or {@code null} if none
+     * @see #getWebApplicationContext(ServletContext)
+     * @see ServletContext#getAttributeNames()
+     */
+    private static WebApplicationContext findWebApplicationContext(ServletContext sc) {
+        WebApplicationContext wac = getWebApplicationContext(sc);
+        if (wac == null) {
+            Enumeration<String> attrNames = sc.getAttributeNames();
+            while (attrNames.hasMoreElements()) {
+                String attrName = attrNames.nextElement();
+                Object attrValue = sc.getAttribute(attrName);
+                if (attrValue instanceof WebApplicationContext) {
+                    if (wac != null) {
+                        throw new IllegalStateException("No unique WebApplicationContext found: more than one " +
+                                "DispatcherServlet registered with publishContext=true?");
+                    }
+                    wac = (WebApplicationContext) attrValue;
+                }
+            }
+        }
+        return wac;
+    }
+
+}

+ 2 - 1
web/src/main/java/org/springframework/security/web/session/HttpSessionEventPublisher.java

@@ -19,6 +19,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 
 import org.springframework.context.ApplicationContext;
+import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils;
 
 import org.springframework.web.context.support.WebApplicationContextUtils;
 
@@ -49,7 +50,7 @@ public class HttpSessionEventPublisher implements HttpSessionListener {
     //~ Methods ========================================================================================================
 
     ApplicationContext getContext(ServletContext servletContext) {
-        return WebApplicationContextUtils.getRequiredWebApplicationContext(servletContext);
+        return SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(servletContext);
     }
 
     /**

+ 39 - 1
web/src/test/java/org/springframework/security/web/session/HttpSessionEventPublisherTests.java

@@ -44,7 +44,45 @@ public class HttpSessionEventPublisherTests {
         StaticWebApplicationContext context = new StaticWebApplicationContext();
 
         MockServletContext servletContext = new MockServletContext();
-        servletContext.setAttribute(StaticWebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, context);
+        servletContext.setAttribute(
+                StaticWebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE,
+                context);
+
+        context.setServletContext(servletContext);
+        context.registerSingleton("listener", MockApplicationListener.class, null);
+        context.refresh();
+
+        MockHttpSession session = new MockHttpSession(servletContext);
+        MockApplicationListener listener = (MockApplicationListener) context
+                .getBean("listener");
+
+        HttpSessionEvent event = new HttpSessionEvent(session);
+
+        publisher.sessionCreated(event);
+
+        assertNotNull(listener.getCreatedEvent());
+        assertNull(listener.getDestroyedEvent());
+        assertEquals(session, listener.getCreatedEvent().getSession());
+
+        listener.setCreatedEvent(null);
+        listener.setDestroyedEvent(null);
+
+        publisher.sessionDestroyed(event);
+        assertNotNull(listener.getDestroyedEvent());
+        assertNull(listener.getCreatedEvent());
+        assertEquals(session, listener.getDestroyedEvent().getSession());
+    }
+
+    @Test
+    public void publishedEventIsReceivedbyListenerChildContext() {
+        HttpSessionEventPublisher publisher = new HttpSessionEventPublisher();
+
+        StaticWebApplicationContext context = new StaticWebApplicationContext();
+
+        MockServletContext servletContext = new MockServletContext();
+        servletContext.setAttribute(
+                "org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher",
+                context);
 
         context.setServletContext(servletContext);
         context.registerSingleton("listener", MockApplicationListener.class, null);