ソースを参照

SEC-1338: Applied submitted patch, making use of java.util.concurrent classes in place of traditional synchronization.

Luke Taylor 15 年 前
コミット
ca44ebd3cc

+ 30 - 34
core/src/main/java/org/springframework/security/core/session/SessionRegistryImpl.java

@@ -15,21 +15,16 @@
 
 package org.springframework.security.core.session;
 
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.Date;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.springframework.context.ApplicationListener;
 import org.springframework.util.Assert;
 
+import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentMap;
+import java.util.concurrent.CopyOnWriteArraySet;
+
 /**
  * Default implementation of {@link org.springframework.security.core.session.SessionRegistry SessionRegistry}
  * which listens for {@link org.springframework.security.core.session.SessionDestroyedEvent SessionDestroyedEvent}s
@@ -49,14 +44,14 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener
     protected final Log logger = LogFactory.getLog(SessionRegistryImpl.class);
 
     /** <principal:Object,SessionIdSet> */
-    private final Map<Object,Set<String>> principals = Collections.synchronizedMap(new HashMap<Object,Set<String>>());
+    private final ConcurrentMap<Object,Set<String>> principals = new ConcurrentHashMap<Object,Set<String>>();
     /** <sessionId:Object,SessionInformation> */
-    private final Map<String, SessionInformation> sessionIds = Collections.synchronizedMap(new HashMap<String, SessionInformation>());
+    private final Map<String, SessionInformation> sessionIds = new ConcurrentHashMap<String, SessionInformation>();
 
     //~ Methods ========================================================================================================
 
     public List<Object> getAllPrincipals() {
-        return Arrays.asList(principals.keySet().toArray());
+        return new ArrayList<Object>(principals.keySet());
     }
 
     public List<SessionInformation> getAllSessions(Object principal, boolean includeExpiredSessions) {
@@ -68,17 +63,15 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener
 
         List<SessionInformation> list = new ArrayList<SessionInformation>(sessionsUsedByPrincipal.size());
 
-        synchronized (sessionsUsedByPrincipal) {
-            for (String sessionId : sessionsUsedByPrincipal) {
-                SessionInformation sessionInformation = getSessionInformation(sessionId);
+        for (String sessionId : sessionsUsedByPrincipal) {
+            SessionInformation sessionInformation = getSessionInformation(sessionId);
 
-                if (sessionInformation == null) {
-                    continue;
-                }
+            if (sessionInformation == null) {
+                continue;
+            }
 
-                if (includeExpiredSessions || !sessionInformation.isExpired()) {
-                    list.add(sessionInformation);
-                }
+            if (includeExpiredSessions || !sessionInformation.isExpired()) {
+                list.add(sessionInformation);
             }
         }
 
@@ -88,7 +81,7 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener
     public SessionInformation getSessionInformation(String sessionId) {
         Assert.hasText(sessionId, "SessionId required as per interface contract");
 
-        return (SessionInformation) sessionIds.get(sessionId);
+        return sessionIds.get(sessionId);
     }
 
     public void onApplicationEvent(SessionDestroyedEvent event) {
@@ -106,7 +99,7 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener
         }
     }
 
-    public synchronized void registerNewSession(String sessionId, Object principal) {
+    public void registerNewSession(String sessionId, Object principal) {
         Assert.hasText(sessionId, "SessionId required as per interface contract");
         Assert.notNull(principal, "Principal required as per interface contract");
 
@@ -123,8 +116,12 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener
         Set<String> sessionsUsedByPrincipal = principals.get(principal);
 
         if (sessionsUsedByPrincipal == null) {
-            sessionsUsedByPrincipal = Collections.synchronizedSet(new HashSet<String>(4));
-            principals.put(principal, sessionsUsedByPrincipal);
+            sessionsUsedByPrincipal = new CopyOnWriteArraySet<String>();
+            Set<String> prevSessionsUsedByPrincipal = principals.putIfAbsent(principal,
+                    sessionsUsedByPrincipal);
+            if (prevSessionsUsedByPrincipal != null) {
+                sessionsUsedByPrincipal = prevSessionsUsedByPrincipal;
+            }
         }
 
         sessionsUsedByPrincipal.add(sessionId);
@@ -159,20 +156,19 @@ public class SessionRegistryImpl implements SessionRegistry, ApplicationListener
             logger.debug("Removing session " + sessionId + " from principal's set of registered sessions");
         }
 
-        synchronized (sessionsUsedByPrincipal) {
-            sessionsUsedByPrincipal.remove(sessionId);
+        sessionsUsedByPrincipal.remove(sessionId);
 
-            if (sessionsUsedByPrincipal.size() == 0) {
-                // No need to keep object in principals Map anymore
-                if (logger.isDebugEnabled()) {
-                    logger.debug("Removing principal " + info.getPrincipal() + " from registry");
-                }
-                principals.remove(info.getPrincipal());
+        if (sessionsUsedByPrincipal.isEmpty()) {
+            // No need to keep object in principals Map anymore
+            if (logger.isDebugEnabled()) {
+                logger.debug("Removing principal " + info.getPrincipal() + " from registry");
             }
+            principals.remove(info.getPrincipal());
         }
 
         if (logger.isTraceEnabled()) {
             logger.trace("Sessions used by '" + info.getPrincipal() + "' : " + sessionsUsedByPrincipal);
         }
     }
+
 }