Parcourir la source

Add SecurityContextHolderStrategy Test Support

Issue gh-11061
Issue gh-11444
Josh Cummings il y a 3 ans
Parent
commit
f86992a0af

+ 45 - 0
test/src/main/java/org/springframework/security/test/context/TestSecurityContextHolderStrategyAdapter.java

@@ -0,0 +1,45 @@
+/*
+ * Copyright 2002-2022 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.test.context;
+
+import org.springframework.security.core.context.SecurityContext;
+import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
+
+public final class TestSecurityContextHolderStrategyAdapter implements SecurityContextHolderStrategy {
+
+	@Override
+	public void clearContext() {
+		TestSecurityContextHolder.clearContext();
+	}
+
+	@Override
+	public SecurityContext getContext() {
+		return TestSecurityContextHolder.getContext();
+	}
+
+	@Override
+	public void setContext(SecurityContext context) {
+		TestSecurityContextHolder.setContext(context);
+	}
+
+	@Override
+	public SecurityContext createEmptyContext() {
+		return SecurityContextHolder.createEmptyContext();
+	}
+
+}

+ 12 - 2
test/src/main/java/org/springframework/security/test/context/support/WithAnonymousUserSecurityContextFactory.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2014 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -18,12 +18,14 @@ package org.springframework.security.test.context.support;
 
 import java.util.List;
 
+import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.security.authentication.AnonymousAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 
 /**
  * A {@link WithAnonymousUserSecurityContextFactory} that runs with an
@@ -35,13 +37,21 @@ import org.springframework.security.core.context.SecurityContextHolder;
  */
 final class WithAnonymousUserSecurityContextFactory implements WithSecurityContextFactory<WithAnonymousUser> {
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	@Override
 	public SecurityContext createSecurityContext(WithAnonymousUser withUser) {
 		List<GrantedAuthority> authorities = AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS");
 		Authentication authentication = new AnonymousAuthenticationToken("key", "anonymous", authorities);
-		SecurityContext context = SecurityContextHolder.createEmptyContext();
+		SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
 		context.setAuthentication(authentication);
 		return context;
 	}
 
+	@Autowired(required = false)
+	void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 }

+ 11 - 1
test/src/main/java/org/springframework/security/test/context/support/WithMockUserSecurityContextFactory.java

@@ -20,12 +20,14 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.List;
 
+import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.GrantedAuthority;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.core.userdetails.User;
 import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
@@ -39,6 +41,9 @@ import org.springframework.util.StringUtils;
  */
 final class WithMockUserSecurityContextFactory implements WithSecurityContextFactory<WithMockUser> {
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	@Override
 	public SecurityContext createSecurityContext(WithMockUser withUser) {
 		String username = StringUtils.hasLength(withUser.username()) ? withUser.username() : withUser.value();
@@ -60,9 +65,14 @@ final class WithMockUserSecurityContextFactory implements WithSecurityContextFac
 		User principal = new User(username, withUser.password(), true, true, true, true, grantedAuthorities);
 		Authentication authentication = UsernamePasswordAuthenticationToken.authenticated(principal,
 				principal.getPassword(), principal.getAuthorities());
-		SecurityContext context = SecurityContextHolder.createEmptyContext();
+		SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
 		context.setAuthentication(authentication);
 		return context;
 	}
 
+	@Autowired(required = false)
+	void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 }

+ 21 - 4
test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2018 the original author or authors.
+ * Copyright 2002-2022 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -21,12 +21,16 @@ import java.lang.reflect.AnnotatedElement;
 import java.util.function.Supplier;
 
 import org.springframework.beans.BeanUtils;
+import org.springframework.context.ApplicationContext;
 import org.springframework.core.GenericTypeResolver;
 import org.springframework.core.annotation.AnnotatedElementUtils;
 import org.springframework.core.annotation.AnnotationUtils;
+import org.springframework.core.convert.converter.Converter;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.test.context.TestSecurityContextHolder;
+import org.springframework.security.test.context.TestSecurityContextHolderStrategyAdapter;
 import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors;
 import org.springframework.test.context.TestContext;
 import org.springframework.test.context.TestContextAnnotationUtils;
@@ -53,6 +57,19 @@ public class WithSecurityContextTestExecutionListener extends AbstractTestExecut
 	static final String SECURITY_CONTEXT_ATTR_NAME = WithSecurityContextTestExecutionListener.class.getName()
 			.concat(".SECURITY_CONTEXT");
 
