Jelajahi Sumber

Support errorOnInvalidType for Reactive AuthenticationPrincipal

Fixes: gh-5096
Rob Winch 7 tahun lalu
induk
melakukan
d21338d212

+ 3 - 2
docs/manual/src/docs/asciidoc/index.adoc

@@ -412,8 +412,9 @@ Below are the highlights of the release.
 For example, `@WithMockUser(setupBefore = TestExecutionEvent.TEST_EXECUTION)` will setup a user after JUnit's `@Before` and before the test executes.
 ** `@WithUserDetails` now works with `ReactiveUserDetailsService`
 * <<jackson>> - added support for `BadCredentialsException`
-* <<mvc-authentication-principal>> - Supports resolving beans in WebFlux (was already supported in Spring MVC).
-
+* <<mvc-authentication-principal>>
+** Supports resolving beans in WebFlux (was already supported in Spring MVC)
+** Supports resolving `errorOnInvalidType` in WebFlux (was already supported in Spring MVC)
 
 [[samples]]
 == Samples and Guides (Start Here)

+ 33 - 2
web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/AuthenticationPrincipalArgumentResolver.java

@@ -15,9 +15,13 @@
  */
 package org.springframework.security.web.reactive.result.method.annotation;
 
+import java.lang.annotation.Annotation;
+
+import org.reactivestreams.Publisher;
 import org.springframework.core.MethodParameter;
 import org.springframework.core.ReactiveAdapter;
 import org.springframework.core.ReactiveAdapterRegistry;
+import org.springframework.core.ResolvableType;
 import org.springframework.core.annotation.AnnotationUtils;
 import org.springframework.expression.BeanResolver;
 import org.springframework.expression.Expression;
@@ -30,9 +34,8 @@ import org.springframework.util.StringUtils;
 import org.springframework.web.reactive.BindingContext;
 import org.springframework.web.reactive.result.method.HandlerMethodArgumentResolverSupport;
 import org.springframework.web.server.ServerWebExchange;
-import reactor.core.publisher.Mono;
 
