Browse Source

Support Spring Data container types for AuthorizeReturnObject

Closes gh-15994

Signed-off-by: Evgeniy Cheban <mister.cheban@gmail.com>
Evgeniy Cheban 4 months ago
parent
commit
fd4f06a66e

+ 51 - 1
config/src/main/java/org/springframework/security/config/annotation/method/configuration/AuthorizationProxyDataConfiguration.java

@@ -1,5 +1,5 @@
 /*
- * Copyright 2002-2024 the original author or authors.
+ * Copyright 2002-2025 the original author or authors.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -16,13 +16,22 @@
 
 package org.springframework.security.config.annotation.method.configuration;
 
+import java.util.List;
+
 import org.springframework.aop.framework.AopInfrastructureBean;
 import org.springframework.beans.factory.config.BeanDefinition;
 import org.springframework.context.annotation.Bean;
 import org.springframework.context.annotation.Configuration;
 import org.springframework.context.annotation.Role;
+import org.springframework.core.Ordered;
+import org.springframework.data.domain.PageImpl;
+import org.springframework.data.domain.SliceImpl;
+import org.springframework.data.geo.GeoPage;
+import org.springframework.data.geo.GeoResult;
+import org.springframework.data.geo.GeoResults;
 import org.springframework.security.aot.hint.SecurityHintsRegistrar;
 import org.springframework.security.authorization.AuthorizationProxyFactory;
+import org.springframework.security.authorization.method.AuthorizationAdvisorProxyFactory;
 import org.springframework.security.data.aot.hint.AuthorizeReturnObjectDataHintsRegistrar;
 
 @Configuration(proxyBeanMethods = false)
@@ -34,4 +43,45 @@ final class AuthorizationProxyDataConfiguration implements AopInfrastructureBean
 		return new AuthorizeReturnObjectDataHintsRegistrar(proxyFactory);
 	}
 
+	@Bean
+	@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
+	DataTargetVisitor dataTargetVisitor() {
+		return new DataTargetVisitor();
+	}
+
+	private static final class DataTargetVisitor implements AuthorizationAdvisorProxyFactory.TargetVisitor, Ordered {
+
+		private static final int DEFAULT_ORDER = 200;
+
+		@Override
+		public Object visit(AuthorizationAdvisorProxyFactory proxyFactory, Object target) {
+			if (target instanceof GeoResults<?> geoResults) {
+				return new GeoResults<>(proxyFactory.proxy(geoResults.getContent()), geoResults.getAverageDistance());
+			}
+			if (target instanceof GeoResult<?> geoResult) {
+				return new GeoResult<>(proxyFactory.proxy(geoResult.getContent()), geoResult.getDistance());
+			}
+			if (target instanceof GeoPage<?> geoPage) {
+				GeoResults<?> results = new GeoResults<>(proxyFactory.proxy(geoPage.getContent()),
+						geoPage.getAverageDistance());
+				return new GeoPage<>(results, geoPage.getPageable(), geoPage.getTotalElements());
+			}
+			if (target instanceof PageImpl<?> page) {
+				List<?> content = proxyFactory.proxy(page.getContent());
+				return new PageImpl<>(content, page.getPageable(), page.getTotalElements());
+			}
+			if (target instanceof SliceImpl<?> slice) {
+				List<?> content = proxyFactory.proxy(slice.getContent());
+				return new SliceImpl<>(content, slice.getPageable(), slice.hasNext());
+			}
+			return null;
+		}
+
+		@Override
+		public int getOrder() {
+			return DEFAULT_ORDER;
+		}
+
+	}
+
 }

+ 101 - 11
config/src/test/java/org/springframework/security/config/annotation/method/configuration/PrePostMethodSecurityConfigurationTests.java

@@ -63,7 +63,14 @@ import org.springframework.context.annotation.Role;
 import org.springframework.context.event.EventListener;
 import org.springframework.core.annotation.AnnotationAwareOrderComparator;
 import org.springframework.core.annotation.AnnotationConfigurationException;
-import org.springframework.core.annotation.Order;
+import org.springframework.data.domain.Page;
+import org.springframework.data.domain.PageImpl;
+import org.springframework.data.domain.Slice;
+import org.springframework.data.domain.SliceImpl;
+import org.springframework.data.geo.Distance;
+import org.springframework.data.geo.GeoPage;
+import org.springframework.data.geo.GeoResult;
+import org.springframework.data.geo.GeoResults;
 import org.springframework.http.HttpStatus;
 import org.springframework.http.HttpStatusCode;
 import org.springframework.http.MediaType;
@@ -756,6 +763,28 @@ public class PrePostMethodSecurityConfigurationTests {
 		assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(flight::getAltitude);
 	}
 
+	@Test
+	@WithMockUser(authorities = "airplane:read")
+	public void findGeoResultByIdWhenAuthorizedResultThenAuthorizes() {
+		this.spring.register(AuthorizeResultConfig.class).autowire();
+		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
+		GeoResult<Flight> geoResultFlight = flights.findGeoResultFlightById("1");
+		Flight flight = geoResultFlight.getContent();
+		assertThatNoException().isThrownBy(flight::getAltitude);
+		assertThatNoException().isThrownBy(flight::getSeats);
+	}
+
+	@Test
+	@WithMockUser(authorities = "seating:read")
+	public void findGeoResultByIdWhenUnauthorizedResultThenDenies() {
+		this.spring.register(AuthorizeResultConfig.class).autowire();
+		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
+		GeoResult<Flight> geoResultFlight = flights.findGeoResultFlightById("1");
+		Flight flight = geoResultFlight.getContent();
+		assertThatNoException().isThrownBy(flight::getSeats);
+		assertThatExceptionOfType(AccessDeniedException.class).isThrownBy(flight::getAltitude);
+	}
+
 	@Test
 	@WithMockUser(authorities = "airplane:read")
 	public void findByIdWhenAuthorizedResponseEntityThenAuthorizes() {
@@ -827,6 +856,46 @@ public class PrePostMethodSecurityConfigurationTests {
 				.doesNotContain("Kevin Mitnick"));
 	}
 
+	@Test
+	@WithMockUser(authorities = "airplane:read")
+	public void findPageWhenPostFilterThenFilters() {
+		this.spring.register(AuthorizeResultConfig.class).autowire();
+		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
+		flights.findPage()
+			.forEach((flight) -> assertThat(flight.getPassengers()).extracting(Passenger::getName)
+				.doesNotContain("Kevin Mitnick"));
+	}
+
+	@Test
+	@WithMockUser(authorities = "airplane:read")
+	public void findSliceWhenPostFilterThenFilters() {
+		this.spring.register(AuthorizeResultConfig.class).autowire();
+		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
+		flights.findSlice()
+			.forEach((flight) -> assertThat(flight.getPassengers()).extracting(Passenger::getName)
+				.doesNotContain("Kevin Mitnick"));
+	}
+
+	@Test
+	@WithMockUser(authorities = "airplane:read")
+	public void findGeoPageWhenPostFilterThenFilters() {
+		this.spring.register(AuthorizeResultConfig.class).autowire();
+		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
+		flights.findGeoPage()
+			.forEach((flight) -> assertThat(flight.getContent().getPassengers()).extracting(Passenger::getName)
+				.doesNotContain("Kevin Mitnick"));
+	}
+
+	@Test
+	@WithMockUser(authorities = "airplane:read")
+	public void findGeoResultsWhenPostFilterThenFilters() {
+		this.spring.register(AuthorizeResultConfig.class).autowire();
+		FlightRepository flights = this.spring.getContext().getBean(FlightRepository.class);
+		flights.findGeoResults()
+			.forEach((flight) -> assertThat(flight.getContent().getPassengers()).extracting(Passenger::getName)
+				.doesNotContain("Kevin Mitnick"));
+	}
+
 	@Test
 	@WithMockUser(authorities = "airplane:read")
 	public void findAllWhenPreFilterThenFilters() {
@@ -1762,16 +1831,8 @@ public class PrePostMethodSecurityConfigurationTests {
 
 		@Bean
 		@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
-		@Order(1)
-		static TargetVisitor mock() {
-			return Mockito.mock(TargetVisitor.class);
-		}
-
-		@Bean
-		@Role(BeanDefinition.ROLE_INFRASTRUCTURE)
-		@Order(0)
-		static TargetVisitor skipValueTypes() {
-			return TargetVisitor.defaultsSkipValueTypes();
+		static TargetVisitor customTargetVisitor() {
+			return TargetVisitor.of(Mockito.mock(), TargetVisitor.defaultsSkipValueTypes());
 		}
 
 		@Bean
@@ -1802,10 +1863,39 @@ public class PrePostMethodSecurityConfigurationTests {
 			return this.flights.values().iterator();
 		}
 
+		Page<Flight> findPage() {
+			return new PageImpl<>(new ArrayList<>(this.flights.values()));
+		}
+
+		Slice<Flight> findSlice() {
+			return new SliceImpl<>(new ArrayList<>(this.flights.values()));
+		}
+
+		GeoPage<Flight> findGeoPage() {
+			List<GeoResult<Flight>> results = new ArrayList<>();
+			for (Flight flight : this.flights.values()) {
+				results.add(new GeoResult<>(flight, new Distance(flight.altitude)));
+			}
+			return new GeoPage<>(new GeoResults<>(results));
+		}
+
+		GeoResults<Flight> findGeoResults() {
+			List<GeoResult<Flight>> results = new ArrayList<>();
+			for (Flight flight : this.flights.values()) {
+				results.add(new GeoResult<>(flight, new Distance(flight.altitude)));
+			}
+			return new GeoResults<>(results);
+		}
+
 		Flight findById(String id) {
 			return this.flights.get(id);
 		}
 
+		GeoResult<Flight> findGeoResultFlightById(String id) {
+			Flight flight = this.flights.get(id);
+			return new GeoResult<>(flight, new Distance(flight.altitude));
+		}
+
 		Flight save(Flight flight) {
 			this.flights.put(flight.getId(), flight);
 			return flight;