Kaynağa Gözat

review phase2

Dan Zheng 6 yıl önce
ebeveyn
işleme
22c8f63390

+ 11 - 3
core/src/main/java/org/springframework/security/core/annotation/CurrentSecurityContext.java

@@ -5,7 +5,7 @@
  * you may not use this file except in compliance with the License.
  * You may obtain a copy of the License at
  *
- *      http://www.apache.org/licenses/LICENSE-2.0
+ *      https://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -28,11 +28,19 @@ import java.lang.annotation.Target;
  * @author Dan Zheng
  * @since 5.2
  *
+ * <p>
  * See: <a href=
  * "{@docRoot}/org/springframework/security/web/bind/support/CurrentSecurityContextArgumentResolver.html"
- * > CurrentSecurityContextArgumentResolver </a>
+ * > CurrentSecurityContextArgumentResolver</a> For Servlet
+ * </p>
+ *
+ * <p>
+ * See: <a href=
+ * "{@docRoot}/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.html"
+ * > CurrentSecurityContextArgumentResolver</a> For WebFlux
+ * </p>
  */
-@Target({ ElementType.PARAMETER })
+@Target({ ElementType.PARAMETER, ElementType.ANNOTATION_TYPE })
 @Retention(RetentionPolicy.RUNTIME)
 @Documented
 public @interface CurrentSecurityContext {

+ 3 - 1
web/src/main/java/org/springframework/security/web/bind/support/CurrentSecurityContextArgumentResolver.java

@@ -5,7 +5,7 @@
  * you may not use this file except in compliance with the License.
  * You may obtain a copy of the License at
  *
- *      http://www.apache.org/licenses/LICENSE-2.0
+ *      https://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -28,6 +28,7 @@ import org.springframework.security.core.annotation.CurrentSecurityContext;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextHolder;
 import org.springframework.stereotype.Controller;
+import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 import org.springframework.web.bind.support.WebDataBinderFactory;
 import org.springframework.web.context.request.NativeWebRequest;
@@ -142,6 +143,7 @@ public final class CurrentSecurityContextArgumentResolver
 	 * @param beanResolver the {@link BeanResolver} to use
 	 */
 	public void setBeanResolver(BeanResolver beanResolver) {
+		Assert.notNull(beanResolver, "beanResolver cannot be null");
 		this.beanResolver = beanResolver;
 	}
 

+ 3 - 1
web/src/main/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolver.java

@@ -5,7 +5,7 @@
  * you may not use this file except in compliance with the License.
  * You may obtain a copy of the License at
  *
- *      http://www.apache.org/licenses/LICENSE-2.0
+ *      https://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -29,6 +29,7 @@ import org.springframework.expression.spel.support.StandardEvaluationContext;
 import org.springframework.security.core.annotation.CurrentSecurityContext;
 import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 import org.springframework.security.core.context.SecurityContext;
+import org.springframework.util.Assert;
 import org.springframework.util.StringUtils;
 import org.springframework.web.reactive.BindingContext;
 import org.springframework.web.reactive.result.method.HandlerMethodArgumentResolverSupport;
@@ -57,6 +58,7 @@ public class CurrentSecurityContextArgumentResolver extends HandlerMethodArgumen
 	 * @param beanResolver the {@link BeanResolver} to use
 	 */
 	public void setBeanResolver(BeanResolver beanResolver) {
+		Assert.notNull(beanResolver, "beanResolver cannot be null");
 		this.beanResolver = beanResolver;
 	}
 

+ 81 - 1
web/src/test/java/org/springframework/security/web/bind/support/CurrentSecurityContextArgumentResolverTests.java

@@ -5,7 +5,7 @@
  * you may not use this file except in compliance with the License.
  * You may obtain a copy of the License at
  *
- *      http://www.apache.org/licenses/LICENSE-2.0
+ *      https://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -15,6 +15,10 @@
  */
 package org.springframework.security.web.bind.support;
 
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
 import java.lang.reflect.Method;
 
 import org.junit.After;
@@ -162,6 +166,32 @@ public class CurrentSecurityContextArgumentResolverTests {
 				null, null));
 	}
 
