|
@@ -17,6 +17,7 @@ package org.springframework.security.test.context.support;
|
|
|
|
|
|
import java.lang.annotation.Annotation;
|
|
|
import java.lang.reflect.AnnotatedElement;
|
|
|
+import java.util.function.Supplier;
|
|
|
|
|
|
import org.springframework.beans.BeanUtils;
|
|
|
import org.springframework.core.GenericTypeResolver;
|
|
@@ -69,11 +70,12 @@ public class WithSecurityContextTestExecutionListener
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- SecurityContext securityContext = testSecurityContext.securityContext;
|
|
|
+ Supplier<SecurityContext> supplier = testSecurityContext
|
|
|
+ .getSecurityContextSupplier();
|
|
|
if (testSecurityContext.getTestExecutionEvent() == TestExecutionEvent.TEST_METHOD) {
|
|
|
- TestSecurityContextHolder.setContext(securityContext);
|
|
|
+ TestSecurityContextHolder.setContext(supplier.get());
|
|
|
} else {
|
|
|
- testContext.setAttribute(SECURITY_CONTEXT_ATTR_NAME, securityContext);
|
|
|
+ testContext.setAttribute(SECURITY_CONTEXT_ATTR_NAME, supplier);
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -83,9 +85,10 @@ public class WithSecurityContextTestExecutionListener
|
|
|
*/
|
|
|
@Override
|
|
|
public void beforeTestExecution(TestContext testContext) {
|
|
|
- SecurityContext securityContext = (SecurityContext) testContext.removeAttribute(SECURITY_CONTEXT_ATTR_NAME);
|
|
|
- if (securityContext != null) {
|
|
|
- TestSecurityContextHolder.setContext(securityContext);
|
|
|
+ Supplier<SecurityContext> supplier = (Supplier<SecurityContext>) testContext
|
|
|
+ .removeAttribute(SECURITY_CONTEXT_ATTR_NAME);
|
|
|
+ if (supplier != null) {
|
|
|
+ TestSecurityContextHolder.setContext(supplier.get());
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -118,14 +121,16 @@ public class WithSecurityContextTestExecutionListener
|
|
|
.resolveTypeArgument(factory.getClass(),
|
|
|
WithSecurityContextFactory.class);
|
|
|
Annotation annotation = findAnnotation(annotated, type);
|
|
|
+ Supplier<SecurityContext> supplier = () -> {
|
|
|
+ try {
|
|
|
+ return factory.createSecurityContext(annotation);
|
|
|
+ } catch (RuntimeException e) {
|
|
|
+ throw new IllegalStateException(
|
|
|
+ "Unable to create SecurityContext using " + annotation, e);
|
|
|
+ }
|
|
|
+ };
|
|
|
TestExecutionEvent initialize = withSecurityContext.setupBefore();
|
|
|
- try {
|
|
|
- return new TestSecurityContext(factory.createSecurityContext(annotation), initialize);
|
|
|
- }
|
|
|
- catch (RuntimeException e) {
|
|
|
- throw new IllegalStateException(
|
|
|
- "Unable to create SecurityContext using " + annotation, e);
|
|
|
- }
|
|
|
+ return new TestSecurityContext(supplier, initialize);
|
|
|
}
|
|
|
|
|
|
private Annotation findAnnotation(AnnotatedElement annotated,
|
|
@@ -179,16 +184,17 @@ public class WithSecurityContextTestExecutionListener
|
|
|
}
|
|
|
|
|
|
static class TestSecurityContext {
|
|
|
- private final SecurityContext securityContext;
|
|
|
+ private final Supplier<SecurityContext> securityContextSupplier;
|
|
|
private final TestExecutionEvent testExecutionEvent;
|
|
|
|
|
|
- TestSecurityContext(SecurityContext securityContext, TestExecutionEvent testExecutionEvent) {
|
|
|
- this.securityContext = securityContext;
|
|
|
+ TestSecurityContext(Supplier<SecurityContext> securityContextSupplier,
|
|
|
+ TestExecutionEvent testExecutionEvent) {
|
|
|
+ this.securityContextSupplier = securityContextSupplier;
|
|
|
this.testExecutionEvent = testExecutionEvent;
|
|
|
}
|
|
|
|
|
|
- public SecurityContext getSecurityContext() {
|
|
|
- return this.securityContext;
|
|
|
+ public Supplier<SecurityContext> getSecurityContextSupplier() {
|
|
|
+ return this.securityContextSupplier;
|
|
|
}
|
|
|
|
|
|
public TestExecutionEvent getTestExecutionEvent() {
|