Pārlūkot izejas kodu

Register hints for @WithSecurityContext on class level

Issue gh-12215
Marcus Da Coregio 2 gadi atpakaļ
vecāks
revīzija
1648151dd2

+ 11 - 2
test/src/main/java/org/springframework/security/test/aot/hint/WithSecurityContextTestRuntimeHints.java

@@ -17,6 +17,7 @@
 package org.springframework.security.test.aot.hint;
 
 import java.util.Arrays;
+import java.util.stream.Stream;
 
 import org.springframework.aot.hint.MemberCategory;
 import org.springframework.aot.hint.RuntimeHints;
@@ -38,13 +39,21 @@ class WithSecurityContextTestRuntimeHints implements TestRuntimeHintsRegistrar {
 
 	@Override
 	public void registerHints(RuntimeHints hints, Class<?> testClass, ClassLoader classLoader) {
-		Arrays.stream(testClass.getDeclaredMethods())
-				.map((method) -> MergedAnnotations.from(method, SUPERCLASS).get(WithSecurityContext.class))
+		Stream.concat(getClassAnnotations(testClass), getMethodAnnotations(testClass))
 				.filter(MergedAnnotation::isPresent)
 				.map((withSecurityContext) -> withSecurityContext.getClass("factory"))
 				.forEach((factory) -> registerDeclaredConstructors(hints, factory));
 	}
 
+	private Stream<MergedAnnotation<WithSecurityContext>> getClassAnnotations(Class<?> testClass) {
+		return MergedAnnotations.search(SUPERCLASS).from(testClass).stream(WithSecurityContext.class);
+	}
+
+	private Stream<MergedAnnotation<WithSecurityContext>> getMethodAnnotations(Class<?> testClass) {
+		return Arrays.stream(testClass.getDeclaredMethods())
+				.map((method) -> MergedAnnotations.from(method, SUPERCLASS).get(WithSecurityContext.class));
+	}
+
 	private void registerDeclaredConstructors(RuntimeHints hints, Class<?> factory) {
 		hints.reflection().registerType(factory, MemberCategory.INVOKE_DECLARED_CONSTRUCTORS);
 	}

+ 9 - 0
test/src/test/java/org/springframework/security/test/aot/hint/WithSecurityContextTestRuntimeHintsTests.java

@@ -28,6 +28,8 @@ import org.springframework.aot.hint.TypeReference;
 import org.springframework.aot.hint.predicate.RuntimeHintsPredicates;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
+import org.springframework.security.test.context.showcase.WithMockCustomUser;
+import org.springframework.security.test.context.showcase.WithMockCustomUserSecurityContextFactory;
 import org.springframework.security.test.context.support.WithAnonymousUser;
 import org.springframework.security.test.context.support.WithMockUser;
 import org.springframework.security.test.context.support.WithSecurityContext;
@@ -39,6 +41,7 @@ import static org.assertj.core.api.Assertions.assertThat;
 /**
  * Tests for {@link WithSecurityContextTestRuntimeHints}.
  */
+@WithMockCustomUser
 class WithSecurityContextTestRuntimeHintsTests {
 
 	private final RuntimeHints hints = new RuntimeHints();
@@ -85,6 +88,12 @@ class WithSecurityContextTestRuntimeHintsTests {
 				.withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.hints);
 	}
 
+	@Test
+	void withMockCustomUserOnClassHasHints() {
+		assertThat(RuntimeHintsPredicates.reflection().onType(WithMockCustomUserSecurityContextFactory.class)
+				.withMemberCategory(MemberCategory.INVOKE_DECLARED_CONSTRUCTORS)).accepts(this.hints);
+	}
+
 	@Retention(RetentionPolicy.RUNTIME)
 	@WithSecurityContext(factory = WithMockTestUserSecurityContextFactory.class)
 	@interface WithMockTestUser {

+ 4 - 0
test/src/test/java/org/springframework/security/test/context/showcase/WithMockCustomUser.java

@@ -16,11 +16,15 @@
 
 package org.springframework.security.test.context.showcase;
 
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+
 import org.springframework.security.test.context.support.WithSecurityContext;
 
 /**
  * @author Rob Winch
  */
+@Retention(RetentionPolicy.RUNTIME)
 @WithSecurityContext(factory = WithMockCustomUserSecurityContextFactory.class)
 public @interface WithMockCustomUser {