Quellcode durchsuchen

SEC-588: Completed JdbcTokenRepositoryImpl and added extra update method to PersistentTokenRepository interface.

Luke Taylor vor 18 Jahren
Ursprung
Commit
7caa1587b3

+ 6 - 6
core/src/main/java/org/springframework/security/ui/rememberme/AbstractRememberMeServices.java

@@ -48,10 +48,10 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
     private int tokenValiditySeconds = 1209600; // 14 days
 
     public void afterPropertiesSet() throws Exception {
-        Assert.hasLength(key);        
+        Assert.hasLength(key);
         Assert.hasLength(parameter);
         Assert.hasLength(cookieName);
-        Assert.notNull(userDetailsService);        
+        Assert.notNull(userDetailsService);
     }
 
     /**
@@ -81,7 +81,7 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
             cancelCookie(request, response);
             throw cte;
         } catch (UsernameNotFoundException noUser) {
-            cancelCookie(request, response);            
+            cancelCookie(request, response);
             logger.debug("Remember-me login was valid but corresponding user not found.", noUser);
             return null;
         } catch (InvalidCookieException invalidCookie) {
@@ -90,7 +90,7 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
             return null;
         } catch (RememberMeAuthenticationException e) {
             cancelCookie(request, response);
-            logger.debug("autoLogin failed", e);
+            logger.debug(e.getMessage());
             return null;
         }
 
@@ -286,7 +286,7 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
         }
         cancelCookie(request, response);
     }
-    
+
     public void setCookieName(String cookieName) {
         this.cookieName = cookieName;
     }
@@ -322,7 +322,7 @@ public abstract class AbstractRememberMeServices implements RememberMeServices,
     public int getTokenValiditySeconds() {
         return tokenValiditySeconds;
     }
-    
+
     public AuthenticationDetailsSource getAuthenticationDetailsSource() {
         return authenticationDetailsSource;
     }

+ 137 - 12
core/src/main/java/org/springframework/security/ui/rememberme/JdbcTokenRepositoryImpl.java

@@ -1,34 +1,159 @@
 package org.springframework.security.ui.rememberme;
 
+import org.springframework.dao.DataAccessException;
+import org.springframework.dao.IncorrectResultSizeDataAccessException;
+import org.springframework.jdbc.core.SqlParameter;
 import org.springframework.jdbc.core.support.JdbcDaoSupport;
+import org.springframework.jdbc.object.MappingSqlQuery;
+import org.springframework.jdbc.object.SqlUpdate;
+
+import javax.sql.DataSource;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Types;
+import java.util.Date;
 
 /**
- * 
+ * JDBC based persistent login token repository implementation.
+ *
  * @author Luke Taylor
  * @version $Id$
  */
 public class JdbcTokenRepositoryImpl extends JdbcDaoSupport implements PersistentTokenRepository {
-    //~ Static fields/initializers =====================================================================================    
-    public static final String DEF_TOKEN_BY_SERIES_QUERY =
-            "select username,series,token from persistent_logins where series = ?";
-    public static final String DEF_INSERT_TOKEN_STATEMENT =
-            "insert into persistent_logins (username,series,token) values(?,?,?)";
-    public static final String DEF_REMOVE_USER_TOKENS_STATEMENT =
+    //~ Static fields/initializers =====================================================================================
+
+    /** Default SQL for creating the database table to store the tokens */
+    public static final String CREATE_TABLE_SQL =
+            "create table persistent_logins (username varchar(64) not null, series varchar(64) primary key, " +
+                    "token varchar(64) not null, last_used timestamp not null)";
+    /** The default SQL used by the <tt>getTokenBySeries</tt> query */
+    public static final String DEF_TOKEN_BY_SERIES_SQL =
+            "select username,series,token,last_used from persistent_logins where series = ?";
+    /** The default SQL used by <tt>createNewToken</tt> */
+    public static final String DEF_INSERT_TOKEN_SQL =
+            "insert into persistent_logins (username, series, token, last_used) values(?,?,?,?)";
+    /** The default SQL used by <tt>updateToken</tt> */
+    public static final String DEF_UPDATE_TOKEN_SQL =
+            "update persistent_logins set token = ?, last_used = ? where series = ?";
+    /** The default SQL used by <tt>removeUserTokens</tt> */
+    public static final String DEF_REMOVE_USER_TOKENS_SQL =
             "delete from persistent_logins where username = ?";
 
     //~ Instance fields ================================================================================================
 
-    private String tokensBySeriesQuery = DEF_TOKEN_BY_SERIES_QUERY;
-    private String insertTokenStatement = DEF_INSERT_TOKEN_STATEMENT;
-    private String removeUserTokensStatement = DEF_REMOVE_USER_TOKENS_STATEMENT;
+    private String tokensBySeriesSql = DEF_TOKEN_BY_SERIES_SQL;
+    private String insertTokenSql = DEF_INSERT_TOKEN_SQL;
+    private String updateTokenSql = DEF_UPDATE_TOKEN_SQL;
+    private String removeUserTokensSql = DEF_REMOVE_USER_TOKENS_SQL;
+    private boolean createTableOnStartup;
+
+    protected MappingSqlQuery tokensBySeriesMapping;
+    protected SqlUpdate insertToken;
+    protected SqlUpdate updateToken;
+    protected SqlUpdate removeUserTokens;
+
+    protected void initDao() {
+        tokensBySeriesMapping = new TokensBySeriesMapping(getDataSource());
+        insertToken = new InsertToken(getDataSource());
+        updateToken = new UpdateToken(getDataSource());
+        removeUserTokens = new RemoveUserTokens(getDataSource());
+
+        if (createTableOnStartup) {
+            getJdbcTemplate().execute(CREATE_TABLE_SQL);
+        }
+    }
+
+    public void createNewToken(PersistentRememberMeToken token) {
+        insertToken.update(
+                new Object[] {token.getUsername(), token.getSeries(), token.getTokenValue(), token.getDate()});
+    }
 
-    public void saveToken(PersistentRememberMeToken token) {
+    public void updateToken(String series, String tokenValue, Date lastUsed) {
+        updateToken.update(new Object[] {tokenValue, new Date(), series});
     }
 
+    /**
+     * Loads the token data for the supplied series identifier.
+     *
+     * If an error occurs, it will be reported and null will be returned (since the result should just be a failed
+     * persistent login).
+     *
+     * @param seriesId
+     * @return the token matching the series, or null if no match found or an exception occurred.
+     */
     public PersistentRememberMeToken getTokenForSeries(String seriesId) {
+        try {
+            return (PersistentRememberMeToken) tokensBySeriesMapping.findObject(seriesId);
+        } catch(IncorrectResultSizeDataAccessException moreThanOne) {
+            logger.error("Querying token for series '" + seriesId + "' returned more than one value. Series" +
+                    "should be unique");
+        } catch(DataAccessException e) {
+            logger.error("Failed to load token for series " + seriesId, e);
+        }
+
         return null;
     }
 
-    public void removeAllTokens(String username) {
+    public void removeUserTokens(String username) {
+        removeUserTokens.update(username);
+    }
+
+    /**
+     * Intended for convenience in debugging. Will create the persistent_tokens database table when the class
+     * is initialized during the initDao method.
+     *
+     * @param createTableOnStartup set to true to execute the
+     */
+    public void setCreateTableOnStartup(boolean createTableOnStartup) {
+        this.createTableOnStartup = createTableOnStartup;
+    }
+
+    //~ Inner Classes ==================================================================================================
+
+    protected class TokensBySeriesMapping extends MappingSqlQuery {
+        protected TokensBySeriesMapping(DataSource ds) {
+            super(ds, tokensBySeriesSql);
+            declareParameter(new SqlParameter(Types.VARCHAR));
+            compile();
+        }
+
+        protected Object mapRow(ResultSet rs, int rowNum) throws SQLException {
+            PersistentRememberMeToken token =
+                    new PersistentRememberMeToken(rs.getString(1), rs.getString(2), rs.getString(3), rs.getTimestamp(4));
+
+            return token;
+        }
+    }
+
+    protected class UpdateToken extends SqlUpdate {
+
+        public UpdateToken(DataSource ds) {
+            super(ds, updateTokenSql);
+            setMaxRowsAffected(1);
+            declareParameter(new SqlParameter(Types.VARCHAR));
+            declareParameter(new SqlParameter(Types.TIMESTAMP));
+            declareParameter(new SqlParameter(Types.VARCHAR));
+            compile();
+        }
+    }
+
+    protected class InsertToken extends SqlUpdate {
+
+        public InsertToken(DataSource ds) {
+            super(ds, insertTokenSql);
+            declareParameter(new SqlParameter(Types.VARCHAR));
+            declareParameter(new SqlParameter(Types.VARCHAR));
+            declareParameter(new SqlParameter(Types.VARCHAR));
+            declareParameter(new SqlParameter(Types.TIMESTAMP));
+            compile();
+        }
+    }
+
+    protected class RemoveUserTokens extends SqlUpdate {
+        public RemoveUserTokens(DataSource ds) {
+            super(ds, removeUserTokensSql);
+            declareParameter(new SqlParameter(Types.VARCHAR));
+            compile();
+        }
     }
 }

+ 38 - 26
core/src/main/java/org/springframework/security/ui/rememberme/PersistentTokenBasedRememberMeServices.java

@@ -1,6 +1,7 @@
 package org.springframework.security.ui.rememberme;
 
 import org.apache.commons.codec.binary.Base64;
+import org.springframework.dao.DataAccessException;
 import org.springframework.security.Authentication;
 
 import javax.servlet.http.HttpServletRequest;
@@ -10,7 +11,7 @@ import java.util.Arrays;
 import java.util.Date;
 
 /**
- * {@link RememberMeServices} implementation based on Barry Jaspan's 
+ * {@link RememberMeServices} implementation based on Barry Jaspan's
  * <a href="http://jaspan.com/improved_persistent_login_cookie_best_practice">Improved Persistent Login Cookie
  * Best Practice</a>.
  *
@@ -80,7 +81,7 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe
         // We have a match for this user/series combination
         if (!presentedToken.equals(token.getTokenValue())) {
             // Token doesn't match series value. Delete all logins for this user and throw an exception to warn them.
-            tokenRepository.removeAllTokens(token.getUsername());
+            tokenRepository.removeUserTokens(token.getUsername());
 
             throw new CookieTheftException(messages.getMessage("PersistentTokenBasedRememberMeServices.cookieStolen",
                     "Invalid remember-me token (Series/token) mismatch. Implies previous cookie theft attack."));
@@ -90,10 +91,22 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe
             throw new RememberMeAuthenticationException("Remember-me login has expired");
         }
 
-        // Token also matches, so login is valid. create and save new token with the *same* series number.
-        PersistentRememberMeToken newToken = createNewToken(token.getUsername(), token.getSeries());
+        // Token also matches, so login is valid. Update the token value, keeping the *same* series number.
+        if (logger.isDebugEnabled()) {
+            logger.debug("Refreshing persistent login token for user '" + token.getUsername() + "', series '" +
+                    token.getSeries() + "'");
+        }
+
+        PersistentRememberMeToken newToken = new PersistentRememberMeToken(token.getUsername(),
+                token.getSeries(), generateTokenData(), new Date());
 
-        addCookie(newToken, request, response);
+        try {
+            tokenRepository.updateToken(newToken.getSeries(), newToken.getTokenValue(), newToken.getDate());
+            addCookie(newToken, request, response);
+        } catch (DataAccessException e) {
+            logger.error("Failed to update token: ", e);
+            throw new RememberMeAuthenticationException("Autologin failed due to data access problem");
+        }
 
         return token.getUsername();
     }
@@ -104,32 +117,31 @@ public class PersistentTokenBasedRememberMeServices extends AbstractRememberMeSe
      *
      */
     protected void onLoginSuccess(HttpServletRequest request, HttpServletResponse response, Authentication successfulAuthentication) {
-        PersistentRememberMeToken token = createNewToken(successfulAuthentication.getName(), null);
-        addCookie(token, request, response);
-    }
-
-    private PersistentRememberMeToken createNewToken(String username, String series) {
-        if (logger.isDebugEnabled()) {
-            logger.debug(series == null ? "Creating new" : "Renewing" +
-                    " persistent login token for user " + username);
-        }
+        String username = successfulAuthentication.getName();
 
-        if (series == null) {
-            byte[] newSeries = new byte[seriesLength];
-            random.nextBytes(newSeries);
-            series = new String(Base64.encodeBase64(newSeries));
-            logger.debug("New series: " + series);
-        }
+        logger.debug("Creating new persistent login for user " + username);
 
-        byte[] token = new byte[tokenLength];
-        random.nextBytes(token);
+        PersistentRememberMeToken persistentToken = new PersistentRememberMeToken(username, generateSeriesData(),
+                generateTokenData(), new Date());
+        try {
+            tokenRepository.createNewToken(persistentToken);
+            addCookie(persistentToken, request, response);
+        } catch (DataAccessException e) {
+            logger.error("Failed to save persistent token ", e);
 
-        PersistentRememberMeToken persistentToken = new PersistentRememberMeToken(username, series,
-                new String(Base64.encodeBase64(token)), new Date());
+        }
+    }
 
-        tokenRepository.saveToken(persistentToken);
+    protected String generateSeriesData() {
+        byte[] newSeries = new byte[seriesLength];
+        random.nextBytes(newSeries);
+        return new String(Base64.encodeBase64(newSeries));
+    }
 
-        return persistentToken;
+    protected String generateTokenData() {
+        byte[] newToken = new byte[tokenLength];
+        random.nextBytes(newToken);
+        return new String(Base64.encodeBase64(newToken));
     }
 
     private void addCookie(PersistentRememberMeToken token, HttpServletRequest request, HttpServletResponse response) {

+ 12 - 2
core/src/main/java/org/springframework/security/ui/rememberme/PersistentTokenRepository.java

@@ -1,15 +1,25 @@
 package org.springframework.security.ui.rememberme;
 
+import java.util.Date;
+
 /**
+ * The abstraction used by {@link PersistentTokenBasedRememberMeServices} to store the persistent
+ * login tokens for a user.
+ *
+ * @see JdbcTokenRepositoryImpl
+ * @see InMemoryTokenRepositoryImpl 
+ *
  * @author Luke Taylor
  * @version $Id$
  */
 public interface PersistentTokenRepository {
 
-    void saveToken(PersistentRememberMeToken token);
+    void createNewToken(PersistentRememberMeToken token);
+
+    void updateToken(String series, String tokenValue, Date lastUsed);
 
     PersistentRememberMeToken getTokenForSeries(String seriesId);
 
-    void removeAllTokens(String username);
+    void removeUserTokens(String username);
 
 }

+ 132 - 0
core/src/test/java/org/springframework/security/ui/rememberme/JdbcTokenRepositoryImplTests.java

@@ -0,0 +1,132 @@
+package org.springframework.security.ui.rememberme;
+
+import org.springframework.jdbc.core.JdbcTemplate;
+import org.springframework.jdbc.datasource.DriverManagerDataSource;
+
+import org.junit.After;
+import static org.junit.Assert.*;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.sql.Timestamp;
+import java.util.Date;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * @author Luke Taylor
+ * @version $Id$
+ */
+public class JdbcTokenRepositoryImplTests {
+    private static DriverManagerDataSource dataSource;
+    private JdbcTokenRepositoryImpl repo;
+    private JdbcTemplate template;
+
+    @BeforeClass
+    public static void createDataSource() {
+        dataSource = new DriverManagerDataSource();
+        dataSource.setDriverClassName("org.hsqldb.jdbcDriver");
+        dataSource.setUrl("jdbc:hsqldb:mem:tokenrepotest");
+        dataSource.setUsername("sa");
+        dataSource.setPassword("");
+    }
+
+    @Before
+    public void populateDatabase() {
+        repo = new JdbcTokenRepositoryImpl();
+        repo.setDataSource(dataSource);
+        repo.initDao();
+        template = repo.getJdbcTemplate();
+        template.execute("create table persistent_logins (username varchar not null, " +
+                "series varchar not null, token varchar not null, last_used timestamp not null)");
+    }
+
+    @After
+    public void clearData() {
+        template.execute("drop table persistent_logins");
+    }
+
+    @Test
+    public void createNewTokenInsertsCorrectData() {
+        Date currentDate = new Date();
+        PersistentRememberMeToken token = new PersistentRememberMeToken("joeuser", "joesseries", "atoken", currentDate);
+        repo.createNewToken(token);
+
+        Map results = template.queryForMap("select * from persistent_logins");
+
+        assertEquals(currentDate, results.get("last_used"));
+        assertEquals("joeuser", results.get("username"));
+        assertEquals("joesseries", results.get("series"));
+        assertEquals("atoken", results.get("token"));
+    }
+
+    @Test
+    public void retrievingTokenReturnsCorrectData() {
+
+        template.execute("insert into persistent_logins (series, username, token, last_used) values " +
+                "('joesseries', 'joeuser', 'atoken', '2007-10-09 18:19:25.000000000')");
+        PersistentRememberMeToken token = repo.getTokenForSeries("joesseries");
+
+        assertEquals("joeuser", token.getUsername());
+        assertEquals("joesseries", token.getSeries());
+        assertEquals("atoken", token.getTokenValue());
+        assertEquals(Timestamp.valueOf("2007-10-09 18:19:25.000000000"), token.getDate());
+    }
+
+    @Test
+    public void retrievingTokenWithDuplicateSeriesReturnsNull() {
+        template.execute("insert into persistent_logins (series, username, token, last_used) values " +
+                "('joesseries', 'joeuser', 'atoken2', '2007-10-19 18:19:25.000000000')");
+        template.execute("insert into persistent_logins (series, username, token, last_used) values " +
+                "('joesseries', 'joeuser', 'atoken', '2007-10-09 18:19:25.000000000')");
+
+//        List results = template.queryForList("select * from persistent_logins where series = 'joesseries'");
+
+        assertNull(repo.getTokenForSeries("joesseries"));
+    }
+
+    @Test
+    public void removingUserTokensDeletesData() {
+        template.execute("insert into persistent_logins (series, username, token, last_used) values " +
+                "('joesseries2', 'joeuser', 'atoken2', '2007-10-19 18:19:25.000000000')");
+        template.execute("insert into persistent_logins (series, username, token, last_used) values " +
+                "('joesseries', 'joeuser', 'atoken', '2007-10-09 18:19:25.000000000')");
+
+       // List results = template.queryForList("select * from persistent_logins where series = 'joesseries'");
+
+        repo.removeUserTokens("joeuser");
+
+        List results = template.queryForList("select * from persistent_logins where username = 'joeuser'");
+
+        assertEquals(0, results.size());
+    }
+
+    @Test
+    public void updatingTokenModifiesTokenValueAndLastUsed() {
+        Timestamp ts = new Timestamp(System.currentTimeMillis() - 1);
+        template.execute("insert into persistent_logins (series, username, token, last_used) values " +
+                "('joesseries', 'joeuser', 'atoken', '" + ts.toString() + "')");
+        repo.updateToken("joesseries", "newtoken", new Date());
+
+        Map results = template.queryForMap("select * from persistent_logins where series = 'joesseries'");
+
+        assertEquals("joeuser", results.get("username"));
+        assertEquals("joesseries", results.get("series"));
+        assertEquals("newtoken", results.get("token"));
+        Date lastUsed = (Date) results.get("last_used");
+        assertTrue(lastUsed.getTime() > ts.getTime());
+    }
+
+    @Test
+    public void createTableOnStartupCreatesCorrectTable() {
+        template.execute("drop table persistent_logins");
+        repo = new JdbcTokenRepositoryImpl();
+        repo.setDataSource(dataSource);
+        repo.setCreateTableOnStartup(true);
+        repo.initDao();
+
+        template.queryForList("select username,series,token,last_used from persistent_logins");
+    }
+
+}

+ 8 - 3
core/src/test/java/org/springframework/security/ui/rememberme/PersistentTokenBasedRememberMeServicesTests.java

@@ -24,7 +24,7 @@ public class PersistentTokenBasedRememberMeServicesTests {
 
     @Test(expected = InvalidCookieException.class)
     public void loginIsRejectedWithWrongNumberOfCookieTokens() {
-        services.processAutoLoginCookie(new String[] {"series", "token", "extra"}, new MockHttpServletRequest(), 
+        services.processAutoLoginCookie(new String[] {"series", "token", "extra"}, new MockHttpServletRequest(),
                 new MockHttpServletResponse());
     }
 
@@ -101,15 +101,20 @@ public class PersistentTokenBasedRememberMeServicesTests {
             storedToken = token;
         }
 
-        public void saveToken(PersistentRememberMeToken token) {
+        public void createNewToken(PersistentRememberMeToken token) {
             storedToken = token;
         }
 
+        public void updateToken(String series, String tokenValue, Date lastUsed) {
+            storedToken = new PersistentRememberMeToken(storedToken.getUsername(), storedToken.getSeries(),
+                    tokenValue, lastUsed);
+        }
+
         public PersistentRememberMeToken getTokenForSeries(String seriesId) {
             return storedToken;
         }
 
-        public void removeAllTokens(String username) {
+        public void removeUserTokens(String username) {
         }
 
         PersistentRememberMeToken getStoredToken() {