Parcourir la source

WebSessionReactiveSecurityRepository Supports Cache

Rob Winch il y a 3 ans
Parent
commit
1ef738ba34

+ 14 - 1
web/src/main/java/org/springframework/security/web/server/context/WebSessionServerSecurityContextRepository.java

@@ -46,6 +46,8 @@ public class WebSessionServerSecurityContextRepository implements ServerSecurity
 
 	private String springSecurityContextAttrName = DEFAULT_SPRING_SECURITY_CONTEXT_ATTR_NAME;
 
+	private boolean cacheSecurityContext;
+
 	/**
 	 * Sets the session attribute name used to save and load the {@link SecurityContext}
 	 * @param springSecurityContextAttrName the session attribute name to use to save and
@@ -56,6 +58,16 @@ public class WebSessionServerSecurityContextRepository implements ServerSecurity
 		this.springSecurityContextAttrName = springSecurityContextAttrName;
 	}
 
+	/**
+	 * If set to true the result of {@link #load(ServerWebExchange)} will use
+	 * {@link Mono#cache()} to prevent multiple lookups.
+	 * @param cacheSecurityContext true if {@link Mono#cache()} should be used, else
+	 * false.
+	 */
+	public void setCacheSecurityContext(boolean cacheSecurityContext) {
+		this.cacheSecurityContext = cacheSecurityContext;
+	}
+
 	@Override
 	public Mono<Void> save(ServerWebExchange exchange, SecurityContext context) {
 		return exchange.getSession().doOnNext((session) -> {
@@ -72,13 +84,14 @@ public class WebSessionServerSecurityContextRepository implements ServerSecurity
 
 	@Override
 	public Mono<SecurityContext> load(ServerWebExchange exchange) {
-		return exchange.getSession().flatMap((session) -> {
+		Mono<SecurityContext> result = exchange.getSession().flatMap((session) -> {
 			SecurityContext context = (SecurityContext) session.getAttribute(this.springSecurityContextAttrName);
 			logger.debug((context != null)
 					? LogMessage.format("Found SecurityContext '%s' in WebSession: '%s'", context, session)
 					: LogMessage.format("No SecurityContext found in WebSession: '%s'", session));
 			return Mono.justOrEmpty(context);
 		});
+		return (cacheSecurityContext) ? result.cache() : result;
 	}
 
 }

+ 26 - 0
web/src/test/java/org/springframework/security/web/server/context/WebSessionServerSecurityContextRepositoryTests.java

@@ -17,14 +17,19 @@
 package org.springframework.security.web.server.context;
 
 import org.junit.jupiter.api.Test;
+import reactor.core.publisher.Mono;
+import reactor.test.publisher.PublisherProbe;
 
 import org.springframework.mock.http.server.reactive.MockServerHttpRequest;
 import org.springframework.mock.web.server.MockServerWebExchange;
 import org.springframework.security.core.context.SecurityContext;
 import org.springframework.security.core.context.SecurityContextImpl;
+import org.springframework.web.server.ServerWebExchange;
 import org.springframework.web.server.WebSession;
 
 import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.BDDMockito.given;
+import static org.mockito.Mockito.mock;
 
 /**
  * @author Rob Winch
@@ -79,4 +84,25 @@ public class WebSessionServerSecurityContextRepositoryTests {
 		assertThat(context).isNull();
 	}
 
+	@Test
+	public void loadWhenCacheSecurityContextThenSubscribeOnce() {
+		PublisherProbe<WebSession> webSession = PublisherProbe.empty();
+		ServerWebExchange exchange = mock(ServerWebExchange.class);
+		given(exchange.getSession()).willReturn(webSession.mono());
+		this.repository.setCacheSecurityContext(true);
+		Mono<SecurityContext> context = this.repository.load(exchange);
+		assertThat(context.block()).isSameAs(context.block());
+		assertThat(webSession.subscribeCount()).isEqualTo(1);
+	}
+
+	@Test
+	public void loadWhenNotCacheSecurityContextThenSubscribeMultiple() {
+		PublisherProbe<WebSession> webSession = PublisherProbe.empty();
+		ServerWebExchange exchange = mock(ServerWebExchange.class);
+		given(exchange.getSession()).willReturn(webSession.mono());
+		Mono<SecurityContext> context = this.repository.load(exchange);
+		assertThat(context.block()).isSameAs(context.block());
+		assertThat(webSession.subscribeCount()).isEqualTo(2);
+	}
+
 }