+	@Test
+	public void metaAnnotationWhenCurrentCustomSecurityContextThenInjectSecurityContext() throws Exception {
+		assertThat(resolver.resolveArgument(showCurrentCustomSecurityContext(), null, null, null))
+				.isNotNull();
+	}
+
+	@Test
+	public void metaAnnotationWhenCurrentAuthenticationThenInjectAuthentication() throws Exception {
+		String principal = "current_authentcation";
+		setAuthenticationPrincipal(principal);
+		Authentication auth1 = (Authentication) resolver.resolveArgument(showCurrentAuthentication(), null, null, null);
+		assertThat(auth1.getPrincipal()).isEqualTo(principal);
+	}
+
+	@Test
+	public void metaAnnotationWhenCurrentSecurityWithErrorOnInvalidTypeThenInjectSecurityContext() throws Exception {
+		assertThat(resolver.resolveArgument(showCurrentSecurityWithErrorOnInvalidType(), null, null, null))
+				.isNotNull();
+	}
+
+	@Test
+	public void metaAnnotationWhenCurrentSecurityWithErrorOnInvalidTypeThenMisMatch() throws Exception {
+		assertThatExceptionOfType(ClassCastException.class).isThrownBy(() -> resolver.resolveArgument(showCurrentSecurityWithErrorOnInvalidTypeMisMatch(), null,
+				null, null));
+	}
+
 	private MethodParameter showSecurityContextNoAnnotation() {
 		return getMethodParameter("showSecurityContextNoAnnotation", String.class);
 	}
@@ -206,6 +236,22 @@ public class CurrentSecurityContextArgumentResolverTests {
 		return getMethodParameter("showSecurityContextErrorOnInvalidTypeTrue", String.class);
 	}
 
