Jelajahi Sumber

Support @ClientRegistrationId at Class Level
Closes gh-17806

Signed-off-by: Bernard Budano <bbudano@gmail.com>

Bernard Budano 4 minggu lalu
induk
melakukan
8e3cf9677c

+ 1 - 1
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/annotation/ClientRegistrationId.java

@@ -33,7 +33,7 @@ import org.springframework.core.annotation.AliasFor;
  * @since 7.0
  * @see org.springframework.security.oauth2.client.web.client.ClientRegistrationIdProcessor
  */
-@Target({ ElementType.METHOD, ElementType.ANNOTATION_TYPE })
+@Target({ ElementType.METHOD, ElementType.TYPE })
 @Retention(RetentionPolicy.RUNTIME)
 @Documented
 public @interface ClientRegistrationId {

+ 8 - 4
oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/client/ClientRegistrationIdProcessor.java

@@ -17,6 +17,7 @@
 package org.springframework.security.oauth2.client.web.client;
 
 import java.lang.reflect.Method;
+import java.util.Optional;
 
 import org.jspecify.annotations.Nullable;
 
@@ -37,17 +38,20 @@ public final class ClientRegistrationIdProcessor implements HttpRequestValues.Pr
 
 	public static ClientRegistrationIdProcessor DEFAULT_INSTANCE = new ClientRegistrationIdProcessor();
 
+	private ClientRegistrationIdProcessor() {
+	}
+
 	@Override
 	public void process(Method method, MethodParameter[] parameters, @Nullable Object[] arguments,
 			HttpRequestValues.Builder builder) {
-		ClientRegistrationId registeredId = AnnotationUtils.findAnnotation(method, ClientRegistrationId.class);
+		ClientRegistrationId registeredId = Optional
+			.ofNullable(AnnotationUtils.findAnnotation(method, ClientRegistrationId.class))
+			.orElseGet(() -> AnnotationUtils.findAnnotation(method.getDeclaringClass(), ClientRegistrationId.class));
+
 		if (registeredId != null) {
 			String registrationId = registeredId.registrationId();
 			builder.configureAttributes(ClientAttributes.clientRegistrationId(registrationId));
 		}
 	}
 
-	private ClientRegistrationIdProcessor() {
-	}
-
 }

+ 28 - 9
oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/client/ClientRegistrationIdProcessorTests.java

@@ -39,6 +39,8 @@ import static org.assertj.core.api.Assertions.assertThat;
  */
 class ClientRegistrationIdProcessorTests {
 
+	private static final String REGISTRATION_ID = "registrationId";
+
 	ClientRegistrationIdProcessor processor = ClientRegistrationIdProcessor.DEFAULT_INSTANCE;
 
 	@Test
@@ -48,32 +50,42 @@ class ClientRegistrationIdProcessorTests {
 		this.processor.process(hasClientRegistrationId, null, null, builder);
 
 		String registrationId = ClientAttributes.resolveClientRegistrationId(builder.build().getAttributes());
-		assertThat(registrationId).isEqualTo(RestService.REGISTRATION_ID);
+		assertThat(registrationId).isEqualTo(REGISTRATION_ID);
 	}
 
 	@Test
 	void processWhenMetaClientRegistrationIdPresentThenSet() {
 		HttpRequestValues.Builder builder = HttpRequestValues.builder();
-		Method hasClientRegistrationId = ReflectionUtils.findMethod(RestService.class, "hasMetaClientRegistrationId");
-		this.processor.process(hasClientRegistrationId, null, null, builder);
+		Method hasMetaClientRegistrationId = ReflectionUtils.findMethod(RestService.class,
+				"hasMetaClientRegistrationId");
+		this.processor.process(hasMetaClientRegistrationId, null, null, builder);
 
 		String registrationId = ClientAttributes.resolveClientRegistrationId(builder.build().getAttributes());
-		assertThat(registrationId).isEqualTo(RestService.REGISTRATION_ID);
+		assertThat(registrationId).isEqualTo(REGISTRATION_ID);
 	}
 
 	@Test
 	void processWhenNoClientRegistrationIdPresentThenNull() {
 		HttpRequestValues.Builder builder = HttpRequestValues.builder();
-		Method hasClientRegistrationId = ReflectionUtils.findMethod(RestService.class, "noClientRegistrationId");
-		this.processor.process(hasClientRegistrationId, null, null, builder);
+		Method noClientRegistrationId = ReflectionUtils.findMethod(RestService.class, "noClientRegistrationId");
+		this.processor.process(noClientRegistrationId, null, null, builder);
 
 		String registrationId = ClientAttributes.resolveClientRegistrationId(builder.build().getAttributes());
 		assertThat(registrationId).isNull();
 	}
 
-	interface RestService {
+	@Test
+	void processWhenClientRegistrationIdPresentOnDeclaringClassThenSet() {
+		HttpRequestValues.Builder builder = HttpRequestValues.builder();
+		Method declaringClassHasClientRegistrationId = ReflectionUtils.findMethod(AnnotatedRestService.class,
+				"declaringClassHasClientRegistrationId");
+		this.processor.process(declaringClassHasClientRegistrationId, null, null, builder);
+
+		String registrationId = ClientAttributes.resolveClientRegistrationId(builder.build().getAttributes());
+		assertThat(registrationId).isEqualTo(REGISTRATION_ID);
+	}
 
-		String REGISTRATION_ID = "registrationId";
+	interface RestService {
 
 		@ClientRegistrationId(REGISTRATION_ID)
 		void hasClientRegistrationId();
@@ -86,9 +98,16 @@ class ClientRegistrationIdProcessorTests {
 	}
 
 	@Retention(RetentionPolicy.RUNTIME)
-	@ClientRegistrationId(RestService.REGISTRATION_ID)
+	@ClientRegistrationId(REGISTRATION_ID)
 	@interface MetaClientRegistrationId {
 
 	}
 
+	@ClientRegistrationId(REGISTRATION_ID)
+	interface AnnotatedRestService {
+
+		void declaringClassHasClientRegistrationId();
+
+	}
+
 }