ソースを参照

Add Kotlin support to PreFilter and PostFilter annotations

Closes gh-15093
Blagoja Stamatovski 1 年間 前
コミット
63f48167bd

+ 14 - 0
core/spring-security-core.gradle

@@ -1,6 +1,8 @@
+import org.jetbrains.kotlin.gradle.tasks.KotlinCompile
 import java.util.concurrent.Callable
 
 apply plugin: 'io.spring.convention.spring-module'
+apply plugin: 'kotlin'
 
 dependencies {
 	management platform(project(":spring-security-dependencies"))
@@ -31,6 +33,9 @@ dependencies {
 	testImplementation "org.springframework:spring-test"
 	testImplementation 'org.skyscreamer:jsonassert'
 	testImplementation 'org.springframework:spring-test'
+	testImplementation 'org.jetbrains.kotlin:kotlin-reflect'
+	testImplementation 'org.jetbrains.kotlin:kotlin-stdlib-jdk8'
+	testImplementation 'io.mockk:mockk'
 
 	testRuntimeOnly 'org.hsqldb:hsqldb'
 }
@@ -57,3 +62,12 @@ Callable<String> springVersion() {
 	return  (Callable<String>) { project.configurations.compileClasspath.resolvedConfiguration.resolvedArtifacts
     .find { it.name == 'spring-core' }.moduleVersion.id.version }
 }
+
+tasks.withType(KotlinCompile).configureEach {
+	kotlinOptions {
+		languageVersion = "1.7"
+		apiVersion = "1.7"
+		freeCompilerArgs = ["-Xjsr305=strict", "-Xsuppress-version-warnings"]
+		jvmTarget = "17"
+	}
+}

+ 31 - 13
core/src/main/java/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandler.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2024 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -52,6 +52,7 @@ import org.springframework.util.Assert;
  *
  * @author Luke Taylor
  * @author Evgeniy Cheban
+ * @author Blagoja Stamatovski
  * @since 3.0
  */
 public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpressionHandler<MethodInvocation>
@@ -109,12 +110,13 @@ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpr
 	}
 
 	/**
-	 * Filters the {@code filterTarget} object (which must be either a collection, array,
-	 * map or stream), by evaluating the supplied expression.
+	 * Filters the {@code filterTarget} object (which must be either a {@link Collection},
+	 * {@code Array}, {@link Map} or {@link Stream}), by evaluating the supplied
+	 * expression.
 	 * <p>
-	 * If a {@code Collection} or {@code Map} is used, the original instance will be
-	 * modified to contain the elements for which the permission expression evaluates to
-	 * {@code true}. For an array, a new array instance will be returned.
+	 * Returns new instances of the same type as the supplied {@code filterTarget} object
+	 * @return The filtered {@link Collection}, {@code Array}, {@link Map} or
+	 * {@link Stream}
 	 */
 	@Override
 	public Object filter(Object filterTarget, Expression filterExpression, EvaluationContext ctx) {
@@ -151,9 +153,17 @@ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpr
 			}
 		}
 		this.logger.debug(LogMessage.format("Retaining elements: %s", retain));
-		filterTarget.clear();
-		filterTarget.addAll(retain);
-		return filterTarget;
+		try {
+			filterTarget.clear();
+			filterTarget.addAll(retain);
+			return filterTarget;
+		}
+		catch (UnsupportedOperationException unsupportedOperationException) {
+			this.logger.debug(LogMessage.format(
+					"Collection threw exception: %s. Will return a new instance instead of mutating its state.",
+					unsupportedOperationException.getMessage()));
+			return retain;
+		}
 	}
 
 	private Object filterArray(Object[] filterTarget, Expression filterExpression, EvaluationContext ctx,
@@ -178,7 +188,7 @@ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpr
 		return filtered;
 	}
 
-	private <K, V> Object filterMap(final Map<K, V> filterTarget, Expression filterExpression, EvaluationContext ctx,
+	private <K, V> Object filterMap(Map<K, V> filterTarget, Expression filterExpression, EvaluationContext ctx,
 			MethodSecurityExpressionOperations rootObject) {
 		Map<K, V> retain = new LinkedHashMap<>(filterTarget.size());
 		this.logger.debug(LogMessage.format("Filtering map with %s elements", filterTarget.size()));
@@ -189,9 +199,17 @@ public class DefaultMethodSecurityExpressionHandler extends AbstractSecurityExpr
 			}
 		}
 		this.logger.debug(LogMessage.format("Retaining elements: %s", retain));
