ソースを参照

Use SecurityContextHolderStrategy for Taglibs

Issue gh-11060
Josh Cummings 3 年 前
コミット
237a31c69b

+ 17 - 4
taglibs/src/main/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTag.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2004-2010 the original author or authors.
+ * Copyright 2004-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -32,7 +32,9 @@ import org.springframework.expression.ParseException;
 import org.springframework.security.access.expression.ExpressionUtils;
 import org.springframework.security.access.expression.SecurityExpressionHandler;
 import org.springframework.security.core.Authentication;
+import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.web.FilterInvocation;
 import org.springframework.security.web.WebAttributes;
 import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator;
@@ -110,7 +112,7 @@ public abstract class AbstractAuthorizeTag {
 	 * @throws IOException
 	 */
 	public boolean authorizeUsingAccessExpression() throws IOException {
-		if (SecurityContextHolder.getContext().getAuthentication() == null) {
+		if (getContext().getAuthentication() == null) {
 			return false;
 		}
 		SecurityExpressionHandler<FilterInvocation> handler = getExpressionHandler();
@@ -131,7 +133,7 @@ public abstract class AbstractAuthorizeTag {
 		FilterInvocation f = new FilterInvocation(getRequest(), getResponse(), (request, response) -> {
 			throw new UnsupportedOperationException();
 		});
-		return handler.createEvaluationContext(SecurityContextHolder.getContext().getAuthentication(), f);
+		return handler.createEvaluationContext(getContext().getAuthentication(), f);
 	}
 
 	/**
@@ -142,7 +144,7 @@ public abstract class AbstractAuthorizeTag {
 	 */
 	public boolean authorizeUsingUrlCheck() throws IOException {
 		String contextPath = ((HttpServletRequest) getRequest()).getContextPath();
-		Authentication currentUser = SecurityContextHolder.getContext().getAuthentication();
+		Authentication currentUser = getContext().getAuthentication();
 		return getPrivilegeEvaluator().isAllowed(contextPath, getUrl(), getMethod(), currentUser);
 	}
 
@@ -170,6 +172,17 @@ public abstract class AbstractAuthorizeTag {
 		this.method = (method != null) ? method.toUpperCase() : null;
 	}
 
+	private SecurityContext getContext() {
+		ApplicationContext appContext = SecurityWebApplicationContextUtils
+				.findRequiredWebApplicationContext(getServletContext());
+		String[] names = appContext.getBeanNamesForType(SecurityContextHolderStrategy.class);
+		if (names.length == 1) {
+			SecurityContextHolderStrategy strategy = appContext.getBean(SecurityContextHolderStrategy.class);
+			return strategy.getContext();
+		}
+		return SecurityContextHolder.getContext();
+	}
+
 	@SuppressWarnings({ "unchecked", "rawtypes" })
 	private SecurityExpressionHandler<FilterInvocation> getExpressionHandler() throws IOException {
 		ApplicationContext appContext = SecurityWebApplicationContextUtils

+ 11 - 1
taglibs/src/main/java/org/springframework/security/taglibs/authz/AccessControlListTag.java

@@ -33,6 +33,7 @@ import org.springframework.context.ApplicationContext;
 import org.springframework.security.access.PermissionEvaluator;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.taglibs.TagLibConfig;
 import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils;
 
@@ -57,6 +58,9 @@ public class AccessControlListTag extends TagSupport {
 
 	protected static final Log logger = LogFactory.getLog(AccessControlListTag.class);
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	private ApplicationContext applicationContext;
 
 	private Object domainObject;
@@ -78,7 +82,7 @@ public class AccessControlListTag extends TagSupport {
 			// Of course they have access to a null object!
 			return evalBody();
 		}
-		Authentication authentication = SecurityContextHolder.getContext().getAuthentication();
+		Authentication authentication = this.securityContextHolderStrategy.getContext().getAuthentication();
 		if (authentication == null) {
 			logger.debug("SecurityContextHolder did not return a non-null Authentication object, so skipping tag body");
 			return skipBody();
@@ -146,6 +150,12 @@ public class AccessControlListTag extends TagSupport {
 		}
 		this.applicationContext = getContext(this.pageContext);
 		this.permissionEvaluator = getBeanOfType(PermissionEvaluator.class);
+		String[] names = this.applicationContext.getBeanNamesForType(SecurityContextHolderStrategy.class);
+		if (names.length == 1) {
+			SecurityContextHolderStrategy strategy = this.applicationContext
+					.getBean(SecurityContextHolderStrategy.class);
+			this.securityContextHolderStrategy = strategy;
+		}
 	}
 
 	private <T> T getBeanOfType(Class<T> type) throws JspException {

+ 22 - 4
taglibs/src/main/java/org/springframework/security/taglibs/authz/AuthenticationTag.java

@@ -18,6 +18,7 @@ package org.springframework.security.taglibs.authz;
 
 import java.io.IOException;
 
+import javax.servlet.ServletContext;
 import javax.servlet.jsp.JspException;
 import javax.servlet.jsp.PageContext;
 import javax.servlet.jsp.tagext.Tag;
@@ -25,9 +26,12 @@ import javax.servlet.jsp.tagext.TagSupport;
 
 import org.springframework.beans.BeanWrapperImpl;
 import org.springframework.beans.BeansException;
+import org.springframework.context.ApplicationContext;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.web.context.support.SecurityWebApplicationContextUtils;
 import org.springframework.security.web.util.TextEscapeUtils;
 import org.springframework.web.util.TagUtils;
 
@@ -42,6 +46,9 @@ import org.springframework.web.util.TagUtils;
  */
 public class AuthenticationTag extends TagSupport {
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	private String var;
 
 	private String property;
@@ -76,6 +83,18 @@ public class AuthenticationTag extends TagSupport {
 		this.scopeSpecified = true;
 	}
 
+	public void setPageContext(PageContext pageContext) {
+		super.setPageContext(pageContext);
+		ServletContext servletContext = pageContext.getServletContext();
+		ApplicationContext context = SecurityWebApplicationContextUtils
+				.findRequiredWebApplicationContext(servletContext);
+		String[] names = context.getBeanNamesForType(SecurityContextHolderStrategy.class);
+		if (names.length == 1) {
+			SecurityContextHolderStrategy strategy = context.getBean(SecurityContextHolderStrategy.class);
+			this.securityContextHolderStrategy = strategy;
+		}
+	}
+
 	@Override
 	public int doStartTag() throws JspException {
 		return super.doStartTag();
@@ -86,12 +105,11 @@ public class AuthenticationTag extends TagSupport {
 		Object result = null;
 		// determine the value by...
 		if (this.property != null) {
-			if ((SecurityContextHolder.getContext() == null)
-					|| !(SecurityContextHolder.getContext() instanceof SecurityContext)
-					|| (SecurityContextHolder.getContext().getAuthentication() == null)) {
+			SecurityContext context = this.securityContextHolderStrategy.getContext();
+			if ((context == null) || !(context instanceof SecurityContext) || (context.getAuthentication() == null)) {
 				return Tag.EVAL_PAGE;
 			}
-			Authentication auth = SecurityContextHolder.getContext().getAuthentication();
+			Authentication auth = context.getAuthentication();
 			if (auth.getPrincipal() == null) {
 				return Tag.EVAL_PAGE;
 			}

+ 28 - 1
taglibs/src/test/java/org/springframework/security/taglibs/authz/AbstractAuthorizeTagTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2016 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -32,11 +32,15 @@ import org.springframework.mock.web.MockHttpServletResponse;
 import org.springframework.mock.web.MockServletContext;
 import org.springframework.security.access.expression.SecurityExpressionHandler;
 import org.springframework.security.authentication.TestingAuthenticationToken;
+import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.web.WebAttributes;
 import org.springframework.security.web.access.WebInvocationPrivilegeEvaluator;
 import org.springframework.security.web.access.expression.DefaultWebSecurityExpressionHandler;
 import org.springframework.web.context.WebApplicationContext;
+import org.springframework.web.context.support.GenericWebApplicationContext;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.ArgumentMatchers.any;
@@ -74,12 +78,33 @@ public class AbstractAuthorizeTagTests {
 
 	@Test
 	public void privilegeEvaluatorFromRequest() throws IOException {
+		WebApplicationContext wac = mock(WebApplicationContext.class);
+		this.servletContext.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac);
+		given(wac.getBeanNamesForType(SecurityContextHolderStrategy.class)).willReturn(new String[0]);
+		String uri = "/something";
+		WebInvocationPrivilegeEvaluator expected = mock(WebInvocationPrivilegeEvaluator.class);
+		this.tag.setUrl(uri);
+		this.request.setAttribute(WebAttributes.WEB_INVOCATION_PRIVILEGE_EVALUATOR_ATTRIBUTE, expected);
+		this.tag.authorizeUsingUrlCheck();
+		verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any());
+	}
+
+	@Test
+	public void privilegeEvaluatorFromRequestUsesSecurityContextHolderStrategy() throws IOException {
+		SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
+		given(strategy.getContext()).willReturn(new SecurityContextImpl(
+				new TestingAuthenticationToken("user", "password", AuthorityUtils.NO_AUTHORITIES)));
+		GenericWebApplicationContext wac = new GenericWebApplicationContext();
+		wac.registerBean(SecurityContextHolderStrategy.class, () -> strategy);
+		wac.refresh();
+		this.servletContext.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, wac);
 		String uri = "/something";
 		WebInvocationPrivilegeEvaluator expected = mock(WebInvocationPrivilegeEvaluator.class);
 		this.tag.setUrl(uri);
 		this.request.setAttribute(WebAttributes.WEB_INVOCATION_PRIVILEGE_EVALUATOR_ATTRIBUTE, expected);
 		this.tag.authorizeUsingUrlCheck();
 		verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any());
+		verify(strategy).getContext();
 	}
 
 	@Test
@@ -90,6 +115,7 @@ public class AbstractAuthorizeTagTests {
 		WebApplicationContext wac = mock(WebApplicationContext.class);
 		given(wac.getBeansOfType(WebInvocationPrivilegeEvaluator.class))
 				.willReturn(Collections.singletonMap("wipe", expected));
+		given(wac.getBeanNamesForType(SecurityContextHolderStrategy.class)).willReturn(new String[0]);
 		this.servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac);
 		this.tag.authorizeUsingUrlCheck();
 		verify(expected).isAllowed(eq(""), eq(uri), eq("GET"), any());
@@ -104,6 +130,7 @@ public class AbstractAuthorizeTagTests {
 		WebApplicationContext wac = mock(WebApplicationContext.class);
 		given(wac.getBeansOfType(SecurityExpressionHandler.class))
 				.willReturn(Collections.<String, SecurityExpressionHandler>singletonMap("wipe", expected));
+		given(wac.getBeanNamesForType(SecurityContextHolderStrategy.class)).willReturn(new String[0]);
 		this.servletContext.setAttribute("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher", wac);
 		assertThat(this.tag.authorize()).isTrue();
 	}

+ 29 - 1
taglibs/src/test/java/org/springframework/security/taglibs/authz/AccessControlListTagTests.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2016 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -34,7 +34,10 @@ import org.springframework.security.access.PermissionEvaluator;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.web.context.WebApplicationContext;
+import org.springframework.web.context.support.GenericWebApplicationContext;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.mockito.BDDMockito.given;
@@ -68,6 +71,7 @@ public class AccessControlListTagTests {
 		Map beanMap = new HashMap();
 		beanMap.put("pe", this.pe);
 		given(ctx.getBeansOfType(PermissionEvaluator.class)).willReturn(beanMap);
+		given(ctx.getBeanNamesForType(SecurityContextHolderStrategy.class)).willReturn(new String[0]);
 		MockServletContext servletCtx = new MockServletContext();
 		servletCtx.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, ctx);
 		this.pageContext = new MockPageContext(servletCtx, new MockHttpServletRequest(), new MockHttpServletResponse());
@@ -92,6 +96,30 @@ public class AccessControlListTagTests {
 		assertThat((Boolean) this.pageContext.getAttribute("allowed")).isTrue();
 	}
 
+	@Test
+	public void securityContextHolderStrategyIsUsedIfConfigured() throws Exception {
+		SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
+		given(strategy.getContext()).willReturn(new SecurityContextImpl(this.bob));
+		GenericWebApplicationContext context = new GenericWebApplicationContext();
+		context.registerBean(SecurityContextHolderStrategy.class, () -> strategy);
+		context.registerBean(PermissionEvaluator.class, () -> this.pe);
+		context.refresh();
+		MockServletContext servletCtx = new MockServletContext();
+		servletCtx.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, context);
+		this.pageContext = new MockPageContext(servletCtx, new MockHttpServletRequest(), new MockHttpServletResponse());
+		this.tag.setPageContext(this.pageContext);
+		Object domainObject = new Object();
+		given(this.pe.hasPermission(this.bob, domainObject, "READ")).willReturn(true);
+		this.tag.setDomainObject(domainObject);
+		this.tag.setHasPermission("READ");
+		this.tag.setVar("allowed");
+		assertThat(this.tag.getDomainObject()).isSameAs(domainObject);
+		assertThat(this.tag.getHasPermission()).isEqualTo("READ");
+		assertThat(this.tag.doStartTag()).isEqualTo(Tag.EVAL_BODY_INCLUDE);
+		assertThat((Boolean) this.pageContext.getAttribute("allowed")).isTrue();
+		verify(strategy).getContext();
+	}
+
 	@Test
 	public void childContext() throws Exception {
 		ServletContext servletContext = this.pageContext.getServletContext();

+ 27 - 0
taglibs/src/test/java/org/springframework/security/taglibs/authz/AuthenticationTagTests.java

@@ -22,14 +22,23 @@ import javax.servlet.jsp.tagext.Tag;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.Test;
 
+import org.springframework.mock.web.MockPageContext;
+import org.springframework.mock.web.MockServletContext;
 import org.springframework.security.authentication.TestingAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+import org.springframework.security.core.context.SecurityContextImpl;
 import org.springframework.security.core.userdetails.User;
+import org.springframework.web.context.WebApplicationContext;
+import org.springframework.web.context.support.GenericWebApplicationContext;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
 
 /**
  * Tests {@link AuthenticationTag}.
@@ -131,6 +140,24 @@ public class AuthenticationTagTests {
 		assertThat(this.authenticationTag.getLastMessage()).isEqualTo("<>& ");
 	}
 
+	@Test
+	public void setSecurityContextHolderStrategyThenUses() throws Exception {
+		SecurityContextHolderStrategy strategy = mock(SecurityContextHolderStrategy.class);
+		given(strategy.getContext()).willReturn(new SecurityContextImpl(
+				new TestingAuthenticationToken("rodAsString", "koala", AuthorityUtils.NO_AUTHORITIES)));
+		MockServletContext servletContext = new MockServletContext();
+		GenericWebApplicationContext applicationContext = new GenericWebApplicationContext();
+		applicationContext.registerBean(SecurityContextHolderStrategy.class, () -> strategy);
+		applicationContext.refresh();
+		servletContext.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, applicationContext);
+		this.authenticationTag.setPageContext(new MockPageContext(servletContext));
+		this.authenticationTag.setProperty("principal");
+		assertThat(this.authenticationTag.doStartTag()).isEqualTo(Tag.SKIP_BODY);
+		assertThat(this.authenticationTag.doEndTag()).isEqualTo(Tag.EVAL_PAGE);
+		assertThat(this.authenticationTag.getLastMessage()).isEqualTo("rodAsString");
+		verify(strategy).getContext();
+	}
+
 	private class MyAuthenticationTag extends AuthenticationTag {
 
 		String lastMessage = null;