浏览代码

Test AuthorizeReturnObject in Reactive

Issue gh-14597
Josh Cummings 11 月之前
父节点
当前提交
fee5dd30c0

+ 222 - 0
config/src/test/java/org/springframework/security/config/annotation/method/configuration/PrePostReactiveMethodSecurityConfigurationTests.java

@@ -18,6 +18,10 @@ package org.springframework.security.config.annotation.method.configuration;
 
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
 
 import org.junit.jupiter.api.Test;
 import org.junit.jupiter.api.extension.ExtendWith;
@@ -34,12 +38,18 @@ import org.springframework.context.annotation.Role;
 import org.springframework.security.access.AccessDeniedException;
 import org.springframework.security.access.PermissionEvaluator;
 import org.springframework.security.access.expression.method.DefaultMethodSecurityExpressionHandler;
+import org.springframework.security.access.expression.method.MethodSecurityExpressionHandler;
+import org.springframework.security.access.hierarchicalroles.RoleHierarchy;
+import org.springframework.security.access.hierarchicalroles.RoleHierarchyImpl;
 import org.springframework.security.access.prepost.PostAuthorize;
 import org.springframework.security.access.prepost.PostFilter;
 import org.springframework.security.access.prepost.PreAuthorize;
 import org.springframework.security.access.prepost.PreFilter;
 import org.springframework.security.authorization.AuthorizationDeniedException;
+import org.springframework.security.authorization.method.AuthorizationAdvisorProxyFactory;
+import org.springframework.security.authorization.method.AuthorizeReturnObject;
 import org.springframework.security.authorization.method.PrePostTemplateDefaults;
+import org.springframework.security.config.Customizer;
 import org.springframework.security.config.test.SpringTestContext;
 import org.springframework.security.config.test.SpringTestContextExtension;
 import org.springframework.security.core.annotation.AnnotationTemplateExpressionDefaults;
@@ -49,6 +59,7 @@ import org.springframework.test.context.junit.jupiter.SpringExtension;
 
 import static org.assertj.core.api.Assertions.assertThat;
 import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.assertj.core.api.Assertions.assertThatNoException;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.eq;
 import static org.mockito.BDDMockito.given;
@@ -320,6 +331,84 @@ public class PrePostReactiveMethodSecurityConfigurationTests {
 			.containsExactly("dave");
 	}
 