-		filterTarget.clear();
-		filterTarget.putAll(retain);
-		return filterTarget;
+		try {
+			filterTarget.clear();
+			filterTarget.putAll(retain);
+			return filterTarget;
+		}
+		catch (UnsupportedOperationException unsupportedOperationException) {
+			this.logger.debug(LogMessage.format(
+					"Map threw exception: %s. Will return a new instance instead of mutating its state.",
+					unsupportedOperationException.getMessage()));
+			return retain;
+		}
 	}
 
 	private Object filterStream(final Stream<?> filterTarget, Expression filterExpression, EvaluationContext ctx,

+ 236 - 0
core/src/test/kotlin/org/springframework/security/access/expression/method/DefaultMethodSecurityExpressionHandlerKotlinTests.kt

@@ -0,0 +1,236 @@
+/*
+ * Copyright 2002-2024 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.security.access.expression.method
+
+import io.mockk.every
+import io.mockk.mockk
+import org.aopalliance.intercept.MethodInvocation
+import org.assertj.core.api.Assertions.assertThat
+import org.junit.jupiter.api.BeforeEach
+import org.junit.jupiter.api.Test
+import org.springframework.expression.EvaluationContext
+import org.springframework.expression.Expression
+import org.springframework.security.core.Authentication
+import java.util.stream.Stream
+import kotlin.reflect.jvm.internal.impl.load.kotlin.JvmType
+import kotlin.reflect.jvm.javaMethod
+
+/**
+ * @author Blagoja Stamatovski
+ */
+class DefaultMethodSecurityExpressionHandlerKotlinTests {
+    private object Foo {
+        fun bar() {
+        }
+    }
+
+    private lateinit var authentication: Authentication
+    private lateinit var methodInvocation: MethodInvocation
+
+    private val handler: MethodSecurityExpressionHandler = DefaultMethodSecurityExpressionHandler()
+
+    @BeforeEach
+    fun setUp()  {
+        authentication = mockk()
+        methodInvocation = mockk()
+
+        every { methodInvocation.`this` } returns { Foo }
+        every { methodInvocation.method } answers { Foo::bar.javaMethod!! }
+        every { methodInvocation.arguments } answers { arrayOf<JvmType.Object>() }
+    }
+
+    @Test
+    fun `filters non-empty maps`() {
+        val expression: Expression = handler.expressionParser.parseExpression("filterObject.key eq 'key2'")
+        val context: EvaluationContext = handler.createEvaluationContext(
+            /* authentication = */ authentication,
+            /* invocation = */ methodInvocation,
+        )
+        val nonEmptyMap: Map<String, String> = mapOf(
+            "key1" to "value1",
+            "key2" to "value2",
+            "key3" to "value3",
+        )
+
+        val filtered: Any = handler.filter(
+            /* filterTarget = */ nonEmptyMap,
+            /* filterExpression = */ expression,
+            /* ctx = */ context,
+        )
+
+        assertThat(filtered).isInstanceOf(Map::class.java)
+        val result = (filtered as Map<String, String>)
+        assertThat(result).hasSize(1)
+        assertThat(result).containsKey("key2")
+        assertThat(result).containsValue("value2")
+    }
+
+    @Test
+    fun `filters empty maps`() {
+        val expression: Expression = handler.expressionParser.parseExpression("filterObject.key eq 'key2'")
+        val context: EvaluationContext = handler.createEvaluationContext(
+            /* authentication = */ authentication,
+            /* invocation = */ methodInvocation,
+        )
+        val emptyMap: Map<String, String> = emptyMap()
+
+        val filtered: Any = handler.filter(
+            /* filterTarget = */ emptyMap,
+            /* filterExpression = */ expression,
+            /* ctx = */ context,
+        )
+
+        assertThat(filtered).isInstanceOf(Map::class.java)
+        val result = (filtered as Map<String, String>)
+        assertThat(result).hasSize(0)
+    }
+
+    @Test
+    fun `filters non-empty collections`() {
+        val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'")
+        val context: EvaluationContext = handler.createEvaluationContext(
+            /* authentication = */ authentication,
+            /* invocation = */ methodInvocation,
+        )
+        val nonEmptyCollection: Collection<String> = listOf(
+            "string1",
+            "string2",
+            "string1",
+        )
+
+        val filtered: Any = handler.filter(
+            /* filterTarget = */ nonEmptyCollection,
+            /* filterExpression = */ expression,
+            /* ctx = */ context,
+        )
+
+        assertThat(filtered).isInstanceOf(Collection::class.java)
+        val result = (filtered as Collection<String>)
+        assertThat(result).hasSize(1)
+        assertThat(result).contains("string2")
+    }
+
+    @Test
+    fun `filters empty collections`() {
+        val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'")
+        val context: EvaluationContext = handler.createEvaluationContext(
+            /* authentication = */ authentication,
+            /* invocation = */ methodInvocation,
+        )
+        val emptyCollection: Collection<String> = emptyList()
+
+        val filtered: Any = handler.filter(
+            /* filterTarget = */ emptyCollection,
+            /* filterExpression = */ expression,
+            /* ctx = */ context,
+        )
+
+        assertThat(filtered).isInstanceOf(Collection::class.java)
+        val result = (filtered as Collection<String>)
+        assertThat(result).hasSize(0)
+    }
+
+    @Test
+    fun `filters non-empty arrays`() {
+        val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'")
+        val context: EvaluationContext = handler.createEvaluationContext(
+            /* authentication = */ authentication,
+            /* invocation = */ methodInvocation,
+        )
+        val nonEmptyArray: Array<String> = arrayOf(
+            "string1",
+            "string2",
+            "string1",
+        )
+
+        val filtered: Any = handler.filter(
+            /* filterTarget = */ nonEmptyArray,
+            /* filterExpression = */ expression,
+            /* ctx = */ context,
+        )
+
+        assertThat(filtered).isInstanceOf(Array<String>::class.java)
+        val result = (filtered as Array<String>)
+        assertThat(result).hasSize(1)
+        assertThat(result).contains("string2")
+    }
+
+    @Test
+    fun `filters empty arrays`() {
+        val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'")
+        val context: EvaluationContext = handler.createEvaluationContext(
+            /* authentication = */ authentication,
+            /* invocation = */ methodInvocation,
+        )
+        val emptyArray: Array<String> = emptyArray()
+
+        val filtered: Any = handler.filter(
+            /* filterTarget = */ emptyArray,
+            /* filterExpression = */ expression,
+            /* ctx = */ context,
+        )
+
+        assertThat(filtered).isInstanceOf(Array<String>::class.java)
+        val result = (filtered as Array<String>)
+        assertThat(result).hasSize(0)
+    }
+
+    @Test
+    fun `filters non-empty streams`() {
+        val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'")
+        val context: EvaluationContext = handler.createEvaluationContext(
+            /* authentication = */ authentication,
+            /* invocation = */ methodInvocation,
+        )
+        val nonEmptyStream: Stream<String> = listOf(
+            "string1",
+            "string2",
+            "string1",
+        ).stream()
+
+        val filtered: Any = handler.filter(
+            /* filterTarget = */ nonEmptyStream,
+            /* filterExpression = */ expression,
+            /* ctx = */ context,
+        )
+
+        assertThat(filtered).isInstanceOf(Stream::class.java)
+        val result = (filtered as Stream<String>).toList()
+        assertThat(result).hasSize(1)
+        assertThat(result).contains("string2")
+    }
+
+    @Test
+    fun `filters empty streams`() {
+        val expression: Expression = handler.expressionParser.parseExpression("filterObject eq 'string2'")
+        val context: EvaluationContext = handler.createEvaluationContext(
+            /* authentication = */ authentication,
+            /* invocation = */ methodInvocation,
+        )
+        val emptyStream: Stream<String> = emptyList<String>().stream()
+
+        val filtered: Any = handler.filter(
+            /* filterTarget = */ emptyStream,
+            /* filterExpression = */ expression,
+            /* ctx = */ context,
+        )
+
+        assertThat(filtered).isInstanceOf(Stream::class.java)
+        val result = (filtered as Stream<String>).toList()
+        assertThat(result).hasSize(0)
+    }
+}

+ 106 - 8
docs/modules/ROOT/pages/servlet/authorization/method-security.adoc

@@ -546,9 +546,6 @@ If not, Spring Security will throw an `AccessDeniedException` and return a 403 s
 [[use-prefilter]]
 === Filtering Method Parameters with `@PreFilter`
 
-[NOTE]
-`@PreFilter` is not yet supported for Kotlin-specific data types; for that reason, only Java snippets are shown
-
 When Method Security is active, you can annotate a method with the {security-api-url}org/springframework/security/access/prepost/PreFilter.html[`@PreFilter`] annotation like so:
 
 [tabs]
@@ -566,6 +563,20 @@ public class BankService {
 	}
 }
 ----
