Parcourir la source

Fix SecurityContext creation for TEST_EXECUTION

Currently, there is support for setting up a SecurityContext after @Before by
using TestExecutionEvent.TEST_EXECUTION. The current implementation, however,
already creates the SecurityContext in @Before and just does not set it yet.
This leads to issues like #6591. For the case of @WithUserDetails, the
creation of the SecurityContext already looks up a user from the repository.
If the user was inserted in @Before, the user is not found despite using
TestExecutionEvent.TEST_EXECUTION. This commit changes the creation of the
SecurityContext to happen after @Before if using
TestExecutionEvent.TEST_EXECUTION.

Closes gh-6591
Markus Gabriel il y a 5 ans
Parent
commit
97ee6d66f1

+ 24 - 18
test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java

@@ -17,6 +17,7 @@ package org.springframework.security.test.context.support;
 
 import java.lang.annotation.Annotation;
 import java.lang.reflect.AnnotatedElement;
+import java.util.function.Supplier;
 
 import org.springframework.beans.BeanUtils;
 import org.springframework.core.GenericTypeResolver;
@@ -69,11 +70,12 @@ public class WithSecurityContextTestExecutionListener
 			return;
 		}
 
-		SecurityContext securityContext = testSecurityContext.securityContext;
+		Supplier<SecurityContext> supplier = testSecurityContext
+				.getSecurityContextSupplier();
 		if (testSecurityContext.getTestExecutionEvent() == TestExecutionEvent.TEST_METHOD) {
-			TestSecurityContextHolder.setContext(securityContext);
+			TestSecurityContextHolder.setContext(supplier.get());
 		} else {
-			testContext.setAttribute(SECURITY_CONTEXT_ATTR_NAME, securityContext);
+			testContext.setAttribute(SECURITY_CONTEXT_ATTR_NAME, supplier);
 		}
 	}
 
@@ -83,9 +85,10 @@ public class WithSecurityContextTestExecutionListener
 	 */
 	@Override
 	public void beforeTestExecution(TestContext testContext) {
-		SecurityContext securityContext = (SecurityContext) testContext.removeAttribute(SECURITY_CONTEXT_ATTR_NAME);
-		if (securityContext != null) {
-			TestSecurityContextHolder.setContext(securityContext);
+		Supplier<SecurityContext> supplier = (Supplier<SecurityContext>) testContext
+				.removeAttribute(SECURITY_CONTEXT_ATTR_NAME);
+		if (supplier != null) {
+			TestSecurityContextHolder.setContext(supplier.get());
 		}
 	}
 
@@ -118,14 +121,16 @@ public class WithSecurityContextTestExecutionListener
 				.resolveTypeArgument(factory.getClass(),
 						WithSecurityContextFactory.class);
 		Annotation annotation = findAnnotation(annotated, type);
+		Supplier<SecurityContext> supplier = () -> {
+			try {
+				return factory.createSecurityContext(annotation);
+			} catch (RuntimeException e) {
+				throw new IllegalStateException(
+						"Unable to create SecurityContext using " + annotation, e);
+			}
+		};
 		TestExecutionEvent initialize = withSecurityContext.setupBefore();
-		try {
-			return new TestSecurityContext(factory.createSecurityContext(annotation), initialize);
-		}
-		catch (RuntimeException e) {
-			throw new IllegalStateException(
-					"Unable to create SecurityContext using " + annotation, e);
-		}
+		return new TestSecurityContext(supplier, initialize);
 	}
 
 	private Annotation findAnnotation(AnnotatedElement annotated,
@@ -179,16 +184,17 @@ public class WithSecurityContextTestExecutionListener
 	}
 
 	static class TestSecurityContext {
-		private final SecurityContext securityContext;
+		private final Supplier<SecurityContext> securityContextSupplier;
 		private final TestExecutionEvent testExecutionEvent;
 
-		TestSecurityContext(SecurityContext securityContext, TestExecutionEvent testExecutionEvent) {
-			this.securityContext = securityContext;
+		TestSecurityContext(Supplier<SecurityContext> securityContextSupplier,
+				TestExecutionEvent testExecutionEvent) {
+			this.securityContextSupplier = securityContextSupplier;
 			this.testExecutionEvent = testExecutionEvent;
 		}
 
-		public SecurityContext getSecurityContext() {
-			return this.securityContext;
+		public Supplier<SecurityContext> getSecurityContextSupplier() {
+			return this.securityContextSupplier;
 		}
 
 		public TestExecutionEvent getTestExecutionEvent() {

+ 22 - 2
test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListenerTests.java

@@ -21,6 +21,8 @@ import org.junit.ClassRule;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.runner.RunWith;
+import org.mockito.ArgumentCaptor;
+import org.mockito.ArgumentMatchers;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
 import org.springframework.beans.factory.annotation.Autowired;
@@ -36,6 +38,7 @@ import org.springframework.test.context.junit4.rules.SpringClassRule;
 import org.springframework.test.context.junit4.rules.SpringMethodRule;
 
 import java.lang.reflect.Method;
+import java.util.function.Supplier;
 
 import static org.assertj.core.api.Assertions.*;
 import static org.mockito.ArgumentMatchers.any;
@@ -102,7 +105,23 @@ public class WithSecurityContextTestExecutionListenerTests {
 		this.listener.beforeTestMethod(this.testContext);
 
 		assertThat(TestSecurityContextHolder.getContext().getAuthentication()).isNull();
-		verify(this.testContext).setAttribute(eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME), any(SecurityContext.class));
+		verify(this.testContext).setAttribute(eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME)
+				, ArgumentMatchers.<Supplier<SecurityContext>>any());
+	}
+
+	@Test
+	@SuppressWarnings("unchecked")
+	public void beforeTestMethodWhenWithMockUserTestExecutionThenTestContextSupplierOk() throws Exception {
+		Method testMethod = TheTest.class.getMethod("withMockUserTestExecution");
+		when(this.testContext.getApplicationContext()).thenReturn(this.applicationContext);
+		when(this.testContext.getTestMethod()).thenReturn(testMethod);
+
+		this.listener.beforeTestMethod(this.testContext);
+
+		ArgumentCaptor<Supplier<SecurityContext>> supplierCaptor = ArgumentCaptor.forClass(Supplier.class);
+		verify(this.testContext).setAttribute(eq(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME),
+				supplierCaptor.capture());
+		assertThat(supplierCaptor.getValue().get().getAuthentication()).isNotNull();
 	}
 
 	@Test
@@ -116,7 +135,8 @@ public class WithSecurityContextTestExecutionListenerTests {
 	public void beforeTestExecutionWhenTestContextNotNullThenSecurityContextSet() {
 		SecurityContextImpl securityContext = new SecurityContextImpl();
 		securityContext.setAuthentication(new TestingAuthenticationToken("user", "passsword", "ROLE_USER"));
-		when(this.testContext.removeAttribute(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME)).thenReturn(securityContext);
+		Supplier<SecurityContext> supplier = () -> securityContext;
+		when(this.testContext.removeAttribute(WithSecurityContextTestExecutionListener.SECURITY_CONTEXT_ATTR_NAME)).thenReturn(supplier);
 
 		this.listener.beforeTestExecution(this.testContext);