+	@Test
+	@WithMockUser(authorities = "airplane:read")
+	public void findByIdWhenAuthorizedResultThenAuthorizes() {
+		this.spring.register(AuthorizeResultConfig.class).autowire();
+		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
+		Flight flight = flights.findById("1").block();
+		assertThatNoException().isThrownBy(flight::getAltitude);
+		assertThatNoException().isThrownBy(flight::getSeats);
+	}
+
+	@Test
+	@WithMockUser(authorities = "seating:read")
+	public void findByIdWhenUnauthorizedResultThenDenies() {
+		this.spring.register(AuthorizeResultConfig.class).autowire();
+		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
+		Flight flight = flights.findById("1").block();
+		assertThatNoException().isThrownBy(flight::getSeats);
+		assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> flight.getAltitude().block());
+	}
+
+	@Test
+	@WithMockUser(authorities = "seating:read")
+	public void findAllWhenUnauthorizedResultThenDenies() {
+		this.spring.register(AuthorizeResultConfig.class).autowire();
+		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
+		flights.findAll().collectList().block().forEach((flight) -> {
+			assertThatNoException().isThrownBy(flight::getSeats);
+			assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(() -> flight.getAltitude().block());
+		});
+	}
+
+	@Test
+	public void removeWhenAuthorizedResultThenRemoves() {
+		this.spring.register(AuthorizeResultConfig.class).autowire();
+		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
+		flights.remove("1");
+	}
+
+	@Test
+	@WithMockUser(authorities = "airplane:read")
+	public void findAllWhenPostFilterThenFilters() {
+		this.spring.register(AuthorizeResultConfig.class).autowire();
+		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
+		flights.findAll()
+			.collectList()
+			.block()
+			.forEach((flight) -> assertThat(flight.getPassengers().collectList().block())
+				.extracting((p) -> p.getName().block())
+				.doesNotContain("Kevin Mitnick"));
+	}
+
+	@Test
+	@WithMockUser(authorities = "airplane:read")
+	public void findAllWhenPreFilterThenFilters() {
+		this.spring.register(AuthorizeResultConfig.class).autowire();
+		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
+		flights.findAll().collectList().block().forEach((flight) -> {
+			flight.board(Flux.just("John")).block();
+			assertThat(flight.getPassengers().collectList().block()).extracting((p) -> p.getName().block())
+				.doesNotContain("John");
+			flight.board(Flux.just("John Doe")).block();
+			assertThat(flight.getPassengers().collectList().block()).extracting((p) -> p.getName().block())
+				.contains("John Doe");
+		});
+	}
+
+	@Test
+	@WithMockUser(authorities = "seating:read")
+	public void findAllWhenNestedPreAuthorizeThenAuthorizes() {
+		this.spring.register(AuthorizeResultConfig.class).autowire();
+		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
+		flights.findAll().collectList().block().forEach((flight) -> {
+			List<Passenger> passengers = flight.getPassengers().collectList().block();
+			passengers.forEach((passenger) -> assertThatExceptionOfType(AccessDeniedException.class)
+				.isThrownBy(() -> passenger.getName().block()));
+		});
+	}
+
 	@Configuration
 	@EnableReactiveMethodSecurity
 	static class MethodSecurityServiceEnabledConfig {
@@ -484,4 +573,137 @@ public class PrePostReactiveMethodSecurityConfigurationTests {
 
 	}
 
+	@EnableReactiveMethodSecurity
+	@Configuration
+	public static class AuthorizeResultConfig {
+
+		@Bean
+		@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
+		static Customizer<AuthorizationAdvisorProxyFactory> skipValueTypes() {
+			return (f) -> f.setTargetVisitor(AuthorizationAdvisorProxyFactory.TargetVisitor.defaultsSkipValueTypes());
+		}
+
+		@Bean
+		FlightRepository flights() {
+			FlightRepository flights = new FlightRepository();
+			Flight one = new Flight("1", 35000d, 35);
+			one.board(Flux.just("Marie Curie", "Kevin Mitnick", "Ada Lovelace")).block();
+			flights.save(one).block();
+			Flight two = new Flight("2", 32000d, 72);
+			two.board(Flux.just("Albert Einstein")).block();
+			flights.save(two).block();
+			return flights;
+		}
+
+		@Bean
+		static MethodSecurityExpressionHandler expressionHandler() {
+			RoleHierarchy hierarchy = RoleHierarchyImpl.withRolePrefix("")
+				.role("airplane:read")
+				.implies("seating:read")
+				.build();
+			DefaultMethodSecurityExpressionHandler expressionHandler = new DefaultMethodSecurityExpressionHandler();
+			expressionHandler.setRoleHierarchy(hierarchy);
+			return expressionHandler;
+		}
+
+		@Bean
+		Authz authz() {
+			return new Authz();
+		}
+
+		public static class Authz {
+
+			public Mono<Boolean> isNotKevinMitnick(Passenger passenger) {
+				return passenger.getName().map((n) -> !"Kevin Mitnick".equals(n));
+			}
+
+		}
+
+	}
+
+	@AuthorizeReturnObject
+	static class FlightRepository {
+
+		private final Map<String, Flight> flights = new ConcurrentHashMap<>();
+
+		Flux<Flight> findAll() {
+			return Flux.fromIterable(this.flights.values());
+		}
+
+		Mono<Flight> findById(String id) {
+			return Mono.just(this.flights.get(id));
+		}
+
+		Mono<Flight> save(Flight flight) {
+			this.flights.put(flight.getId(), flight);
+			return Mono.just(flight);
+		}
+
+		Mono<Void> remove(String id) {
+			this.flights.remove(id);
+			return Mono.empty();
+		}
+
+	}
+
+	@AuthorizeReturnObject
+	static class Flight {
+
+		private final String id;
+
+		private final Double altitude;
+
+		private final Integer seats;
+
+		private final List<Passenger> passengers = new ArrayList<>();
+
+		Flight(String id, Double altitude, Integer seats) {
+			this.id = id;
+			this.altitude = altitude;
+			this.seats = seats;
+		}
+
+		String getId() {
+			return this.id;
+		}
+
+		@PreAuthorize("hasAuthority('airplane:read')")
+		Mono<Double> getAltitude() {
+			return Mono.just(this.altitude);
+		}
+
+		@PreAuthorize("hasAuthority('seating:read')")
+		Mono<Integer> getSeats() {
+			return Mono.just(this.seats);
+		}
+
+		@PostAuthorize("hasAuthority('seating:read')")
+		@PostFilter("@authz.isNotKevinMitnick(filterObject)")
+		Flux<Passenger> getPassengers() {
+			return Flux.fromIterable(this.passengers);
+		}
+
+		@PreAuthorize("hasAuthority('seating:read')")
+		@PreFilter("filterObject.contains(' ')")
+		Mono<Void> board(Flux<String> passengers) {
+			return passengers.doOnNext((passenger) -> this.passengers.add(new Passenger(passenger))).then(Mono.empty());
+		}
+
+	}
+
+	public static class Passenger {
+
+		String name;
+
+		public Passenger(String name) {
+			this.name = name;
+		}
+
+		@PreAuthorize("hasAuthority('airplane:read')")
+		public Mono<String> getName() {
+			return Mono.just(this.name);
+		}
+
+	}
+
 }