+
+Kotlin::
++
+[source,kotlin,role="secondary"]
+----
+@Component
+open class BankService {
+	@PreFilter("filterObject.owner == authentication.name")
+	fun updateAccounts(vararg accounts: Account): Collection<Account> {
+        // ... `accounts` will only contain the accounts owned by the logged-in user
+        return updated
+	}
+}
+----
 ======
 
 This is meant to filter out any values from `accounts` where the expression `filterObject.owner == authentication.name` fails.
@@ -591,6 +602,23 @@ void updateAccountsWhenOwnedThenReturns() {
     assertThat(updated).containsOnly(ownedBy);
 }
 ----
+
+Kotlin::
++
+[source,kotlin,role="secondary"]
+----
+@Autowired
+lateinit var bankService: BankService
+
+@WithMockUser(username="owner")
+@Test
+fun updateAccountsWhenOwnedThenReturns() {
+    val ownedBy: Account = ...
+    val notOwnedBy: Account = ...
+    val updated: Collection<Account> = bankService.updateAccounts(ownedBy, notOwnedBy)
+    assertThat(updated).containsOnly(ownedBy)
+}
+----
 ======
 
 [TIP]
@@ -618,6 +646,23 @@ public Collection<Account> updateAccounts(Map<String, Account> accounts)
 @PreFilter("filterObject.owner == authentication.name")
 public Collection<Account> updateAccounts(Stream<Account> accounts)
 ----
