|  | @@ -1,5 +1,5 @@
 | 
	
		
			
				|  |  |  /*
 | 
	
		
			
				|  |  | - * Copyright 2002-2019 the original author or authors.
 | 
	
		
			
				|  |  | + * Copyright 2002-2024 the original author or authors.
 | 
	
		
			
				|  |  |   *
 | 
	
		
			
				|  |  |   * Licensed under the Apache License, Version 2.0 (the "License");
 | 
	
		
			
				|  |  |   * you may not use this file except in compliance with the License.
 | 
	
	
		
			
				|  | @@ -16,20 +16,36 @@
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  package org.springframework.security.config.annotation.method.configuration;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +import java.util.ArrayList;
 | 
	
		
			
				|  |  | +import java.util.List;
 | 
	
		
			
				|  |  | +import java.util.Map;
 | 
	
		
			
				|  |  | +import java.util.concurrent.ConcurrentHashMap;
 | 
	
		
			
				|  |  | +import java.util.function.Function;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  import org.junit.jupiter.api.Test;
 | 
	
		
			
				|  |  |  import org.junit.jupiter.api.extension.ExtendWith;
 | 
	
		
			
				|  |  | +import reactor.core.publisher.Flux;
 | 
	
		
			
				|  |  | +import reactor.core.publisher.Mono;
 | 
	
		
			
				|  |  | +import reactor.test.StepVerifier;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import org.springframework.beans.factory.annotation.Autowired;
 | 
	
		
			
				|  |  |  import org.springframework.context.annotation.Bean;
 | 
	
		
			
				|  |  |  import org.springframework.context.annotation.Configuration;
 | 
	
		
			
				|  |  |  import org.springframework.expression.EvaluationContext;
 | 
	
		
			
				|  |  | +import org.springframework.security.access.AccessDeniedException;
 | 
	
		
			
				|  |  |  import org.springframework.security.access.expression.SecurityExpressionRoot;
 | 
	
		
			
				|  |  |  import org.springframework.security.access.expression.method.DefaultMethodSecurityExpressionHandler;
 | 
	
		
			
				|  |  |  import org.springframework.security.access.intercept.method.MockMethodInvocation;
 | 
	
		
			
				|  |  | +import org.springframework.security.access.prepost.PostAuthorize;
 | 
	
		
			
				|  |  | +import org.springframework.security.access.prepost.PostFilter;
 | 
	
		
			
				|  |  | +import org.springframework.security.access.prepost.PreAuthorize;
 | 
	
		
			
				|  |  | +import org.springframework.security.access.prepost.PreFilter;
 | 
	
		
			
				|  |  |  import org.springframework.security.authentication.TestingAuthenticationToken;
 | 
	
		
			
				|  |  | +import org.springframework.security.authorization.method.AuthorizeReturnObject;
 | 
	
		
			
				|  |  |  import org.springframework.security.config.core.GrantedAuthorityDefaults;
 | 
	
		
			
				|  |  |  import org.springframework.security.config.test.SpringTestContext;
 | 
	
		
			
				|  |  |  import org.springframework.security.config.test.SpringTestContextExtension;
 | 
	
		
			
				|  |  | +import org.springframework.security.core.context.ReactiveSecurityContextHolder;
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import static org.assertj.core.api.Assertions.assertThat;
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -85,6 +101,112 @@ public class ReactiveMethodSecurityConfigurationTests {
 | 
	
		
			
				|  |  |  		assertThat(root.hasRole("ABC")).isTrue();
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	@Test
 | 
	
		
			
				|  |  | +	public void findByIdWhenAuthorizedResultThenAuthorizes() {
 | 
	
		
			
				|  |  | +		this.spring.register(AuthorizeResultConfig.class).autowire();
 | 
	
		
			
				|  |  | +		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
 | 
	
		
			
				|  |  | +		TestingAuthenticationToken pilot = new TestingAuthenticationToken("user", "pass", "airplane:read");
 | 
	
		
			
				|  |  | +		StepVerifier
 | 
	
		
			
				|  |  | +			.create(flights.findById("1")
 | 
	
		
			
				|  |  | +				.flatMap(Flight::getAltitude)
 | 
	
		
			
				|  |  | +				.contextWrite(ReactiveSecurityContextHolder.withAuthentication(pilot)))
 | 
	
		
			
				|  |  | +			.expectNextCount(1)
 | 
	
		
			
				|  |  | +			.verifyComplete();
 | 
	
		
			
				|  |  | +		StepVerifier
 | 
	
		
			
				|  |  | +			.create(flights.findById("1")
 | 
	
		
			
				|  |  | +				.flatMap(Flight::getSeats)
 | 
	
		
			
				|  |  | +				.contextWrite(ReactiveSecurityContextHolder.withAuthentication(pilot)))
 | 
	
		
			
				|  |  | +			.expectNextCount(1)
 | 
	
		
			
				|  |  | +			.verifyComplete();
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	@Test
 | 
	
		
			
				|  |  | +	public void findByIdWhenUnauthorizedResultThenDenies() {
 | 
	
		
			
				|  |  | +		this.spring.register(AuthorizeResultConfig.class).autowire();
 | 
	
		
			
				|  |  | +		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
 | 
	
		
			
				|  |  | +		TestingAuthenticationToken pilot = new TestingAuthenticationToken("user", "pass", "seating:read");
 | 
	
		
			
				|  |  | +		StepVerifier
 | 
	
		
			
				|  |  | +			.create(flights.findById("1")
 | 
	
		
			
				|  |  | +				.flatMap(Flight::getSeats)
 | 
	
		
			
				|  |  | +				.contextWrite(ReactiveSecurityContextHolder.withAuthentication(pilot)))
 | 
	
		
			
				|  |  | +			.expectNextCount(1)
 | 
	
		
			
				|  |  | +			.verifyComplete();
 | 
	
		
			
				|  |  | +		StepVerifier
 | 
	
		
			
				|  |  | +			.create(flights.findById("1")
 | 
	
		
			
				|  |  | +				.flatMap(Flight::getAltitude)
 | 
	
		
			
				|  |  | +				.contextWrite(ReactiveSecurityContextHolder.withAuthentication(pilot)))
 | 
	
		
			
				|  |  | +			.verifyError(AccessDeniedException.class);
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	@Test
 | 
	
		
			
				|  |  | +	public void findAllWhenUnauthorizedResultThenDenies() {
 | 
	
		
			
				|  |  | +		this.spring.register(AuthorizeResultConfig.class).autowire();
 | 
	
		
			
				|  |  | +		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
 | 
	
		
			
				|  |  | +		TestingAuthenticationToken pilot = new TestingAuthenticationToken("user", "pass", "seating:read");
 | 
	
		
			
				|  |  | +		StepVerifier
 | 
	
		
			
				|  |  | +			.create(flights.findAll()
 | 
	
		
			
				|  |  | +				.flatMap(Flight::getSeats)
 | 
	
		
			
				|  |  | +				.contextWrite(ReactiveSecurityContextHolder.withAuthentication(pilot)))
 | 
	
		
			
				|  |  | +			.expectNextCount(2)
 | 
	
		
			
				|  |  | +			.verifyComplete();
 | 
	
		
			
				|  |  | +		StepVerifier
 | 
	
		
			
				|  |  | +			.create(flights.findAll()
 | 
	
		
			
				|  |  | +				.flatMap(Flight::getAltitude)
 | 
	
		
			
				|  |  | +				.contextWrite(ReactiveSecurityContextHolder.withAuthentication(pilot)))
 | 
	
		
			
				|  |  | +			.verifyError(AccessDeniedException.class);
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	@Test
 | 
	
		
			
				|  |  | +	public void removeWhenAuthorizedResultThenRemoves() {
 | 
	
		
			
				|  |  | +		this.spring.register(AuthorizeResultConfig.class).autowire();
 | 
	
		
			
				|  |  | +		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
 | 
	
		
			
				|  |  | +		TestingAuthenticationToken pilot = new TestingAuthenticationToken("user", "pass", "seating:read");
 | 
	
		
			
				|  |  | +		StepVerifier.create(flights.remove("1").contextWrite(ReactiveSecurityContextHolder.withAuthentication(pilot)))
 | 
	
		
			
				|  |  | +			.verifyComplete();
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	@Test
 | 
	
		
			
				|  |  | +	public void findAllWhenPostFilterThenFilters() {
 | 
	
		
			
				|  |  | +		this.spring.register(AuthorizeResultConfig.class).autowire();
 | 
	
		
			
				|  |  | +		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
 | 
	
		
			
				|  |  | +		TestingAuthenticationToken pilot = new TestingAuthenticationToken("user", "pass", "airplane:read");
 | 
	
		
			
				|  |  | +		StepVerifier
 | 
	
		
			
				|  |  | +			.create(flights.findAll()
 | 
	
		
			
				|  |  | +				.flatMap(Flight::getPassengers)
 | 
	
		
			
				|  |  | +				.flatMap(Passenger::getName)
 | 
	
		
			
				|  |  | +				.contextWrite(ReactiveSecurityContextHolder.withAuthentication(pilot)))
 | 
	
		
			
				|  |  | +			.expectNext("Marie Curie", "Ada Lovelace", "Albert Einstein")
 | 
	
		
			
				|  |  | +			.verifyComplete();
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	@Test
 | 
	
		
			
				|  |  | +	public void findAllWhenPreFilterThenFilters() {
 | 
	
		
			
				|  |  | +		this.spring.register(AuthorizeResultConfig.class).autowire();
 | 
	
		
			
				|  |  | +		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
 | 
	
		
			
				|  |  | +		TestingAuthenticationToken pilot = new TestingAuthenticationToken("user", "pass", "airplane:read");
 | 
	
		
			
				|  |  | +		StepVerifier
 | 
	
		
			
				|  |  | +			.create(flights.findAll()
 | 
	
		
			
				|  |  | +				.flatMap((flight) -> flight.board(Flux.just("John Doe", "John")).then(Mono.just(flight)))
 | 
	
		
			
				|  |  | +				.flatMap(Flight::getPassengers)
 | 
	
		
			
				|  |  | +				.flatMap(Passenger::getName)
 | 
	
		
			
				|  |  | +				.contextWrite(ReactiveSecurityContextHolder.withAuthentication(pilot)))
 | 
	
		
			
				|  |  | +			.expectNext("Marie Curie", "Ada Lovelace", "John Doe", "Albert Einstein", "John Doe")
 | 
	
		
			
				|  |  | +			.verifyComplete();
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	@Test
 | 
	
		
			
				|  |  | +	public void findAllWhenNestedPreAuthorizeThenAuthorizes() {
 | 
	
		
			
				|  |  | +		this.spring.register(AuthorizeResultConfig.class).autowire();
 | 
	
		
			
				|  |  | +		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
 | 
	
		
			
				|  |  | +		TestingAuthenticationToken pilot = new TestingAuthenticationToken("user", "pass", "seating:read");
 | 
	
		
			
				|  |  | +		StepVerifier
 | 
	
		
			
				|  |  | +			.create(flights.findAll()
 | 
	
		
			
				|  |  | +				.flatMap(Flight::getPassengers)
 | 
	
		
			
				|  |  | +				.flatMap(Passenger::getName)
 | 
	
		
			
				|  |  | +				.contextWrite(ReactiveSecurityContextHolder.withAuthentication(pilot)))
 | 
	
		
			
				|  |  | +			.verifyError(AccessDeniedException.class);
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  	@Configuration
 | 
	
		
			
				|  |  |  	@EnableReactiveMethodSecurity // this imports ReactiveMethodSecurityConfiguration
 | 
	
		
			
				|  |  |  	static class WithRolePrefixConfiguration {
 | 
	
	
		
			
				|  | @@ -108,4 +230,112 @@ public class ReactiveMethodSecurityConfigurationTests {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | +	@EnableReactiveMethodSecurity
 | 
	
		
			
				|  |  | +	@Configuration
 | 
	
		
			
				|  |  | +	static class AuthorizeResultConfig {
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		@Bean
 | 
	
		
			
				|  |  | +		FlightRepository flights() {
 | 
	
		
			
				|  |  | +			FlightRepository flights = new FlightRepository();
 | 
	
		
			
				|  |  | +			Flight one = new Flight("1", 35000d, 35);
 | 
	
		
			
				|  |  | +			one.board(Flux.just("Marie Curie", "Kevin Mitnick", "Ada Lovelace")).block();
 | 
	
		
			
				|  |  | +			flights.save(one).block();
 | 
	
		
			
				|  |  | +			Flight two = new Flight("2", 32000d, 72);
 | 
	
		
			
				|  |  | +			two.board(Flux.just("Albert Einstein")).block();
 | 
	
		
			
				|  |  | +			flights.save(two).block();
 | 
	
		
			
				|  |  | +			return flights;
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		@Bean
 | 
	
		
			
				|  |  | +		Function<Passenger, Mono<Boolean>> isNotKevin() {
 | 
	
		
			
				|  |  | +			return (passenger) -> passenger.getName().map((name) -> !name.equals("Kevin Mitnick"));
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	@AuthorizeReturnObject
 | 
	
		
			
				|  |  | +	static class FlightRepository {
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		private final Map<String, Flight> flights = new ConcurrentHashMap<>();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		Flux<Flight> findAll() {
 | 
	
		
			
				|  |  | +			return Flux.fromIterable(this.flights.values());
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		Mono<Flight> findById(String id) {
 | 
	
		
			
				|  |  | +			return Mono.just(this.flights.get(id));
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		Mono<Flight> save(Flight flight) {
 | 
	
		
			
				|  |  | +			this.flights.put(flight.getId(), flight);
 | 
	
		
			
				|  |  | +			return Mono.just(flight);
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		Mono<Void> remove(String id) {
 | 
	
		
			
				|  |  | +			this.flights.remove(id);
 | 
	
		
			
				|  |  | +			return Mono.empty();
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	static class Flight {
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		private final String id;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		private final Double altitude;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		private final Integer seats;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		private final List<Passenger> passengers = new ArrayList<>();
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		Flight(String id, Double altitude, Integer seats) {
 | 
	
		
			
				|  |  | +			this.id = id;
 | 
	
		
			
				|  |  | +			this.altitude = altitude;
 | 
	
		
			
				|  |  | +			this.seats = seats;
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		String getId() {
 | 
	
		
			
				|  |  | +			return this.id;
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		@PreAuthorize("hasAuthority('airplane:read')")
 | 
	
		
			
				|  |  | +		Mono<Double> getAltitude() {
 | 
	
		
			
				|  |  | +			return Mono.just(this.altitude);
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		@PreAuthorize("hasAnyAuthority('seating:read', 'airplane:read')")
 | 
	
		
			
				|  |  | +		Mono<Integer> getSeats() {
 | 
	
		
			
				|  |  | +			return Mono.just(this.seats);
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		@AuthorizeReturnObject
 | 
	
		
			
				|  |  | +		@PostAuthorize("hasAnyAuthority('seating:read', 'airplane:read')")
 | 
	
		
			
				|  |  | +		@PostFilter("@isNotKevin.apply(filterObject)")
 | 
	
		
			
				|  |  | +		Flux<Passenger> getPassengers() {
 | 
	
		
			
				|  |  | +			return Flux.fromIterable(this.passengers);
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		@PreAuthorize("hasAnyAuthority('seating:read', 'airplane:read')")
 | 
	
		
			
				|  |  | +		@PreFilter("filterObject.contains(' ')")
 | 
	
		
			
				|  |  | +		Mono<Void> board(Flux<String> passengers) {
 | 
	
		
			
				|  |  | +			return passengers.doOnNext((passenger) -> this.passengers.add(new Passenger(passenger))).then();
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	public static class Passenger {
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		String name;
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		public Passenger(String name) {
 | 
	
		
			
				|  |  | +			this.name = name;
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +		@PreAuthorize("hasAuthority('airplane:read')")
 | 
	
		
			
				|  |  | +		public Mono<String> getName() {
 | 
	
		
			
				|  |  | +			return Mono.just(this.name);
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  | +	}
 | 
	
		
			
				|  |  | +
 | 
	
		
			
				|  |  |  }
 |