Browse Source

SEC-1012: Refactor SessionRegistry interface to use Java 5 generics.

Luke Taylor 16 years ago
parent
commit
ba6664f77f

+ 21 - 21
core/src/main/java/org/springframework/security/authentication/concurrent/ConcurrentSessionControllerImpl.java

@@ -15,6 +15,8 @@
 
 package org.springframework.security.authentication.concurrent;
 
+import java.util.List;
+
 import org.springframework.security.core.Authentication;
 import org.springframework.security.core.AuthenticationException;
 import org.springframework.security.core.SpringSecurityMessageSource;
@@ -29,14 +31,13 @@ import org.springframework.util.Assert;
 
 
 /**
- * Base implementation of {@link ConcurrentSessionControllerImpl} which prohibits simultaneous logins.<p>By default
- * uses {@link SessionRegistryImpl}, although any <code>SessionRegistry</code> may be used.</p>
+ * Base implementation of {@link ConcurrentSessionControllerImpl} which prohibits simultaneous logins.
  *
  * @author Ben Alex
  * @version $Id$
  */
 public class ConcurrentSessionControllerImpl implements ConcurrentSessionController, InitializingBean,
-    MessageSourceAware {
+        MessageSourceAware {
     //~ Instance fields ================================================================================================
 
     protected MessageSourceAccessor messages = SpringSecurityMessageSource.getAccessor();
@@ -61,10 +62,10 @@ public class ConcurrentSessionControllerImpl implements ConcurrentSessionControl
      * @param allowableSessions DOCUMENT ME!
      * @param registry an instance of the <code>SessionRegistry</code> for subclass use
      *
-     * @throws ConcurrentLoginException DOCUMENT ME!
+     * @throws ConcurrentLoginException if the
      */
-    protected void allowableSessionsExceeded(String sessionId, SessionInformation[] sessions, int allowableSessions,
-        SessionRegistry registry) {
+    protected void allowableSessionsExceeded(String sessionId, List<SessionInformation> sessions, int allowableSessions,
+            SessionRegistry registry) {
         if (exceptionIfMaximumExceeded || (sessions == null)) {
             throw new ConcurrentLoginException(messages.getMessage("ConcurrentSessionControllerImpl.exceededAllowed",
                     new Object[] {new Integer(allowableSessions)},
@@ -74,30 +75,25 @@ public class ConcurrentSessionControllerImpl implements ConcurrentSessionControl
         // Determine least recently used session, and mark it for invalidation
         SessionInformation leastRecentlyUsed = null;
 
-        for (int i = 0; i < sessions.length; i++) {
+        for (int i = 0; i < sessions.size(); i++) {
             if ((leastRecentlyUsed == null)
-                    || sessions[i].getLastRequest().before(leastRecentlyUsed.getLastRequest())) {
-                leastRecentlyUsed = sessions[i];
+                    || sessions.get(i).getLastRequest().before(leastRecentlyUsed.getLastRequest())) {
+                leastRecentlyUsed = sessions.get(i);
             }
         }
 
         leastRecentlyUsed.expireNow();
     }
 
-    public void checkAuthenticationAllowed(Authentication request)
-        throws AuthenticationException {
+    public void checkAuthenticationAllowed(Authentication request) throws AuthenticationException {
         Assert.notNull(request, "Authentication request cannot be null (violation of interface contract)");
 
         Object principal = SessionRegistryUtils.obtainPrincipalFromAuthentication(request);
         String sessionId = SessionRegistryUtils.obtainSessionIdFromAuthentication(request);
 
-        SessionInformation[] sessions = sessionRegistry.getAllSessions(principal, false);
-
-        int sessionCount = 0;
+        final List<SessionInformation> sessions = sessionRegistry.getAllSessions(principal, false);
 
-        if (sessions != null) {
-            sessionCount = sessions.length;
-        }
+        int sessionCount = sessions == null ? 0 : sessions.size();
 
         int allowableSessions = getMaximumSessionsForThisUser(request);
         Assert.isTrue(allowableSessions != 0, "getMaximumSessionsForThisUser() must return either -1 to allow "
@@ -106,13 +102,17 @@ public class ConcurrentSessionControllerImpl implements ConcurrentSessionControl
         if (sessionCount < allowableSessions) {
             // They haven't got too many login sessions running at present
             return;
-        } else if (allowableSessions == -1) {
+        }
+
+        if (allowableSessions == -1) {
             // We permit unlimited logins
             return;
-        } else if (sessionCount == allowableSessions) {
+        }
+
+        if (sessionCount == allowableSessions) {
             // Only permit it though if this request is associated with one of the sessions
-            for (int i = 0; i < sessionCount; i++) {
-                if (sessions[i].getSessionId().equals(sessionId)) {
+            for (SessionInformation si : sessions) {
+                if (si.getSessionId().equals(sessionId)) {
                     return;
                 }
             }

+ 4 - 2
core/src/main/java/org/springframework/security/authentication/concurrent/SessionRegistry.java

@@ -15,6 +15,8 @@
 
 package org.springframework.security.authentication.concurrent;
 
+import java.util.List;
+
 /**
  * Maintains a registry of <code>SessionInformation</code> instances.
  *
@@ -29,7 +31,7 @@ public interface SessionRegistry {
      *
      * @return each of the unique principals, which can then be presented to {@link #getAllSessions(Object, boolean)}.
      */
-    Object[] getAllPrincipals();
+    List<Object> getAllPrincipals();
 
     /**
      * Obtains all the known sessions for the specified principal. Sessions that have been destroyed are not
@@ -41,7 +43,7 @@ public interface SessionRegistry {
      *
      * @return the matching sessions for this principal, or <code>null</code> if none were found
      */
-    SessionInformation[] getAllSessions(Object principal, boolean includeExpiredSessions);
+    List<SessionInformation> getAllSessions(Object principal, boolean includeExpiredSessions);
 
     /**
      * Obtains the session information for the specified <code>sessionId</code>. Even expired sessions are

+ 7 - 6
core/src/main/java/org/springframework/security/authentication/concurrent/SessionRegistryImpl.java

@@ -16,6 +16,7 @@
 package org.springframework.security.authentication.concurrent;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collections;
 import java.util.Date;
 import java.util.HashMap;
@@ -57,18 +58,18 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener
 
     // ~ Methods =======================================================================================================
 
-    public Object[] getAllPrincipals() {
-        return principals.keySet().toArray();
+    public List<Object> getAllPrincipals() {
+        return Arrays.asList(principals.keySet().toArray());
     }
 
-    public SessionInformation[] getAllSessions(Object principal, boolean includeExpiredSessions) {
-        Set<String> sessionsUsedByPrincipal = principals.get(principal);
+    public List<SessionInformation> getAllSessions(Object principal, boolean includeExpiredSessions) {
+        final Set<String> sessionsUsedByPrincipal = principals.get(principal);
 
         if (sessionsUsedByPrincipal == null) {
             return null;
         }
 
-        List<SessionInformation> list = new ArrayList<SessionInformation>();
+        List<SessionInformation> list = new ArrayList<SessionInformation>(sessionsUsedByPrincipal.size());
 
         synchronized (sessionsUsedByPrincipal) {
             for (String sessionId : sessionsUsedByPrincipal) {
@@ -84,7 +85,7 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener
             }
         }
 
-        return (SessionInformation[]) list.toArray(new SessionInformation[0]);
+        return list;
     }
 
     public SessionInformation getSessionInformation(String sessionId) {

+ 16 - 14
core/src/test/java/org/springframework/security/authentication/concurrent/SessionRegistryImplTests.java

@@ -18,6 +18,7 @@ package org.springframework.security.authentication.concurrent;
 import static org.junit.Assert.*;
 
 import java.util.Date;
+import java.util.List;
 
 import org.junit.Before;
 import org.junit.Test;
@@ -77,8 +78,9 @@ public class SessionRegistryImplTests {
         sessionRegistry.registerNewSession(sessionId2, principal1);
         sessionRegistry.registerNewSession(sessionId3, principal2);
 
-        assertEquals(principal1, sessionRegistry.getAllPrincipals()[0]);
-        assertEquals(principal2, sessionRegistry.getAllPrincipals()[1]);
+        assertEquals(2, sessionRegistry.getAllPrincipals().size());
+        assertTrue(sessionRegistry.getAllPrincipals().contains(principal1));
+        assertTrue(sessionRegistry.getAllPrincipals().contains(principal2));
     }
 
     @Test
@@ -95,7 +97,7 @@ public class SessionRegistryImplTests {
         assertNotNull(sessionRegistry.getSessionInformation(sessionId).getLastRequest());
 
         // Retrieve existing session by principal
-        assertEquals(1, sessionRegistry.getAllSessions(principal, false).length);
+        assertEquals(1, sessionRegistry.getAllSessions(principal, false).size());
 
         // Sleep to ensure SessionRegistryImpl will update time
         Thread.sleep(1000);
@@ -107,7 +109,7 @@ public class SessionRegistryImplTests {
         assertTrue(retrieved.after(currentDateTime));
 
         // Check it retrieves correctly when looked up via principal
-        assertEquals(retrieved, sessionRegistry.getAllSessions(principal, false)[0].getLastRequest());
+        assertEquals(retrieved, sessionRegistry.getAllSessions(principal, false).get(0).getLastRequest());
 
         // Clear session information
         sessionRegistry.removeSessionInformation(sessionId);
@@ -124,13 +126,13 @@ public class SessionRegistryImplTests {
         String sessionId2 = "9876543210";
 
         sessionRegistry.registerNewSession(sessionId1, principal);
-        SessionInformation[] sessions = sessionRegistry.getAllSessions(principal, false);
-        assertEquals(1, sessions.length);
+        List<SessionInformation> sessions = sessionRegistry.getAllSessions(principal, false);
+        assertEquals(1, sessions.size());
         assertTrue(contains(sessionId1, principal));
 
         sessionRegistry.registerNewSession(sessionId2, principal);
         sessions = sessionRegistry.getAllSessions(principal, false);
-        assertEquals(2, sessions.length);
+        assertEquals(2, sessions.size());
         assertTrue(contains(sessionId2, principal));
 
         // Expire one session
@@ -149,18 +151,18 @@ public class SessionRegistryImplTests {
         String sessionId2 = "9876543210";
 
         sessionRegistry.registerNewSession(sessionId1, principal);
-        SessionInformation[] sessions = sessionRegistry.getAllSessions(principal, false);
-        assertEquals(1, sessions.length);
+        List<SessionInformation> sessions = sessionRegistry.getAllSessions(principal, false);
+        assertEquals(1, sessions.size());
         assertTrue(contains(sessionId1, principal));
 
         sessionRegistry.registerNewSession(sessionId2, principal);
         sessions = sessionRegistry.getAllSessions(principal, false);
-        assertEquals(2, sessions.length);
+        assertEquals(2, sessions.size());
         assertTrue(contains(sessionId2, principal));
 
         sessionRegistry.removeSessionInformation(sessionId1);
         sessions = sessionRegistry.getAllSessions(principal, false);
-        assertEquals(1, sessions.length);
+        assertEquals(1, sessions.size());
         assertTrue(contains(sessionId2, principal));
 
         sessionRegistry.removeSessionInformation(sessionId2);
@@ -169,10 +171,10 @@ public class SessionRegistryImplTests {
     }
 
     private boolean contains(String sessionId, Object principal) {
-        SessionInformation[] info = sessionRegistry.getAllSessions(principal, false);
+        List<SessionInformation> info = sessionRegistry.getAllSessions(principal, false);
 
-        for (int i = 0; i < info.length; i++) {
-            if (sessionId.equals(info[i].getSessionId())) {
+        for (int i = 0; i < info.size(); i++) {
+            if (sessionId.equals(info.get(i).getSessionId())) {
                 return true;
             }
         }