+
+Kotlin::
++
+[source,kotlin,role="secondary"]
+----
+@PreFilter("filterObject.owner == authentication.name")
+fun updateAccounts(accounts: Array<Account>): Collection<Account>
+
+@PreFilter("filterObject.owner == authentication.name")
+fun updateAccounts(accounts: Collection<Account>): Collection<Account>
+
+@PreFilter("filterObject.value.owner == authentication.name")
+fun updateAccounts(accounts: Map<String, Account>): Collection<Account>
+
+@PreFilter("filterObject.owner == authentication.name")
+fun updateAccounts(accounts: Stream<Account>): Collection<Account>
+----
 ======
 
 The result is that the above method will only have the `Account` instances where their `owner` attribute matches the logged-in user's `name`.
@@ -625,9 +670,6 @@ The result is that the above method will only have the `Account` instances where
 [[use-postfilter]]
 === Filtering Method Results with `@PostFilter`
 
-[NOTE]
-`@PostFilter` is not yet supported for Kotlin-specific data types; for that reason, only Java snippets are shown
-
 When Method Security is active, you can annotate a method with the {security-api-url}org/springframework/security/access/prepost/PostFilter.html[`@PostFilter`] annotation like so:
 
 [tabs]
@@ -645,6 +687,20 @@ public class BankService {
 	}
 }
 ----
+
+Kotlin::
++
+[source,kotlin,role="secondary"]
+----
+@Component
+open class BankService {
+	@PreFilter("filterObject.owner == authentication.name")
+	fun readAccounts(vararg ids: String): Collection<Account> {
+        // ... the return value will be filtered to only contain the accounts owned by the logged-in user
+        return accounts
+	}
+}
+----
 ======
 
 This is meant to filter out any values from the return value where the expression `filterObject.owner == authentication.name` fails.
@@ -669,6 +725,22 @@ void readAccountsWhenOwnedThenReturns() {
     assertThat(accounts.get(0).getOwner()).isEqualTo("owner");
 }
 ----
+
+Kotlin::
++
+[source,kotlin,role="secondary"]
+----
+@Autowired
+lateinit var bankService: BankService
+
+@WithMockUser(username="owner")
+@Test
+fun readAccountsWhenOwnedThenReturns() {
+    val accounts: Collection<Account> = bankService.updateAccounts("owner", "not-owner")
+    assertThat(accounts).hasSize(1)
+    assertThat(accounts[0].owner).isEqualTo("owner")
+}
+----
 ======
 
 [TIP]
@@ -678,7 +750,15 @@ void readAccountsWhenOwnedThenReturns() {
 
 For example, the above `readAccounts` declaration will function the same way as the following other three:
 
-```java
+[tabs]
+======
+Java::
++
+[source,java,role="primary"]
+----
+@PostFilter("filterObject.owner == authentication.name")
+public Collection<Account> readAccounts(String... ids)
+
 @PostFilter("filterObject.owner == authentication.name")
 public Account[] readAccounts(String... ids)
 
@@ -687,7 +767,25 @@ public Map<String, Account> readAccounts(String... ids)
 
 @PostFilter("filterObject.owner == authentication.name")
 public Stream<Account> readAccounts(String... ids)
-```
+----
+
+Kotlin::
++
+[source,kotlin,role="secondary"]
+----
+@PostFilter("filterObject.owner == authentication.name")
+fun readAccounts(vararg ids: String): Collection<Account>
+
+@PostFilter("filterObject.owner == authentication.name")
+fun readAccounts(vararg ids: String): Array<Account>
+
+@PostFilter("filterObject.owner == authentication.name")
+fun readAccounts(vararg ids: String): Map<String, Account>
+
+@PostFilter("filterObject.owner == authentication.name")
+fun readAccounts(vararg ids: String): Stream<Account>
+----
+======
 
 The result is that the above method will return the `Account` instances where their `owner` attribute matches the logged-in user's `name`.