소스 검색

Improve Upgrading

Closes gh-11259
Josh Cummings 4 년 전
부모
커밋
e6297d3bf7

+ 34 - 13
crypto/src/main/java/org/springframework/security/crypto/bcrypt/BCrypt.java

@@ -526,35 +526,47 @@ public class BCrypt {
 	 * @param safety bit 16 is set when the safety measure is requested
 	 * @return an array containing the binary hashed password
 	 */
-	private byte[] crypt_raw(byte password[], byte salt[], int log_rounds, boolean sign_ext_bug, int safety) {
-		int rounds, i, j;
+	private byte[] crypt_raw(byte password[], byte salt[], int log_rounds, boolean sign_ext_bug, int safety,
+			boolean for_check) {
 		int cdata[] = bf_crypt_ciphertext.clone();
 		int clen = cdata.length;
-		byte ret[];
 
+		long rounds;
 		if (log_rounds < 4 || log_rounds > 31) {
-			throw new IllegalArgumentException("Bad number of rounds");
+			if (!for_check) {
+				throw new IllegalArgumentException("Bad number of rounds");
+			}
+			if (log_rounds != 0) {
+				throw new IllegalArgumentException("Bad number of rounds");
+			}
+			rounds = 0;
+		}
+		else {
+			rounds = roundsForLogRounds(log_rounds);
+			if (rounds < 16 || rounds > Integer.MAX_VALUE) {
+				throw new IllegalArgumentException("Bad number of rounds");
+			}
 		}
-		rounds = 1 << log_rounds;
+
 		if (salt.length != BCRYPT_SALT_LEN) {
 			throw new IllegalArgumentException("Bad salt length");
 		}
 
 		init_key();
 		ekskey(salt, password, sign_ext_bug, safety);
-		for (i = 0; i < rounds; i++) {
+		for (int i = 0; i < rounds; i++) {
 			key(password, sign_ext_bug, safety);
 			key(salt, false, safety);
 		}
 
-		for (i = 0; i < 64; i++) {
-			for (j = 0; j < (clen >> 1); j++) {
+		for (int i = 0; i < 64; i++) {
+			for (int j = 0; j < (clen >> 1); j++) {
 				encipher(cdata, j << 1);
 			}
 		}
 
-		ret = new byte[clen * 4];
-		for (i = 0, j = 0; i < clen; i++) {
+		byte[] ret = new byte[clen * 4];
+		for (int i = 0, j = 0; i < clen; i++) {
 			ret[j++] = (byte) ((cdata[i] >> 24) & 0xff);
 			ret[j++] = (byte) ((cdata[i] >> 16) & 0xff);
 			ret[j++] = (byte) ((cdata[i] >> 8) & 0xff);
@@ -563,6 +575,10 @@ public class BCrypt {
 		return ret;
 	}
 
+	private static String hashpwforcheck(byte[] passwordb, String salt) {
+		return hashpw(passwordb, salt, true);
+	}
+
 	/**
 	 * Hash a password using the OpenBSD bcrypt scheme
 	 * @param password the password to hash
@@ -584,6 +600,10 @@ public class BCrypt {
 	 * @return the hashed password
 	 */
 	public static String hashpw(byte passwordb[], String salt) {
+		return hashpw(passwordb, salt, false);
+	}
+
+	private static String hashpw(byte passwordb[], String salt, boolean for_check) {
 		BCrypt B;
 		String real_salt;
 		byte saltb[], hashed[];
@@ -633,7 +653,7 @@ public class BCrypt {
 		}
 
 		B = new BCrypt();
-		hashed = B.crypt_raw(passwordb, saltb, rounds, minor == 'x', minor == 'a' ? 0x10000 : 0);
+		hashed = B.crypt_raw(passwordb, saltb, rounds, minor == 'x', minor == 'a' ? 0x10000 : 0, for_check);
 
 		rs.append("$2");
 		if (minor >= 'a') {
@@ -740,7 +760,8 @@ public class BCrypt {
 	 * @return true if the passwords match, false otherwise
 	 */
 	public static boolean checkpw(String plaintext, String hashed) {
-		return equalsNoEarlyReturn(hashed, hashpw(plaintext, hashed));
+		byte[] passwordb = plaintext.getBytes(StandardCharsets.UTF_8);
+		return equalsNoEarlyReturn(hashed, hashpwforcheck(passwordb, hashed));
 	}
 
 	/**
@@ -751,7 +772,7 @@ public class BCrypt {
 	 * @since 5.3
 	 */
 	public static boolean checkpw(byte[] passwordb, String hashed) {
-		return equalsNoEarlyReturn(hashed, hashpw(passwordb, hashed));
+		return equalsNoEarlyReturn(hashed, hashpwforcheck(passwordb, hashed));
 	}
 
 	static boolean equalsNoEarlyReturn(String a, String b) {

+ 14 - 0
crypto/src/test/java/org/springframework/security/crypto/bcrypt/BCryptPasswordEncoderTests.java

@@ -208,4 +208,18 @@ public class BCryptPasswordEncoderTests {
 		assertThatIllegalArgumentException().isThrownBy(() -> encoder.matches(null, "does-not-matter"));
 	}
 
+	@Test
+	public void upgradeWhenNoRoundsThenTrue() {
+		BCryptPasswordEncoder encoder = new BCryptPasswordEncoder();
+		assertThat(encoder.upgradeEncoding("$2a$00$9N8N35BVs5TLqGL3pspAte5OWWA2a2aZIs.EGp7At7txYakFERMue")).isTrue();
+	}
+
+	@Test
+	public void checkWhenNoRoundsThenTrue() {
+		BCryptPasswordEncoder encoder = new BCryptPasswordEncoder();
+		assertThat(encoder.matches("password", "$2a$00$9N8N35BVs5TLqGL3pspAte5OWWA2a2aZIs.EGp7At7txYakFERMue"))
+				.isTrue();
+		assertThat(encoder.matches("wrong", "$2a$00$9N8N35BVs5TLqGL3pspAte5OWWA2a2aZIs.EGp7At7txYakFERMue")).isFalse();
+	}
+
 }

+ 7 - 0
crypto/src/test/java/org/springframework/security/crypto/bcrypt/BCryptTests.java

@@ -456,4 +456,11 @@ public class BCryptTests {
 		assertThat(BCrypt.equalsNoEarlyReturn("test", "pass")).isFalse();
 	}
 
+	@Test
+	public void checkpwWhenZeroRoundsThenMatches() {
+		String password = "$2a$00$9N8N35BVs5TLqGL3pspAte5OWWA2a2aZIs.EGp7At7txYakFERMue";
+		assertThat(BCrypt.checkpw("password", password)).isTrue();
+		assertThat(BCrypt.checkpw("wrong", password)).isFalse();
+	}
+
 }