浏览代码

Add argument resolver for SecurityContext

Closes gh-13425
Nermin Karapandzic 1 年之前
父节点
当前提交
6e1bcfed11

+ 23 - 1
messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java

@@ -118,7 +118,21 @@ public class CurrentSecurityContextArgumentResolver implements HandlerMethodArgu
 
 	@Override
 	public boolean supportsParameter(MethodParameter parameter) {
-		return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null;
+		return isMonoSecurityContext(parameter)
+				|| findMethodAnnotation(CurrentSecurityContext.class, parameter) != null;
+	}
+
+	private boolean isMonoSecurityContext(MethodParameter parameter) {
+		boolean isParameterPublisher = Publisher.class.isAssignableFrom(parameter.getParameterType());
+		if (isParameterPublisher) {
+			ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter);
+			Class<?> genericType = resolvableType.resolveGeneric(0);
+			if (genericType == null) {
+				return false;
+			}
+			return SecurityContext.class.isAssignableFrom(genericType);
+		}
+		return false;
 	}
 
 	@Override
@@ -136,6 +150,14 @@ public class CurrentSecurityContextArgumentResolver implements HandlerMethodArgu
 
 	private Object resolveSecurityContext(MethodParameter parameter, Object securityContext) {
 		CurrentSecurityContext contextAnno = findMethodAnnotation(CurrentSecurityContext.class, parameter);
+		if (contextAnno != null) {
+			return resolveSecurityContextFromAnnotation(contextAnno, parameter, securityContext);
+		}
+		return securityContext;
+	}
+
+	private Object resolveSecurityContextFromAnnotation(CurrentSecurityContext contextAnno, MethodParameter parameter,
+			Object securityContext) {
 		String expressionToParse = contextAnno.expression();
 		if (StringUtils.hasLength(expressionToParse)) {
 			StandardEvaluationContext context = new StandardEvaluationContext();

+ 81 - 0
messaging/src/test/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolverTests.java

@@ -46,6 +46,24 @@ public class CurrentSecurityContextArgumentResolverTests {
 		assertThat(this.resolver.supportsParameter(arg0("currentSecurityContextOnMonoSecurityContext"))).isTrue();
 	}
 
+	@Test
+	public void supportsParameterWhenMonoSecurityContextNoAnnotationThenTrue() {
+		assertThat(this.resolver.supportsParameter(arg0("currentSecurityContextOnMonoSecurityContextNoAnnotation")))
+			.isTrue();
+	}
+
+	@Test
+	public void supportsParameterWhenMonoCustomSecurityContextNoAnnotationThenTrue() {
+		assertThat(
+				this.resolver.supportsParameter(arg0("currentCustomSecurityContextOnMonoSecurityContextNoAnnotation")))
+			.isTrue();
+	}
+
+	@Test
+	public void supportsParameterWhenNoSecurityContextNoAnnotationThenFalse() {
+		assertThat(this.resolver.supportsParameter(arg0("currentSecurityContextOnMonoStringNoAnnotation"))).isFalse();
+	}
+
 	@Test
 	public void resolveArgumentWhenAuthenticationPrincipalAndEmptyContextThenNull() {
 		Object result = this.resolver.resolveArgument(arg0("currentSecurityContextOnMonoSecurityContext"), null)
@@ -67,6 +85,18 @@ public class CurrentSecurityContextArgumentResolverTests {
 	private void currentSecurityContextOnMonoSecurityContext(@CurrentSecurityContext Mono<SecurityContext> context) {
 	}
 
+	@SuppressWarnings("unused")
+	private void currentSecurityContextOnMonoSecurityContextNoAnnotation(Mono<SecurityContext> context) {
+	}
+
+	@SuppressWarnings("unused")
+	private void currentCustomSecurityContextOnMonoSecurityContextNoAnnotation(Mono<CustomSecurityContext> context) {
+	}
+
+	@SuppressWarnings("unused")
+	private void currentSecurityContextOnMonoStringNoAnnotation(Mono<String> context) {
+	}
+
 	@Test
 	public void supportsParameterWhenCurrentUserThenTrue() {
 		assertThat(this.resolver.supportsParameter(arg0("currentUserOnMonoUserDetails"))).isTrue();
@@ -110,6 +140,41 @@ public class CurrentSecurityContextArgumentResolverTests {
 	private void monoUserDetails(Mono<UserDetails> user) {
 	}
 
+	@Test
+	public void supportsParameterWhenSecurityContextNotAnnotatedThenTrue() {
+		assertThat(this.resolver.supportsParameter(arg0("monoSecurityContext"))).isTrue();
+	}
+
+	@Test
+	public void resolveArgumentWhenMonoSecurityContextNoAnnotationThenFound() {
+		Authentication authentication = TestAuthentication.authenticatedUser();
+		Mono<SecurityContext> result = (Mono<SecurityContext>) this.resolver
+			.resolveArgument(arg0("monoSecurityContext"), null)
+			.contextWrite(ReactiveSecurityContextHolder.withAuthentication(authentication))
+			.block();
+		assertThat(result.block().getAuthentication().getPrincipal()).isEqualTo(authentication.getPrincipal());
+	}
+
+	@SuppressWarnings("unused")
+	private void monoSecurityContext(Mono<SecurityContext> securityContext) {
+	}
+
+	@Test
+	public void resolveArgumentWhenMonoCustomSecurityContextNoAnnotationThenFound() {
+		Authentication authentication = TestAuthentication.authenticatedUser();
+		CustomSecurityContext securityContext = new CustomSecurityContext();
+		securityContext.setAuthentication(authentication);
+		Mono<CustomSecurityContext> result = (Mono<CustomSecurityContext>) this.resolver
+			.resolveArgument(arg0("monoCustomSecurityContext"), null)
+			.contextWrite(ReactiveSecurityContextHolder.withSecurityContext(Mono.just(securityContext)))
+			.block();
+		assertThat(result.block().getAuthentication().getPrincipal()).isEqualTo(authentication.getPrincipal());
+	}
+
+	@SuppressWarnings("unused")
+	private void monoCustomSecurityContext(Mono<CustomSecurityContext> securityContext) {
+	}
+
 	private MethodParameter arg0(String methodName) {
 		ResolvableMethod method = ResolvableMethod.on(getClass()).named(methodName).method();
 		return new SynthesizingMethodParameter(method.method(), 0);
@@ -121,4 +186,20 @@ public class CurrentSecurityContextArgumentResolverTests {
 
 	}
 
+	static class CustomSecurityContext implements SecurityContext {
+
+		private Authentication authentication;
+
+		@Override
+		public Authentication getAuthentication() {
+			return this.authentication;
+		}
+
+		@Override
+		public void setAuthentication(Authentication authentication) {
+			this.authentication = authentication;
+		}
+
+	}
+
 }

+ 29 - 19
web/src/main/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolver.java

@@ -85,7 +85,8 @@ public final class CurrentSecurityContextArgumentResolver implements HandlerMeth
 
 	@Override
 	public boolean supportsParameter(MethodParameter parameter) {
-		return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null;
+		return SecurityContext.class.isAssignableFrom(parameter.getParameterType())
+				|| findMethodAnnotation(CurrentSecurityContext.class, parameter) != null;
 	}
 
 	@Override
@@ -95,26 +96,12 @@ public final class CurrentSecurityContextArgumentResolver implements HandlerMeth
 		if (securityContext == null) {
 			return null;
 		}
-		Object securityContextResult = securityContext;
 		CurrentSecurityContext annotation = findMethodAnnotation(CurrentSecurityContext.class, parameter);
-		String expressionToParse = annotation.expression();
-		if (StringUtils.hasLength(expressionToParse)) {
-			StandardEvaluationContext context = new StandardEvaluationContext();
-			context.setRootObject(securityContext);
-			context.setVariable("this", securityContext);
-			context.setBeanResolver(this.beanResolver);
-			Expression expression = this.parser.parseExpression(expressionToParse);
-			securityContextResult = expression.getValue(context);
-		}
-		if (securityContextResult != null
-				&& !parameter.getParameterType().isAssignableFrom(securityContextResult.getClass())) {
-			if (annotation.errorOnInvalidType()) {
-				throw new ClassCastException(
-						securityContextResult + " is not assignable to " + parameter.getParameterType());
-			}
-			return null;
+		if (annotation != null) {
+			return resolveSecurityContextFromAnnotation(parameter, annotation, securityContext);
 		}
-		return securityContextResult;
+
+		return securityContext;
 	}
 
 	/**
@@ -137,6 +124,29 @@ public final class CurrentSecurityContextArgumentResolver implements HandlerMeth
 		this.beanResolver = beanResolver;
 	}
 
+	private Object resolveSecurityContextFromAnnotation(MethodParameter parameter, CurrentSecurityContext annotation,
+			SecurityContext securityContext) {
+		Object securityContextResult = securityContext;
+		String expressionToParse = annotation.expression();
+		if (StringUtils.hasLength(expressionToParse)) {
+			StandardEvaluationContext context = new StandardEvaluationContext();
+			context.setRootObject(securityContext);
+			context.setVariable("this", securityContext);
+			context.setBeanResolver(this.beanResolver);
+			Expression expression = this.parser.parseExpression(expressionToParse);
+			securityContextResult = expression.getValue(context);
+		}
+		if (securityContextResult != null
+				&& !parameter.getParameterType().isAssignableFrom(securityContextResult.getClass())) {
+			if (annotation.errorOnInvalidType()) {
+				throw new ClassCastException(
+						securityContextResult + " is not assignable to " + parameter.getParameterType());
+			}
+			return null;
+		}
+		return securityContextResult;
+	}
+
 	/**
 	 * Obtain the specified {@link Annotation} on the specified {@link MethodParameter}.
 	 * @param annotationClass the class of the {@link Annotation} to find on the

+ 23 - 1
web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java

@@ -67,7 +67,21 @@ public class CurrentSecurityContextArgumentResolver extends HandlerMethodArgumen
 
 	@Override
 	public boolean supportsParameter(MethodParameter parameter) {
-		return findMethodAnnotation(CurrentSecurityContext.class, parameter) != null;
+		return isMonoSecurityContext(parameter)
+				|| findMethodAnnotation(CurrentSecurityContext.class, parameter) != null;
+	}
+
+	private boolean isMonoSecurityContext(MethodParameter parameter) {
+		boolean isParameterPublisher = Publisher.class.isAssignableFrom(parameter.getParameterType());
+		if (isParameterPublisher) {
+			ResolvableType resolvableType = ResolvableType.forMethodParameter(parameter);
+			Class<?> genericType = resolvableType.resolveGeneric(0);
+			if (genericType == null) {
+				return false;
+			}
+			return SecurityContext.class.isAssignableFrom(genericType);
+		}
+		return false;
 	}
 
 	@Override
@@ -95,6 +109,14 @@ public class CurrentSecurityContextArgumentResolver extends HandlerMethodArgumen
 	 */
 	private Object resolveSecurityContext(MethodParameter parameter, SecurityContext securityContext) {
 		CurrentSecurityContext annotation = findMethodAnnotation(CurrentSecurityContext.class, parameter);
+		if (annotation != null) {
+			return resolveSecurityContextFromAnnotation(annotation, parameter, securityContext);
+		}
+		return securityContext;
+	}
+
+	private Object resolveSecurityContextFromAnnotation(CurrentSecurityContext annotation, MethodParameter parameter,
+			Object securityContext) {
 		Object securityContextResult = securityContext;
 		String expressionToParse = annotation.expression();
 		if (StringUtils.hasLength(expressionToParse)) {

+ 52 - 2
web/src/test/java/org/springframework/security/web/method/annotation/CurrentSecurityContextArgumentResolverTests.java

@@ -69,9 +69,26 @@ public class CurrentSecurityContextArgumentResolverTests {
 		SecurityContextHolder.clearContext();
 	}
 
+	@Test
+	public void supportsParameterNoAnnotationWrongType() {
+		assertThat(this.resolver.supportsParameter(showSecurityContextNoAnnotationTypeMismatch())).isFalse();
+	}
+
 	@Test
 	public void supportsParameterNoAnnotation() {
-		assertThat(this.resolver.supportsParameter(showSecurityContextNoAnnotation())).isFalse();
+		assertThat(this.resolver.supportsParameter(showSecurityContextNoAnnotation())).isTrue();
+	}
+
+	@Test
+	public void supportsParameterCustomSecurityContextNoAnnotation() {
+		assertThat(this.resolver.supportsParameter(showSecurityContextWithCustomSecurityContextNoAnnotation()))
+			.isTrue();
+	}
+
+	@Test
+	public void supportsParameterNoAnnotationCustomType() {
+		assertThat(this.resolver.supportsParameter(showSecurityContextWithCustomSecurityContextNoAnnotation()))
+			.isTrue();
 	}
 
 	@Test
@@ -88,6 +105,24 @@ public class CurrentSecurityContextArgumentResolverTests {
 		assertThat(customSecurityContext.getAuthentication().getPrincipal()).isEqualTo(principal);
 	}
 
+	@Test
+	public void resolveArgumentWithCustomSecurityContextNoAnnotation() {
+		String principal = "custom_security_context";
+		setAuthenticationPrincipalWithCustomSecurityContext(principal);
+		CustomSecurityContext customSecurityContext = (CustomSecurityContext) this.resolver
+			.resolveArgument(showSecurityContextWithCustomSecurityContextNoAnnotation(), null, null, null);
+		assertThat(customSecurityContext.getAuthentication().getPrincipal()).isEqualTo(principal);
+	}
+
+	@Test
+	public void resolveArgumentWithNoAnnotation() {
+		String principal = "custom_security_context";
+		setAuthenticationPrincipal(principal);
+		SecurityContext securityContext = (SecurityContext) this.resolver
+			.resolveArgument(showSecurityContextNoAnnotation(), null, null, null);
+		assertThat(securityContext.getAuthentication().getPrincipal()).isEqualTo(principal);
+	}
+
 	@Test
 	public void resolveArgumentWithCustomSecurityContextTypeMatch() {
 		String principal = "custom_security_context_type_match";
@@ -212,10 +247,14 @@ public class CurrentSecurityContextArgumentResolverTests {
 			.resolveArgument(showCurrentSecurityWithErrorOnInvalidTypeMisMatch(), null, null, null));
 	}
 
-	private MethodParameter showSecurityContextNoAnnotation() {
+	private MethodParameter showSecurityContextNoAnnotationTypeMismatch() {
 		return getMethodParameter("showSecurityContextNoAnnotation", String.class);
 	}
 
+	private MethodParameter showSecurityContextNoAnnotation() {
+		return getMethodParameter("showSecurityContextNoAnnotation", SecurityContext.class);
+	}
+
 	private MethodParameter showSecurityContextAnnotation() {
 		return getMethodParameter("showSecurityContextAnnotation", SecurityContext.class);
 	}
@@ -276,6 +315,11 @@ public class CurrentSecurityContextArgumentResolverTests {
 		return getMethodParameter("showCurrentSecurityWithErrorOnInvalidTypeMisMatch", String.class);
 	}
 
+	public MethodParameter showSecurityContextWithCustomSecurityContextNoAnnotation() {
+		return getMethodParameter("showSecurityContextWithCustomSecurityContextNoAnnotation",
+				CustomSecurityContext.class);
+	}
+
 	private MethodParameter getMethodParameter(String methodName, Class<?>... paramTypes) {
 		Method method = ReflectionUtils.findMethod(TestController.class, methodName, paramTypes);
 		return new MethodParameter(method, 0);
@@ -358,6 +402,12 @@ public class CurrentSecurityContextArgumentResolverTests {
 				@CurrentSecurityWithErrorOnInvalidType String typeMisMatch) {
 		}
 
+		public void showSecurityContextNoAnnotation(SecurityContext context) {
+		}
+
+		public void showSecurityContextWithCustomSecurityContextNoAnnotation(CustomSecurityContext context) {
+		}
+
 	}
 
 	static class CustomSecurityContext implements SecurityContext {

+ 61 - 0
web/src/test/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolverTests.java

@@ -69,6 +69,14 @@ public class CurrentSecurityContextArgumentResolverTests {
 
 	ResolvableMethod securityContextMethod = ResolvableMethod.on(getClass()).named("securityContext").build();
 
+	ResolvableMethod securityContextNoAnnotationMethod = ResolvableMethod.on(getClass())
+		.named("securityContextNoAnnotation")
+		.build();
+
+	ResolvableMethod customSecurityContextNoAnnotationMethod = ResolvableMethod.on(getClass())
+		.named("customSecurityContextNoAnnotation")
+		.build();
+
 	ResolvableMethod securityContextWithAuthentication = ResolvableMethod.on(getClass())
 		.named("securityContextWithAuthentication")
 		.build();
@@ -87,6 +95,19 @@ public class CurrentSecurityContextArgumentResolverTests {
 			.isTrue();
 	}
 
+	@Test
+	public void supportsParameterCurrentSecurityContextNoAnnotation() {
+		assertThat(this.resolver
+			.supportsParameter(this.securityContextNoAnnotationMethod.arg(Mono.class, SecurityContext.class))).isTrue();
+	}
+
+	@Test
+	public void supportsParameterCurrentCustomSecurityContextNoAnnotation() {
+		assertThat(this.resolver.supportsParameter(
+				this.customSecurityContextNoAnnotationMethod.arg(Mono.class, CustomSecurityContext.class)))
+			.isTrue();
+	}
+
 	@Test
 	public void supportsParameterWithAuthentication() {
 		assertThat(this.resolver
@@ -123,6 +144,40 @@ public class CurrentSecurityContextArgumentResolverTests {
 		ReactiveSecurityContextHolder.clearContext();
 	}
 
+	@Test
+	public void resolveArgumentWithSecurityContextNoAnnotation() {
+		MethodParameter parameter = ResolvableMethod.on(getClass())
+			.named("securityContextNoAnnotation")
+			.build()
+			.arg(Mono.class, SecurityContext.class);
+		Authentication auth = buildAuthenticationWithPrincipal("hello");
+		Context context = ReactiveSecurityContextHolder.withAuthentication(auth);
+		Mono<Object> argument = this.resolver.resolveArgument(parameter, this.bindingContext, this.exchange);
+		SecurityContext securityContext = (SecurityContext) argument.contextWrite(context)
+			.cast(Mono.class)
+			.block()
+			.block();
+		assertThat(securityContext.getAuthentication()).isSameAs(auth);
+		ReactiveSecurityContextHolder.clearContext();
+	}
+
+	@Test
+	public void resolveArgumentWithCustomSecurityContextNoAnnotation() {
+		MethodParameter parameter = ResolvableMethod.on(getClass())
+			.named("customSecurityContextNoAnnotation")
+			.build()
+			.arg(Mono.class, CustomSecurityContext.class);
+		Authentication auth = buildAuthenticationWithPrincipal("hello");
+		Context context = ReactiveSecurityContextHolder.withSecurityContext(Mono.just(new CustomSecurityContext(auth)));
+		Mono<Object> argument = this.resolver.resolveArgument(parameter, this.bindingContext, this.exchange);
+		CustomSecurityContext securityContext = (CustomSecurityContext) argument.contextWrite(context)
+			.cast(Mono.class)
+			.block()
+			.block();
+		assertThat(securityContext.getAuthentication()).isSameAs(auth);
+		ReactiveSecurityContextHolder.clearContext();
+	}
+
 	@Test
 	public void resolveArgumentWithCustomSecurityContext() {
 		MethodParameter parameter = ResolvableMethod.on(getClass())
@@ -350,6 +405,12 @@ public class CurrentSecurityContextArgumentResolverTests {
 	void securityContext(@CurrentSecurityContext Mono<SecurityContext> monoSecurityContext) {
 	}
 
+	void securityContextNoAnnotation(Mono<SecurityContext> securityContextMono) {
+	}
+
+	void customSecurityContextNoAnnotation(Mono<CustomSecurityContext> securityContextMono) {
+	}
+
 	void customSecurityContext(@CurrentSecurityContext Mono<SecurityContext> monoSecurityContext) {
 	}