Sfoglia il codice sorgente

Update conversion script

Sengxian 2 anni fa
parent
commit
ec44cfe652
1 ha cambiato i file con 40 aggiunte e 4 eliminazioni
  1. 40 4
      tools/convert_tp.py

+ 40 - 4
tools/convert_tp.py

@@ -1,8 +1,13 @@
 import os
+import sys
 import torch
 import argparse
 import glob
 
+from typing import *
+
+sys.path.append(".")
+
 SEQUENTIAL_LAYERS = [
     "input_layernorm.weight",
     "input_layernorm.bias",
@@ -35,10 +40,25 @@ def parse_arguments():
     parser.add_argument("--input-folder", default=None, type=str, help="Input SAT checkpoint folder")
     parser.add_argument("--output-folder", default=None, type=str, help="Output SAT checkpoint folder")
     parser.add_argument("--target-tp", default=4, type=int, help="Target TP degree")
-    return parser.parse_args()
+    parser.add_argument("--quantization-bit-width", default=None, type=int, help="Quantization bit width")
+
+    args = parser.parse_args()
+    if args.quantization_bit_width is not None:
+        assert args.quantization_bit_width in [4, 8]
+
+    return args
 
 
-def merge_weights(key, sd_list, tp_index, original_tp, target_tp, cat_dim, is_glu):
+def merge_weights(
+    key: str,
+    sd_list: List[Dict],
+    tp_index: int,
+    original_tp: int,
+    target_tp: int,
+    cat_dim: int,
+    is_glu: bool,
+    quantization_bit_width: Optional[int],
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
     if original_tp >= target_tp:
         if is_glu:
             if original_tp > target_tp:
@@ -64,10 +84,23 @@ def merge_weights(key, sd_list, tp_index, original_tp, target_tp, cat_dim, is_gl
         else:
             # without clone, torch will save entire tensor
             merged_sd = torch.chunk(sd_list[0][key], num_part, dim=cat_dim)[tp_index % num_part].clone()
+
+    if quantization_bit_width is not None:
+        from kernels import compress_int4_weight
+
+        weight = merged_sd.cuda()
+        weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (quantization_bit_width - 1)) - 1)).half()
+        weight = torch.round(weight / weight_scale[:, None]).to(torch.int8)
+        if quantization_bit_width == 4:
+            weight = compress_int4_weight(weight)
+        return weight.cpu(), weight_scale.cpu()
+
     return merged_sd
 
 
-def create_checkpoint(sd_list, tp_index, original_tp, target_tp):
+def create_checkpoint(
+    sd_list: List[Dict], tp_index: int, original_tp: int, target_tp: int, quantization_bit_width: Optional[int]
+) -> Dict:
     new_sd = {}
     for key in sd_list[0].keys():
         name = ".".join(key.split(".")[3 if key.startswith("transformer.layers") else 1 :])
@@ -82,7 +115,10 @@ def create_checkpoint(sd_list, tp_index, original_tp, target_tp):
                 target_tp=target_tp,
                 cat_dim=LAYER_CONCAT_DIM.get(name, 0),
                 is_glu=name in GLU_LAYERS,
+                quantization_bit_width=quantization_bit_width if name in QUANTIZED_LAYERS else None,
             )
+            if name in QUANTIZED_LAYERS:
+                new_sd[key], new_sd[f"{key}_scale"] = new_sd[key]
     new_sd = {"module": new_sd}
     return new_sd
 
@@ -110,7 +146,7 @@ def main(args):
                 else [i // (args.target_tp // original_tp)]
             )
         ]
-        torch.save(create_checkpoint(sd_list, i, original_tp, args.target_tp), save_path)
+        torch.save(create_checkpoint(sd_list, i, original_tp, args.target_tp, args.quantization_bit_width), save_path)
 
 
 if __name__ == "__main__":