+	static final SecurityContextHolderStrategy DEFAULT_SECURITY_CONTEXT_HOLDER_STRATEGY = new TestSecurityContextHolderStrategyAdapter();
+
+	Converter<TestContext, SecurityContextHolderStrategy> securityContextHolderStrategyConverter = (testContext) -> {
+		if (!testContext.hasApplicationContext()) {
+			return DEFAULT_SECURITY_CONTEXT_HOLDER_STRATEGY;
+		}
+		ApplicationContext context = testContext.getApplicationContext();
+		if (context.getBeanNamesForType(SecurityContextHolderStrategy.class).length == 0) {
+			return DEFAULT_SECURITY_CONTEXT_HOLDER_STRATEGY;
+		}
+		return context.getBean(SecurityContextHolderStrategy.class);
+	};
+
 	/**
 	 * Sets up the {@link SecurityContext} for each test method. First the specific method
 	 * is inspected for a {@link WithSecurityContext} or {@link Annotation} that has
@@ -70,7 +87,7 @@ public class WithSecurityContextTestExecutionListener extends AbstractTestExecut
 		}
 		Supplier<SecurityContext> supplier = testSecurityContext.getSecurityContextSupplier();
 		if (testSecurityContext.getTestExecutionEvent() == TestExecutionEvent.TEST_METHOD) {
-			TestSecurityContextHolder.setContext(supplier.get());
+			this.securityContextHolderStrategyConverter.convert(testContext).setContext(supplier.get());
 		}
 		else {
 			testContext.setAttribute(SECURITY_CONTEXT_ATTR_NAME, supplier);
@@ -86,7 +103,7 @@ public class WithSecurityContextTestExecutionListener extends AbstractTestExecut
 		Supplier<SecurityContext> supplier = (Supplier<SecurityContext>) testContext
 				.removeAttribute(SECURITY_CONTEXT_ATTR_NAME);
 		if (supplier != null) {
-			TestSecurityContextHolder.setContext(supplier.get());
+			this.securityContextHolderStrategyConverter.convert(testContext).setContext(supplier.get());
 		}
 	}
 
@@ -166,7 +183,7 @@ public class WithSecurityContextTestExecutionListener extends AbstractTestExecut
 	 */
 	@Override
 	public void afterTestMethod(TestContext testContext) {
-		TestSecurityContextHolder.clearContext();
+		this.securityContextHolderStrategyConverter.convert(testContext).clearContext();
 	}
 
 	/**

+ 10 - 1
test/src/main/java/org/springframework/security/test/context/support/WithUserDetailsSecurityContextFactory.java

@@ -24,6 +24,7 @@ import org.springframework.security.authentication.UsernamePasswordAuthenticatio
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.core.userdetails.ReactiveUserDetailsService;
 import org.springframework.security.core.userdetails.UserDetails;
 import org.springframework.security.core.userdetails.UserDetailsService;
@@ -45,6 +46,9 @@ final class WithUserDetailsSecurityContextFactory implements WithSecurityContext
 	private static final boolean reactorPresent = ClassUtils.isPresent("reactor.core.publisher.Mono",
 			WithUserDetailsSecurityContextFactory.class.getClassLoader());
 
+	private SecurityContextHolderStrategy securityContextHolderStrategy = SecurityContextHolder
+			.getContextHolderStrategy();
+
 	private BeanFactory beans;
 
 	@Autowired
@@ -61,11 +65,16 @@ final class WithUserDetailsSecurityContextFactory implements WithSecurityContext
 		UserDetails principal = userDetailsService.loadUserByUsername(username);
 		Authentication authentication = UsernamePasswordAuthenticationToken.authenticated(principal,
 				principal.getPassword(), principal.getAuthorities());
-		SecurityContext context = SecurityContextHolder.createEmptyContext();
+		SecurityContext context = this.securityContextHolderStrategy.createEmptyContext();
 		context.setAuthentication(authentication);
 		return context;
 	}
 
+	@Autowired(required = false)
+	void setSecurityContextHolderStrategy(SecurityContextHolderStrategy securityContextHolderStrategy) {
+		this.securityContextHolderStrategy = securityContextHolderStrategy;
+	}
+
 	private UserDetailsService findUserDetailsService(String beanName) {
 		if (reactorPresent) {
 			UserDetailsService reactive = findAndAdaptReactiveUserDetailsService(beanName);

+ 23 - 7
test/src/main/java/org/springframework/security/test/web/servlet/request/SecurityMockMvcRequestPostProcessors.java

@@ -55,6 +55,7 @@ import org.springframework.security.core.authority.AuthorityUtils;
 import org.springframework.security.core.authority.SimpleGrantedAuthority;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.core.context.SecurityContextHolderStrategy;
 import org.springframework.security.core.userdetails.User;
 import org.springframework.security.core.userdetails.UserDetails;
 import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
@@ -85,6 +86,7 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut
 import org.springframework.security.oauth2.server.resource.authentication.JwtGrantedAuthoritiesConverter;
 import org.springframework.security.oauth2.server.resource.introspection.OAuth2IntrospectionAuthenticatedPrincipal;
 import org.springframework.security.test.context.TestSecurityContextHolder;
+import org.springframework.security.test.context.TestSecurityContextHolderStrategyAdapter;
 import org.springframework.security.test.web.servlet.setup.SecurityMockMvcConfigurers;
 import org.springframework.security.test.web.support.WebTestUtils;
 import org.springframework.security.web.context.HttpRequestResponseHolder;
@@ -115,6 +117,8 @@ import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandl
  */
 public final class SecurityMockMvcRequestPostProcessors {
 
+	private static final SecurityContextHolderStrategy DEFAULT_SECURITY_CONTEXT_HOLDER_STRATEGY = new TestSecurityContextHolderStrategyAdapter();
+
 	private SecurityMockMvcRequestPostProcessors() {
 	}
 
@@ -455,6 +459,18 @@ public final class SecurityMockMvcRequestPostProcessors {
 		return new OAuth2ClientRequestPostProcessor(registrationId);
 	}
 
+	private static SecurityContextHolderStrategy getSecurityContextHolderStrategy(HttpServletRequest request) {
+		WebApplicationContext context = WebApplicationContextUtils
+				.findWebApplicationContext(request.getServletContext());
+		if (context == null) {
+			return DEFAULT_SECURITY_CONTEXT_HOLDER_STRATEGY;
+		}
+		if (context.getBeanNamesForType(SecurityContextHolderStrategy.class).length == 0) {
+			return DEFAULT_SECURITY_CONTEXT_HOLDER_STRATEGY;
+		}
+		return context.getBean(SecurityContextHolderStrategy.class);
+	}
+
 	/**
 	 * Populates the X509Certificate instances onto the request
 	 */
@@ -710,7 +726,7 @@ public final class SecurityMockMvcRequestPostProcessors {
 		 * @param request the {@link HttpServletRequest} to use
 		 */
 		final void save(Authentication authentication, HttpServletRequest request) {
-			SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
+			SecurityContext securityContext = getSecurityContextHolderStrategy(request).createEmptyContext();
 			securityContext.setAuthentication(authentication);
 			save(securityContext, request);
 		}
@@ -790,8 +806,6 @@ public final class SecurityMockMvcRequestPostProcessors {
 	private static final class TestSecurityContextHolderPostProcessor extends SecurityContextRequestPostProcessorSupport
 			implements RequestPostProcessor {
 
-		private SecurityContext EMPTY = SecurityContextHolder.createEmptyContext();
-
 		@Override
 		public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
 			// TestSecurityContextHolder is only a default value
@@ -799,8 +813,10 @@ public final class SecurityMockMvcRequestPostProcessors {
 			if (existingContext != null) {
 				return request;
 			}
-			SecurityContext context = TestSecurityContextHolder.getContext();
-			if (!this.EMPTY.equals(context)) {
+			SecurityContextHolderStrategy strategy = getSecurityContextHolderStrategy(request);
+			SecurityContext empty = strategy.createEmptyContext();
+			SecurityContext context = strategy.getContext();
+			if (!empty.equals(context)) {
 				save(context, request);
 			}
 			return request;
@@ -851,7 +867,7 @@ public final class SecurityMockMvcRequestPostProcessors {
 
 		@Override
 		public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request) {
-			SecurityContext context = SecurityContextHolder.createEmptyContext();
+			SecurityContext context = getSecurityContextHolderStrategy(request).createEmptyContext();
 			context.setAuthentication(this.authentication);
 			save(this.authentication, request);
 			return request;
@@ -869,7 +885,7 @@ public final class SecurityMockMvcRequestPostProcessors {
 	 */
 	private static final class UserDetailsRequestPostProcessor implements RequestPostProcessor {
 
-		private final RequestPostProcessor delegate;
+		private final AuthenticationRequestPostProcessor delegate;
 
 		UserDetailsRequestPostProcessor(UserDetails user) {
 			Authentication token = UsernamePasswordAuthenticationToken.authenticated(user, user.getPassword(),