+	public MethodParameter showCurrentCustomSecurityContext() {
+		return getMethodParameter("showCurrentCustomSecurityContext", SecurityContext.class);
+	}
+
+	public MethodParameter showCurrentAuthentication() {
+		return getMethodParameter("showCurrentAuthentication", Authentication.class);
+	}
+
+	public MethodParameter showCurrentSecurityWithErrorOnInvalidType() {
+		return getMethodParameter("showCurrentSecurityWithErrorOnInvalidType", SecurityContext.class);
+	}
+
+	public MethodParameter showCurrentSecurityWithErrorOnInvalidTypeMisMatch() {
+		return getMethodParameter("showCurrentSecurityWithErrorOnInvalidTypeMisMatch", String.class);
+	}
+
 	private MethodParameter getMethodParameter(String methodName, Class<?>... paramTypes) {
 		Method method = ReflectionUtils.findMethod(TestController.class, methodName,
 				paramTypes);
@@ -248,6 +294,22 @@ public class CurrentSecurityContextArgumentResolverTests {
 		public void showSecurityContextErrorOnInvalidTypeTrue(
 				@CurrentSecurityContext(errorOnInvalidType = true) String implicit) {
 		}
+
+		public void showCurrentCustomSecurityContext(
+				@CurrentCustomSecurityContext SecurityContext context) {
+		}
+
+		public void showCurrentAuthentication(
+				@CurrentAuthentication Authentication authentication) {
+		}
+
+		public void showCurrentSecurityWithErrorOnInvalidType(
+				@CurrentSecurityWithErrorOnInvalidType SecurityContext context) {
+		}
+
+		public void showCurrentSecurityWithErrorOnInvalidTypeMisMatch(
+				@CurrentSecurityWithErrorOnInvalidType String typeMisMatch) {
+		}
 	}
 
 	private void setAuthenticationPrincipal(Object principal) {
@@ -277,6 +339,24 @@ public class CurrentSecurityContextArgumentResolverTests {
 		}
 	}
 
+	@Target({ ElementType.PARAMETER })
+	@Retention(RetentionPolicy.RUNTIME)
+	@CurrentSecurityContext
+	static @interface CurrentCustomSecurityContext {
+	}
+
+	@Target({ ElementType.PARAMETER })
+	@Retention(RetentionPolicy.RUNTIME)
+	@CurrentSecurityContext(expression = "authentication")
+	static @interface CurrentAuthentication {
+	}
+
+	@Target({ ElementType.PARAMETER })
+	@Retention(RetentionPolicy.RUNTIME)
+	@CurrentSecurityContext(errorOnInvalidType = true)
+	static @interface CurrentSecurityWithErrorOnInvalidType {
+	}
+
 	private void setAuthenticationDetail(Object detail) {
 		TestingAuthenticationToken tat = new TestingAuthenticationToken("user", "password",
 				"ROLE_USER");

+ 82 - 15
web/src/test/java/org/springframework/security/web/reactive/result/method/annotation/CurrentSecurityContextArgumentResolverTests.java

@@ -5,7 +5,7 @@
  * you may not use this file except in compliance with the License.
  * You may obtain a copy of the License at
  *
- *      http://www.apache.org/licenses/LICENSE-2.0
+ *      https://www.apache.org/licenses/LICENSE-2.0
  *
  * Unless required by applicable law or agreed to in writing, software
  * distributed under the License is distributed on an "AS IS" BASIS,
@@ -16,13 +16,19 @@
 
 package org.springframework.security.web.reactive.result.method.annotation;
 
+import static org.assertj.core.api.Assertions.assertThat;
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+
+import java.lang.annotation.ElementType;
+import java.lang.annotation.Retention;
+import java.lang.annotation.RetentionPolicy;
+import java.lang.annotation.Target;
+
 import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.mockito.Mock;
 import org.mockito.junit.MockitoJUnitRunner;
-import reactor.core.publisher.Mono;
-import reactor.util.context.Context;
 
 import org.springframework.core.MethodParameter;
 import org.springframework.core.ReactiveAdapterRegistry;
@@ -37,8 +43,8 @@ import org.springframework.security.web.method.ResolvableMethod;
 import org.springframework.web.reactive.BindingContext;
 import org.springframework.web.server.ServerWebExchange;
 
-import static org.assertj.core.api.Assertions.assertThat;
-import static org.junit.Assert.fail;
+import reactor.core.publisher.Mono;
+import reactor.util.context.Context;
 
 
 /**
@@ -172,11 +178,7 @@ public class CurrentSecurityContextArgumentResolverTests {
 		Authentication auth = null;
 		Context context = ReactiveSecurityContextHolder.withAuthentication(auth);
 		Mono<Object> argument = resolver.resolveArgument(parameter, bindingContext, exchange);
-		try {
-			Mono<Object> obj = (Mono<Object>) argument.subscriberContext(context).block();
-			fail("should not reach here");
-		} catch(SpelEvaluationException e) {
-		}
+		assertThatExceptionOfType(SpelEvaluationException.class).isThrownBy(() -> argument.subscriberContext(context).block());
 		ReactiveSecurityContextHolder.clearContext();
 	}
 
@@ -219,11 +221,50 @@ public class CurrentSecurityContextArgumentResolverTests {
 		Authentication auth = buildAuthenticationWithPrincipal("error_on_invalid_type_explicit_true");
 		Context context = ReactiveSecurityContextHolder.withAuthentication(auth);
 		Mono<Object> argument = resolver.resolveArgument(parameter, bindingContext, exchange);
-		try {
-			Mono<String> obj = (Mono<String>) argument.subscriberContext(context).block();
-			fail("should not reach here");
-		} catch(ClassCastException ex) {
-		}
+		assertThatExceptionOfType(ClassCastException.class).isThrownBy(() -> argument.subscriberContext(context).block());
+		ReactiveSecurityContextHolder.clearContext();
+	}
+
+	@Test
+	public void metaAnnotationWhenDefaultSecurityContextThenInjectSecurityContext() throws Exception {
+		MethodParameter parameter = ResolvableMethod.on(getClass()).named("currentCustomSecurityContext").build().arg(Mono.class, SecurityContext.class);
+		Authentication auth = buildAuthenticationWithPrincipal("current_custom_security_context");
+		Context context = ReactiveSecurityContextHolder.withAuthentication(auth);
+		Mono<Object> argument = resolver.resolveArgument(parameter, bindingContext, exchange);
+		SecurityContext securityContext = (SecurityContext) argument.subscriberContext(context).cast(Mono.class).block().block();
+		assertThat(securityContext.getAuthentication()).isSameAs(auth);
+		ReactiveSecurityContextHolder.clearContext();
+	}
+
+	@Test
+	public void metaAnnotationWhenCurrentAuthenticationThenInjectAuthentication() throws Exception {
+		MethodParameter parameter = ResolvableMethod.on(getClass()).named("currentAuthentication").build().arg(Mono.class, Authentication.class);
+		Authentication auth = buildAuthenticationWithPrincipal("current_authentication");
+		Context context = ReactiveSecurityContextHolder.withAuthentication(auth);
+		Mono<Object> argument = resolver.resolveArgument(parameter, bindingContext, exchange);
+		Authentication authentication = (Authentication) argument.subscriberContext(context).cast(Mono.class).block().block();
+		assertThat(authentication).isSameAs(auth);
+		ReactiveSecurityContextHolder.clearContext();
+	}
+
+	@Test
+	public void metaAnnotationWhenCurrentSecurityWithErrorOnInvalidTypeThenInjectSecurityContext() throws Exception {
+		MethodParameter parameter = ResolvableMethod.on(getClass()).named("currentSecurityWithErrorOnInvalidType").build().arg(Mono.class, SecurityContext.class);
+		Authentication auth = buildAuthenticationWithPrincipal("current_security_with_error_on_invalid_type");
+		Context context = ReactiveSecurityContextHolder.withAuthentication(auth);
+		Mono<Object> argument = resolver.resolveArgument(parameter, bindingContext, exchange);
+		SecurityContext securityContext = (SecurityContext) argument.subscriberContext(context).cast(Mono.class).block().block();
+		assertThat(securityContext.getAuthentication()).isSameAs(auth);
+		ReactiveSecurityContextHolder.clearContext();
+	}
+
+	@Test
+	public void metaAnnotationWhenCurrentSecurityWithErrorOnInvalidTypeThenMisMatch() throws Exception {
+		MethodParameter parameter = ResolvableMethod.on(getClass()).named("currentSecurityWithErrorOnInvalidTypeMisMatch").build().arg(Mono.class, String.class);
+		Authentication auth = buildAuthenticationWithPrincipal("current_security_with_error_on_invalid_type_mismatch");
+		Context context = ReactiveSecurityContextHolder.withAuthentication(auth);
+		Mono<Object> argument = resolver.resolveArgument(parameter, bindingContext, exchange);
+		assertThatExceptionOfType(ClassCastException.class).isThrownBy(() -> argument.subscriberContext(context).cast(Mono.class).block().block());
 		ReactiveSecurityContextHolder.clearContext();
 	}
 
@@ -245,6 +286,32 @@ public class CurrentSecurityContextArgumentResolverTests {
 
 	void errorOnInvalidTypeWhenExplicitTrue(@CurrentSecurityContext(errorOnInvalidType = true) Mono<String> implicit) {}
 
+	void currentCustomSecurityContext(@CurrentCustomSecurityContext Mono<SecurityContext> monoSecurityContext) {}
+
+	void currentAuthentication(@CurrentAuthentication Mono<Authentication> authentication) {}
+
+	void currentSecurityWithErrorOnInvalidType(@CurrentSecurityWithErrorOnInvalidType Mono<SecurityContext> monoSecurityContext) {}
+
+	void currentSecurityWithErrorOnInvalidTypeMisMatch(@CurrentSecurityWithErrorOnInvalidType Mono<String> typeMisMatch) {}
+
+	@Target({ ElementType.PARAMETER })
+	@Retention(RetentionPolicy.RUNTIME)
+	@CurrentSecurityContext
+	static @interface CurrentCustomSecurityContext {
+	}
+
+	@Target({ ElementType.PARAMETER })
+	@Retention(RetentionPolicy.RUNTIME)
+	@CurrentSecurityContext(expression = "authentication")
+	static @interface CurrentAuthentication {
+	}
+
+	@Target({ ElementType.PARAMETER })
+	@Retention(RetentionPolicy.RUNTIME)
+	@CurrentSecurityContext(errorOnInvalidType = true)
+	static @interface CurrentSecurityWithErrorOnInvalidType {
+	}
+
 	static class CustomSecurityContext implements SecurityContext {
 		private Authentication authentication;
 		public CustomSecurityContext(Authentication authentication) {