Browse Source

Update DefaultWebInvocationPrivilegeEvaluator to use current ServletContext

Closes gh-10208
Marcus Da Coregio 4 years ago
parent
commit
89db1c37a3

+ 24 - 2
web/src/main/java/org/springframework/security/web/FilterInvocation.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -29,6 +29,7 @@ import java.util.LinkedHashMap;
 import java.util.Map;
 
 import javax.servlet.FilterChain;
+import javax.servlet.ServletContext;
 import javax.servlet.ServletRequest;
 import javax.servlet.ServletResponse;
 import javax.servlet.http.HttpServletRequest;
@@ -78,10 +79,19 @@ public class FilterInvocation {
 	}
 
 	public FilterInvocation(String contextPath, String servletPath, String method) {
-		this(contextPath, servletPath, null, null, method);
+		this(contextPath, servletPath, method, null);
+	}
+
+	public FilterInvocation(String contextPath, String servletPath, String method, ServletContext servletContext) {
+		this(contextPath, servletPath, null, null, method, servletContext);
 	}
 
 	public FilterInvocation(String contextPath, String servletPath, String pathInfo, String query, String method) {
+		this(contextPath, servletPath, pathInfo, query, method, null);
+	}
+
+	public FilterInvocation(String contextPath, String servletPath, String pathInfo, String query, String method,
+			ServletContext servletContext) {
 		DummyRequest request = new DummyRequest();
 		contextPath = (contextPath != null) ? contextPath : "/cp";
 		request.setContextPath(contextPath);
@@ -90,6 +100,7 @@ public class FilterInvocation {
 		request.setPathInfo(pathInfo);
 		request.setQueryString(query);
 		request.setMethod(method);
+		request.setServletContext(servletContext);
 		this.request = request;
 	}
 
@@ -160,6 +171,8 @@ public class FilterInvocation {
 
 		private String method;
 
+		private ServletContext servletContext;
+
 		private final HttpHeaders headers = new HttpHeaders();
 
 		private final Map<String, String[]> parameters = new LinkedHashMap<>();
@@ -290,6 +303,15 @@ public class FilterInvocation {
 			this.parameters.put(name, values);
 		}
 
+		@Override
+		public ServletContext getServletContext() {
+			return this.servletContext;
+		}
+
+		void setServletContext(ServletContext servletContext) {
+			this.servletContext = servletContext;
+		}
+
 	}
 
 	static final class UnsupportedOperationExceptionInvocationHandler implements InvocationHandler {

+ 13 - 3
web/src/main/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluator.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -18,6 +18,8 @@ package org.springframework.security.web.access;
 
 import java.util.Collection;
 
+import javax.servlet.ServletContext;
+
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 
@@ -28,6 +30,7 @@ import org.springframework.security.access.intercept.AbstractSecurityInterceptor
 import org.springframework.security.core.Authentication;
 import org.springframework.security.web.FilterInvocation;
 import org.springframework.util.Assert;
+import org.springframework.web.context.ServletContextAware;
 
 /**
  * Allows users to determine whether they have privileges for a given web URI.
@@ -36,12 +39,14 @@ import org.springframework.util.Assert;
  * @author Luke Taylor
  * @since 3.0
  */
-public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPrivilegeEvaluator {
+public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPrivilegeEvaluator, ServletContextAware {
 
 	protected static final Log logger = LogFactory.getLog(DefaultWebInvocationPrivilegeEvaluator.class);
 
 	private final AbstractSecurityInterceptor securityInterceptor;
 
+	private ServletContext servletContext;
+
 	public DefaultWebInvocationPrivilegeEvaluator(AbstractSecurityInterceptor securityInterceptor) {
 		Assert.notNull(securityInterceptor, "SecurityInterceptor cannot be null");
 		Assert.isTrue(FilterInvocation.class.equals(securityInterceptor.getSecureObjectClass()),
@@ -82,7 +87,7 @@ public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPriv
 	@Override
 	public boolean isAllowed(String contextPath, String uri, String method, Authentication authentication) {
 		Assert.notNull(uri, "uri parameter is required");
-		FilterInvocation filterInvocation = new FilterInvocation(contextPath, uri, method);
+		FilterInvocation filterInvocation = new FilterInvocation(contextPath, uri, method, this.servletContext);
 		Collection<ConfigAttribute> attributes = this.securityInterceptor.obtainSecurityMetadataSource()
 				.getAttributes(filterInvocation);
 		if (attributes == null) {
@@ -101,4 +106,9 @@ public class DefaultWebInvocationPrivilegeEvaluator implements WebInvocationPriv
 		}
 	}
 
+	@Override
+	public void setServletContext(ServletContext servletContext) {
+		this.servletContext = servletContext;
+	}
+
 }

+ 12 - 1
web/src/test/java/org/springframework/security/web/FilterInvocationTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -24,6 +24,7 @@ import org.junit.Test;
 
 import org.springframework.mock.web.MockHttpServletRequest;
 import org.springframework.mock.web.MockHttpServletResponse;
+import org.springframework.mock.web.MockServletContext;
 import org.springframework.security.web.FilterInvocation.DummyRequest;
 import org.springframework.security.web.util.UrlUtils;
 
@@ -131,4 +132,14 @@ public class FilterInvocationTests {
 		UrlUtils.buildRequestUrl(request);
 	}
 
+	@Test
+	public void constructorWhenServletContextProvidedThenSetServletContextInRequest() {
+		String contextPath = "";
+		String servletPath = "/path";
+		String method = "";
+		MockServletContext mockServletContext = new MockServletContext();
+		FilterInvocation filterInvocation = new FilterInvocation(contextPath, servletPath, method, mockServletContext);
+		assertThat(filterInvocation.getRequest().getServletContext()).isSameAs(mockServletContext);
+	}
+
 }

+ 19 - 1
web/src/test/java/org/springframework/security/web/access/DefaultWebInvocationPrivilegeEvaluatorTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2004, 2005, 2006 Acegi Technology Pty Limited
+ * Copyright 2002-2021 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -18,8 +18,10 @@ package org.springframework.security.web.access;
 
 import org.junit.Before;
 import org.junit.Test;
+import org.mockito.ArgumentCaptor;
 
 import org.springframework.context.ApplicationEventPublisher;
+import org.springframework.mock.web.MockServletContext;
 import org.springframework.security.access.AccessDecisionManager;
 import org.springframework.security.access.AccessDeniedException;
 import org.springframework.security.access.intercept.RunAsManager;
@@ -27,6 +29,7 @@ import org.springframework.security.authentication.AuthenticationManager;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.web.FilterInvocation;
 import org.springframework.security.web.access.intercept.FilterInvocationSecurityMetadataSource;
 import org.springframework.security.web.access.intercept.FilterSecurityInterceptor;
 
@@ -34,9 +37,11 @@ import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyList;
 import static org.mockito.ArgumentMatchers.anyObject;
+import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
 import static org.mockito.BDDMockito.willThrow;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 
 /**
  * Tests
@@ -106,4 +111,17 @@ public class DefaultWebInvocationPrivilegeEvaluatorTests {
 		assertThat(wipe.isAllowed("/foo/index.jsp", token)).isFalse();
 	}
 
+	@Test
+	public void isAllowedWhenServletContextIsSetThenPassedFilterInvocationHasServletContext() {
+		Authentication token = new TestingAuthenticationToken("test", "Password", "MOCK_INDEX");
+		MockServletContext servletContext = new MockServletContext();
+		ArgumentCaptor<FilterInvocation> filterInvocationArgumentCaptor = ArgumentCaptor
+				.forClass(FilterInvocation.class);
+		DefaultWebInvocationPrivilegeEvaluator wipe = new DefaultWebInvocationPrivilegeEvaluator(this.interceptor);
+		wipe.setServletContext(servletContext);
+		wipe.isAllowed("/foo/index.jsp", token);
+		verify(this.adm).decide(eq(token), filterInvocationArgumentCaptor.capture(), any());
+		assertThat(filterInvocationArgumentCaptor.getValue().getRequest().getServletContext()).isNotNull();
+	}
+
 }