2
0
Эх сурвалжийг харах

SEC-2190: Support WebApplicationContext in ServletContext attribute

Rob Winch 9 жил өмнө
parent
commit
2bbe70501b

+ 3 - 4
taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java

@@ -37,6 +37,7 @@ import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.FilterInvocation;
 import org.springframework.security.web.FilterInvocation;
 import org.springframework.security.web.WebAttributes;
 import org.springframework.security.web.WebAttributes;
 import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator;
 import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator;
+import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils;
 import org.springframework.util.StringUtils;
 import org.springframework.util.StringUtils;
 import org.springframework.web.context.support.WebApplicationContextUtils;
 import org.springframework.web.context.support.WebApplicationContextUtils;
 
 
@@ -201,8 +202,7 @@ public abstract class AbstractAuthorizeTag {
 	@SuppressWarnings({ "unchecked", "rawtypes" })
 	@SuppressWarnings({ "unchecked", "rawtypes" })
 	private SecurityExpressionHandler<FilterInvocation> getExpressionHandler()
 	private SecurityExpressionHandler<FilterInvocation> getExpressionHandler()
 			throws IOException {
 			throws IOException {
-		ApplicationContext appContext = WebApplicationContextUtils
-				.getRequiredWebApplicationContext(getServletContext());
+		ApplicationContext appContext = SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(getServletContext());
 		Map<String, SecurityExpressionHandler> handlers = appContext
 		Map<String, SecurityExpressionHandler> handlers = appContext
 				.getBeansOfType(SecurityExpressionHandler.class);
 				.getBeansOfType(SecurityExpressionHandler.class);
 
 
@@ -225,8 +225,7 @@ public abstract class AbstractAuthorizeTag {
 			return privEvaluatorFromRequest;
 			return privEvaluatorFromRequest;
 		}
 		}
 
 
-		ApplicationContext ctx = WebApplicationContextUtils
-				.getRequiredWebApplicationContext(getServletContext());
+		ApplicationContext ctx = SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(getServletContext());
 		Map<String, WebInvocationPrivilegeEvaluator> wipes = ctx
 		Map<String, WebInvocationPrivilegeEvaluator> wipes = ctx
 				.getBeansOfType(WebInvocationPrivilegeEvaluator.class);
 				.getBeansOfType(WebInvocationPrivilegeEvaluator.class);
 
 

+ 12 - 9
taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java

@@ -14,6 +14,16 @@
  */
  */
 package org.springframework.security.taglibs.authz;
 package org.springframework.security.taglibs.authz;
 
 
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+
+import javax.servlet.ServletContext;
+import javax.servlet.jsp.JspException;
+import javax.servlet.jsp.PageContext;
+import javax.servlet.jsp.tagext.Tag;
+import javax.servlet.jsp.tagext.TagSupport;
+
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.commons.logging.LogFactory;
 import org.springframework.context.ApplicationContext;
 import org.springframework.context.ApplicationContext;
@@ -21,15 +31,9 @@ import org.springframework.security.access.PermissionEvaluator;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.taglibs.TagLibConfig;
 import org.springframework.security.taglibs.TagLibConfig;
+import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils;
 import org.springframework.web.context.support.WebApplicationContextUtils;
 import org.springframework.web.context.support.WebApplicationContextUtils;
 
 
-import javax.servlet.ServletContext;
-import javax.servlet.jsp.JspException;
-import javax.servlet.jsp.PageContext;
-import javax.servlet.jsp.tagext.Tag;
-import javax.servlet.jsp.tagext.TagSupport;
-import java.util.*;
-
 /**
 /**
  * An implementation of {@link Tag} that allows its body through if all authorizations are
  * An implementation of {@link Tag} that allows its body through if all authorizations are
  * granted to the request's principal.
  * granted to the request's principal.
@@ -142,8 +146,7 @@ public class AccessControlListTag extends TagSupport {
 	protected ApplicationContext getContext(PageContext pageContext) {
 	protected ApplicationContext getContext(PageContext pageContext) {
 		ServletContext servletContext = pageContext.getServletContext();
 		ServletContext servletContext = pageContext.getServletContext();
 
 
-		return WebApplicationContextUtils
-				.getRequiredWebApplicationContext(servletContext);
+		return SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(servletContext);
 	}
 	}
 
 
 	public Object getDomainObject() {
 	public Object getDomainObject() {

+ 35 - 0
taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java

@@ -12,12 +12,16 @@
  */
  */
 package org.springframework.security.taglibs.authz;
 package org.springframework.security.taglibs.authz;
 
 
+import static org.fest.assertions.Assertions.*;
+
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Matchers.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
 
 
 import java.io.IOException;
 import java.io.IOException;
+import java.util.Collections;
 
 
 import javax.servlet.ServletContext;
 import javax.servlet.ServletContext;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletRequest;
@@ -29,10 +33,14 @@ import org.junit.Test;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockServletContext;
 import org.springframework.mock.web.MockServletContext;
+import org.springframework.security.access.expression.SecurityExpressionHandler;
+import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.web.WebAttributes;
 import org.springframework.security.web.WebAttributes;
 import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator;
 import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator;
+import org.springframework.security.web.access.expression.DefaultWebSecurityExpressionHandler;
+import org.springframework.web.context.WebApplicationContext;
 
 
 /**
 /**
  *
  *
@@ -71,6 +79,33 @@ public class AbstractAuthorizeTagTests {
 		verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any(Authentication.class));
 		verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any(Authentication.class));
 	}
 	}
 
 
+	@Test
+	public void privilegeEvaluatorFromChildContext() throws IOException {
+		String uri = "/something";
+		WebInvocationPrivilegeEvaluator expected = mock(WebInvocationPrivilegeEvaluator.class);
+		tag.setUrl(uri);
+		WebApplicationContext wac = mock(WebApplicationContext.class);
+		when(wac.getBeansOfType(WebInvocationPrivilegeEvaluator.class)).thenReturn(Collections.singletonMap("wipe", expected));
+		servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac);
+
+		tag.authorizeUsingUrlCheck();
+
+		verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any(Authentication.class));
+	}
+
+	@Test
+	@SuppressWarnings("rawtypes")
+	public void expressionFromChildContext() throws IOException {
+		SecurityContextHolder.getContext().setAuthentication(new TestingAuthenticationToken("user", "pass","USER"));
+		DefaultWebSecurityExpressionHandler expected = new DefaultWebSecurityExpressionHandler();
+		tag.setAccess("permitAll");
+		WebApplicationContext wac = mock(WebApplicationContext.class);
+		when(wac.getBeansOfType(SecurityExpressionHandler.class)).thenReturn(Collections.<String,SecurityExpressionHandler>singletonMap("wipe", expected));
+		servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac);
+
+		assertThat(tag.authorize()).isTrue();
+	}
+
 	private class AuthzTag extends AbstractAuthorizeTag {
 	private class AuthzTag extends AbstractAuthorizeTag {
 
 
 		@Override
 		@Override

+ 22 - 0
taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java

@@ -26,6 +26,7 @@ import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.web.context.WebApplicationContext;
 import org.springframework.web.context.WebApplicationContext;
 
 
+import javax.servlet.ServletContext;
 import javax.servlet.jsp.tagext.Tag;
 import javax.servlet.jsp.tagext.Tag;
 import java.util.*;
 import java.util.*;
 
 
@@ -83,6 +84,27 @@ public class AccessControlListTagTests {
 		assertTrue((Boolean) pageContext.getAttribute("allowed"));
 		assertTrue((Boolean) pageContext.getAttribute("allowed"));
 	}
 	}
 
 
+	@Test
+	public void childContext() throws Exception {
+		ServletContext servletContext = pageContext.getServletContext();
+		WebApplicationContext wac = (WebApplicationContext) servletContext
+				.getAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE);
+		servletContext.removeAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE);
+		servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac);
+
+		Object domainObject = new Object();
+		when(pe.hasPermission(bob, domainObject, "READ")).thenReturn(true);
+
+		tag.setDomainObject(domainObject);
+		tag.setHasPermission("READ");
+		tag.setVar("allowed");
+		assertSame(domainObject, tag.getDomainObject());
+		assertEquals("READ", tag.getHasPermission());
+
+		assertEquals(Tag.EVAL_BODY_INCLUDE, tag.doStartTag());
+		assertTrue((Boolean) pageContext.getAttribute("allowed"));
+	}
+
 	// SEC-2022
 	// SEC-2022
 	@Test
 	@Test
 	public void multiHasPermissionsAreSplit() throws Exception {
 	public void multiHasPermissionsAreSplit() throws Exception {

+ 52 - 0
web/src/main/java/org/springframework/security/web/context/support/SecurityWebApplicationContextUtils.java

@@ -0,0 +1,52 @@
+/*
+ * Copyright 2002-2015 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.security.web.context.support;
+
+import javax.servlet.ServletContext;
+
+import org.springframework.web.context.WebApplicationContext;
+import org.springframework.web.context.support.WebApplicationContextUtils;
+
+/**
+ * Spring Security extension to Spring's {@link WebApplicationContextUtils}.
+ *
+ * @author Rob Winch
+ */
+public abstract class SecurityWebApplicationContextUtils extends WebApplicationContextUtils {
+
+	/**
+	 * Find a unique {@code WebApplicationContext} for this web app: either the
+	 * root web app context (preferred) or a unique {@code WebApplicationContext}
+	 * among the registered {@code ServletContext} attributes (typically coming
+	 * from a single {@code DispatcherServlet} in the current web application).
+	 * <p>Note that {@code DispatcherServlet}'s exposure of its context can be
+	 * controlled through its {@code publishContext} property, which is {@code true}
+	 * by default but can be selectively switched to only publish a single context
+	 * despite multiple {@code DispatcherServlet} registrations in the web app.
+	 * @param sc ServletContext to find the web application context for
+	 * @return the desired WebApplicationContext for this web app
+	 * @see #getWebApplicationContext(ServletContext)
+	 * @see ServletContext#getAttributeNames()
+	 * @throws IllegalStateException if no WebApplicationContext can be found
+	 */
+	public static WebApplicationContext findRequiredWebApplicationContext(ServletContext servletContext) {
+		WebApplicationContext wac = findWebApplicationContext(servletContext);
+		if (wac == null) {
+			throw new IllegalStateException("No WebApplicationContext found: no ContextLoaderListener registered?");
+		}
+		return wac;
+	}
+}

+ 3 - 4
web/src/main/java/org/springframework/security/web/session/HttpSessionEventPublisher.java

@@ -19,7 +19,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.commons.logging.LogFactory;
 
 
 import org.springframework.context.ApplicationContext;
 import org.springframework.context.ApplicationContext;
-
+import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils;
 import org.springframework.web.context.support.WebApplicationContextUtils;
 import org.springframework.web.context.support.WebApplicationContextUtils;
 
 
 import javax.servlet.ServletContext;
 import javax.servlet.ServletContext;
@@ -28,7 +28,7 @@ import javax.servlet.http.HttpSessionListener;
 
 
 /**
 /**
  * Declared in web.xml as
  * Declared in web.xml as
- * 
+ *
  * <pre>
  * <pre>
  * &lt;listener&gt;
  * &lt;listener&gt;
  *     &lt;listener-class&gt;org.springframework.security.web.session.HttpSessionEventPublisher&lt;/listener-class&gt;
  *     &lt;listener-class&gt;org.springframework.security.web.session.HttpSessionEventPublisher&lt;/listener-class&gt;
@@ -53,8 +53,7 @@ public class HttpSessionEventPublisher implements HttpSessionListener {
 	// ========================================================================================================
 	// ========================================================================================================
 
 
 	ApplicationContext getContext(ServletContext servletContext) {
 	ApplicationContext getContext(ServletContext servletContext) {
-		return WebApplicationContextUtils
-				.getRequiredWebApplicationContext(servletContext);
+		return SecurityWebApplicationContextUtils.findRequiredWebApplicationContext(servletContext);
 	}
 	}
 
 
 	/**
 	/**

+ 36 - 0
web/src/test/java/org/springframework/security/web/session/HttpSessionEventPublisherTests.java

@@ -73,6 +73,42 @@ public class HttpSessionEventPublisherTests {
 		assertEquals(session, listener.getDestroyedEvent().getSession());
 		assertEquals(session, listener.getDestroyedEvent().getSession());
 	}
 	}
 
 
+	@Test
+	public void publishedEventIsReceivedbyListenerChildContext() {
+		HttpSessionEventPublisher publisher = new HttpSessionEventPublisher();
+
+		StaticWebApplicationContext context = new StaticWebApplicationContext();
+
+		MockServletContext servletContext = new MockServletContext();
+		servletContext.setAttribute(
+				"org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher",
+				context);
+
+		context.setServletContext(servletContext);
+		context.registerSingleton("listener", MockApplicationListener.class, null);
+		context.refresh();
+
+		MockHttpSession session = new MockHttpSession(servletContext);
+		MockApplicationListener listener = (MockApplicationListener) context
+				.getBean("listener");
+
+		HttpSessionEvent event = new HttpSessionEvent(session);
+
+		publisher.sessionCreated(event);
+
+		assertNotNull(listener.getCreatedEvent());
+		assertNull(listener.getDestroyedEvent());
+		assertEquals(session, listener.getCreatedEvent().getSession());
+
+		listener.setCreatedEvent(null);
+		listener.setDestroyedEvent(null);
+
+		publisher.sessionDestroyed(event);
+		assertNotNull(listener.getDestroyedEvent());
+		assertNull(listener.getCreatedEvent());
+		assertEquals(session, listener.getDestroyedEvent().getSession());
+	}
+
 	// SEC-2599
 	// SEC-2599
 	@Test(expected = IllegalStateException.class)
 	@Test(expected = IllegalStateException.class)
 	public void sessionCreatedNullApplicationContext() {
 	public void sessionCreatedNullApplicationContext() {