-import java.lang.annotation.Annotation;
+import reactor.core.publisher.Mono;
 
 /**
  * Resolves the Authentication
@@ -90,9 +93,37 @@ public class AuthenticationPrincipalArgumentResolver extends HandlerMethodArgume
 			principal = expression.getValue(context);
 		}
 
+		if (isInvalidType(parameter, principal)) {
+
+			if (authPrincipal.errorOnInvalidType()) {
+				throw new ClassCastException(principal + " is not assignable to "
+					+ parameter.getParameterType());
+			}
+			else {
+				return null;
+			}
+		}
+
 		return principal;
 	}
 
+	private boolean isInvalidType(MethodParameter parameter, Object principal) {
+		if (principal == null) {
+			return false;
+		}
+		Class<?> typeToCheck = parameter.getParameterType();
+		boolean isParameterPublisher = Publisher.class.isAssignableFrom(parameter.getParameterType());
+		if (isParameterPublisher) {
+			ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter);
+			Class<?> genericType = resolvableType.resolveGeneric(0);
+			if (genericType == null) {
+				return false;
+			}
+			typeToCheck = genericType;
+		}
+		return !typeToCheck.isAssignableFrom(principal.getClass());
+	}
+
 	/**
 	 * Obtains the specified {@link Annotation} on the specified {@link MethodParameter}.
 	 *

+ 52 - 0
web/src/test/java/org/springframework/security/web/reactive/result/method/annotation/AuthenticationPrincipalArgumentResolverTests.java

@@ -34,6 +34,7 @@ import reactor.core.publisher.Mono;
 import java.lang.annotation.*;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatThrownBy;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.Mockito.when;
@@ -121,6 +122,17 @@ public class AuthenticationPrincipalArgumentResolverTests {
 		assertThat(argument.cast(Mono.class).block().block()).isEqualTo(authentication.getPrincipal());
 	}
 
+	@Test
+	public void resolveArgumentWhenMonoIsAuthenticationAndNoGenericThenObtainsPrincipal() throws Exception {
+		MethodParameter parameter = ResolvableMethod.on(getClass()).named("authenticationPrincipalNoGeneric").build().arg(Mono.class);
+		when(authentication.getPrincipal()).thenReturn("user");
+		when(exchange.getPrincipal()).thenReturn(Mono.just(authentication));
+
+		Mono<Object> argument = resolver.resolveArgument(parameter, bindingContext, exchange);
+
+		assertThat(argument.cast(Mono.class).block().block()).isEqualTo(authentication.getPrincipal());
+	}
+
 	@Test
 	public void resolveArgumentWhenSpelThenObtainsPrincipal() throws Exception {
 		MyUser user = new MyUser(3L);
@@ -157,15 +169,55 @@ public class AuthenticationPrincipalArgumentResolverTests {
 		assertThat(argument.block()).isEqualTo("user");
 	}
 
+	@Test
+	public void resolveArgumentWhenErrorOnInvalidTypeImplicit() throws Exception {
+		MethodParameter parameter = ResolvableMethod.on(getClass()).named("errorOnInvalidTypeWhenImplicit").build().arg(Integer.class);
+		when(authentication.getPrincipal()).thenReturn("user");
+		when(exchange.getPrincipal()).thenReturn(Mono.just(authentication));
+
+		Mono<Object> argument = resolver.resolveArgument(parameter, bindingContext, exchange);
+
+		assertThat(argument.block()).isNull();
+	}
+
+	@Test
+	public void resolveArgumentWhenErrorOnInvalidTypeExplicitFalse() throws Exception {
+		MethodParameter parameter = ResolvableMethod.on(getClass()).named("errorOnInvalidTypeWhenExplicitFalse").build().arg(Integer.class);
+		when(authentication.getPrincipal()).thenReturn("user");
+		when(exchange.getPrincipal()).thenReturn(Mono.just(authentication));
+
+		Mono<Object> argument = resolver.resolveArgument(parameter, bindingContext, exchange);
+
+		assertThat(argument.block()).isNull();
+	}
+
+	@Test
+	public void resolveArgumentWhenErrorOnInvalidTypeExplicitTrue() throws Exception {
+		MethodParameter parameter = ResolvableMethod.on(getClass()).named("errorOnInvalidTypeWhenExplicitTrue").build().arg(Integer.class);
+		when(authentication.getPrincipal()).thenReturn("user");
+		when(exchange.getPrincipal()).thenReturn(Mono.just(authentication));
+
+		Mono<Object> argument = resolver.resolveArgument(parameter, bindingContext, exchange);
+
+		assertThatThrownBy(() -> argument.block()).isInstanceOf(ClassCastException.class);
+	}
 
 	void authenticationPrincipal(@AuthenticationPrincipal String principal, @AuthenticationPrincipal Mono<String> monoPrincipal) {}
 
+	void authenticationPrincipalNoGeneric(@AuthenticationPrincipal Mono monoPrincipal) {}
+
 	void spel(@AuthenticationPrincipal(expression = "id") Long id) {}
 
 	void bean(@AuthenticationPrincipal(expression = "@beanName.methodName(#this)") Long id) {}
 
 	void meta(@CurrentUser String principal) {}
 
+	void errorOnInvalidTypeWhenImplicit(@AuthenticationPrincipal Integer implicit) {}
+
+	void errorOnInvalidTypeWhenExplicitFalse(@AuthenticationPrincipal(errorOnInvalidType = false) Integer implicit) {}
+
+	void errorOnInvalidTypeWhenExplicitTrue(@AuthenticationPrincipal(errorOnInvalidType = true) Integer implicit) {}
+
 	static class Bean {
 		public Long methodName(MyUser user) {
 			return user.getId();