2
0
Эх сурвалжийг харах

Do Not Invalidate Current Session When It Is Registered

Closes gh-15066
Joaquin Santana 1 жил өмнө
parent
commit
927840fe88

+ 3 - 0
web/src/main/java/org/springframework/security/web/server/authentication/InvalidateLeastUsedServerMaximumSessionsExceededHandler.java

@@ -24,6 +24,7 @@ import reactor.core.publisher.Flux;
 import reactor.core.publisher.Mono;
 
 import org.springframework.security.core.session.ReactiveSessionInformation;
+import org.springframework.util.Assert;
 import org.springframework.web.server.session.WebSessionStore;
 
 /**
@@ -42,12 +43,14 @@ public final class InvalidateLeastUsedServerMaximumSessionsExceededHandler
 	private final WebSessionStore webSessionStore;
 
 	public InvalidateLeastUsedServerMaximumSessionsExceededHandler(WebSessionStore webSessionStore) {
+		Assert.notNull(webSessionStore, "webSessionStore cannot be null");
 		this.webSessionStore = webSessionStore;
 	}
 
 	@Override
 	public Mono<Void> handle(MaximumSessionsContext context) {
 		List<ReactiveSessionInformation> sessions = new ArrayList<>(context.getSessions());
+		sessions.removeIf((session) -> session.getSessionId().equals(context.getCurrentSession().getId()));
 		sessions.sort(Comparator.comparing(ReactiveSessionInformation::getLastAccessTime));
 		int maximumSessionsExceededBy = sessions.size() - context.getMaximumSessionsAllowed() + 1;
 		List<ReactiveSessionInformation> leastRecentlyUsedSessionsToInvalidate = sessions.subList(0,

+ 41 - 5
web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java

@@ -23,6 +23,7 @@ import org.junit.jupiter.api.BeforeEach;
 import org.junit.jupiter.api.Test;
 import reactor.core.publisher.Mono;
 
+import org.springframework.mock.web.server.MockWebSession;
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.session.ReactiveSessionInformation;
 import org.springframework.security.web.server.authentication.InvalidateLeastUsedServerMaximumSessionsExceededHandler;
@@ -34,6 +35,7 @@ import static org.mockito.BDDMockito.given;
 import static org.mockito.Mockito.atLeastOnce;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.verifyNoMoreInteractions;
 
@@ -59,18 +61,21 @@ class InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests {
 		ReactiveSessionInformation session2 = mock(ReactiveSessionInformation.class);
 		given(session1.getLastAccessTime()).willReturn(Instant.ofEpochMilli(1700827760010L));
 		given(session2.getLastAccessTime()).willReturn(Instant.ofEpochMilli(1700827760000L));
+		given(session1.getSessionId()).willReturn("session1");
 		given(session2.getSessionId()).willReturn("session2");
 		given(session2.invalidate()).willReturn(Mono.empty());
 
 		MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class),
-				List.of(session1, session2), 2, null);
+				List.of(session1, session2), 2, createWebSession());
 
 		this.handler.handle(context).block();
 
 		verify(session2).invalidate();
 		verify(session1).getLastAccessTime(); // used by comparator to sort the sessions
 		verify(session2).getLastAccessTime(); // used by comparator to sort the sessions
-		verify(session2).getSessionId(); // used to invalidate session against the
+		verify(session1).getSessionId();
+		verify(session2, times(2)).getSessionId(); // used to invalidate session against
+													// the
 		// WebSessionStore
 		verify(this.webSessionStore).removeSession("session2");
 		verifyNoMoreInteractions(this.webSessionStore);
@@ -90,16 +95,18 @@ class InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests {
 		given(session2.invalidate()).willReturn(Mono.empty());
 		given(session1.getSessionId()).willReturn("session1");
 		given(session2.getSessionId()).willReturn("session2");
+		given(session3.getSessionId()).willReturn("session3");
 
 		MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class),
-				List.of(session1, session2, session3), 2, null);
+				List.of(session1, session2, session3), 2, createWebSession());
 		this.handler.handle(context).block();
 
 		// @formatter:off
 		verify(session1).invalidate();
 		verify(session2).invalidate();
-		verify(session1).getSessionId();
-		verify(session2).getSessionId();
+		verify(session1, times(2)).getSessionId();
+		verify(session2, times(2)).getSessionId();
+		verify(session3).getSessionId();
 		verify(session1, atLeastOnce()).getLastAccessTime(); // used by comparator to sort the sessions
 		verify(session2, atLeastOnce()).getLastAccessTime(); // used by comparator to sort the sessions
 		verify(session3, atLeastOnce()).getLastAccessTime(); // used by comparator to sort the sessions
@@ -112,4 +119,33 @@ class InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests {
 		// @formatter:on
 	}
 
+	@Test
+	void handleWhenCurrentSessionIsRegisteredThenDoNotInvalidateCurrentSession() {
+		ReactiveSessionInformation session1 = mock(ReactiveSessionInformation.class);
+		ReactiveSessionInformation session2 = mock(ReactiveSessionInformation.class);
+		MockWebSession currentSession = createWebSession();
+		given(session1.getLastAccessTime()).willReturn(Instant.ofEpochMilli(1700827760010L));
+		given(session2.getLastAccessTime()).willReturn(Instant.ofEpochMilli(1700827760000L));
+		given(session1.getSessionId()).willReturn("session1");
+		given(session2.getSessionId()).willReturn(currentSession.getId());
+		given(session1.invalidate()).willReturn(Mono.empty());
+
+		MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class),
+				List.of(session1, session2), 1, currentSession);
+
+		this.handler.handle(context).block();
+
+		verify(session1).invalidate();
+		verify(session2).getSessionId();
+		verify(session1, times(2)).getSessionId();
+		verify(this.webSessionStore).removeSession("session1");
+		verifyNoMoreInteractions(this.webSessionStore);
+		verifyNoMoreInteractions(session2);
+		verifyNoMoreInteractions(session1);
+	}
+
+	private MockWebSession createWebSession() {
+		return new MockWebSession();
+	}
+
 }