|
@@ -1,8 +1,13 @@
|
|
import os
|
|
import os
|
|
|
|
+import sys
|
|
import torch
|
|
import torch
|
|
import argparse
|
|
import argparse
|
|
import glob
|
|
import glob
|
|
|
|
|
|
|
|
+from typing import *
|
|
|
|
+
|
|
|
|
+sys.path.append(".")
|
|
|
|
+
|
|
SEQUENTIAL_LAYERS = [
|
|
SEQUENTIAL_LAYERS = [
|
|
"input_layernorm.weight",
|
|
"input_layernorm.weight",
|
|
"input_layernorm.bias",
|
|
"input_layernorm.bias",
|
|
@@ -35,10 +40,25 @@ def parse_arguments():
|
|
parser.add_argument("--input-folder", default=None, type=str, help="Input SAT checkpoint folder")
|
|
parser.add_argument("--input-folder", default=None, type=str, help="Input SAT checkpoint folder")
|
|
parser.add_argument("--output-folder", default=None, type=str, help="Output SAT checkpoint folder")
|
|
parser.add_argument("--output-folder", default=None, type=str, help="Output SAT checkpoint folder")
|
|
parser.add_argument("--target-tp", default=4, type=int, help="Target TP degree")
|
|
parser.add_argument("--target-tp", default=4, type=int, help="Target TP degree")
|
|
- return parser.parse_args()
|
|
|
|
|
|
+ parser.add_argument("--quantization-bit-width", default=None, type=int, help="Quantization bit width")
|
|
|
|
+
|
|
|
|
+ args = parser.parse_args()
|
|
|
|
+ if args.quantization_bit_width is not None:
|
|
|
|
+ assert args.quantization_bit_width in [4, 8]
|
|
|
|
+
|
|
|
|
+ return args
|
|
|
|
|
|
|
|
|
|
-def merge_weights(key, sd_list, tp_index, original_tp, target_tp, cat_dim, is_glu):
|
|
|
|
|
|
+def merge_weights(
|
|
|
|
+ key: str,
|
|
|
|
+ sd_list: List[Dict],
|
|
|
|
+ tp_index: int,
|
|
|
|
+ original_tp: int,
|
|
|
|
+ target_tp: int,
|
|
|
|
+ cat_dim: int,
|
|
|
|
+ is_glu: bool,
|
|
|
|
+ quantization_bit_width: Optional[int],
|
|
|
|
+) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
|
if original_tp >= target_tp:
|
|
if original_tp >= target_tp:
|
|
if is_glu:
|
|
if is_glu:
|
|
if original_tp > target_tp:
|
|
if original_tp > target_tp:
|
|
@@ -64,10 +84,23 @@ def merge_weights(key, sd_list, tp_index, original_tp, target_tp, cat_dim, is_gl
|
|
else:
|
|
else:
|
|
# without clone, torch will save entire tensor
|
|
# without clone, torch will save entire tensor
|
|
merged_sd = torch.chunk(sd_list[0][key], num_part, dim=cat_dim)[tp_index % num_part].clone()
|
|
merged_sd = torch.chunk(sd_list[0][key], num_part, dim=cat_dim)[tp_index % num_part].clone()
|
|
|
|
+
|
|
|
|
+ if quantization_bit_width is not None:
|
|
|
|
+ from kernels import compress_int4_weight
|
|
|
|
+
|
|
|
|
+ weight = merged_sd.cuda()
|
|
|
|
+ weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (quantization_bit_width - 1)) - 1)).half()
|
|
|
|
+ weight = torch.round(weight / weight_scale[:, None]).to(torch.int8)
|
|
|
|
+ if quantization_bit_width == 4:
|
|
|
|
+ weight = compress_int4_weight(weight)
|
|
|
|
+ return weight.cpu(), weight_scale.cpu()
|
|
|
|
+
|
|
return merged_sd
|
|
return merged_sd
|
|
|
|
|
|
|
|
|
|
-def create_checkpoint(sd_list, tp_index, original_tp, target_tp):
|
|
|
|
|
|
+def create_checkpoint(
|
|
|
|
+ sd_list: List[Dict], tp_index: int, original_tp: int, target_tp: int, quantization_bit_width: Optional[int]
|
|
|
|
+) -> Dict:
|
|
new_sd = {}
|
|
new_sd = {}
|
|
for key in sd_list[0].keys():
|
|
for key in sd_list[0].keys():
|
|
name = ".".join(key.split(".")[3 if key.startswith("transformer.layers") else 1 :])
|
|
name = ".".join(key.split(".")[3 if key.startswith("transformer.layers") else 1 :])
|
|
@@ -82,7 +115,10 @@ def create_checkpoint(sd_list, tp_index, original_tp, target_tp):
|
|
target_tp=target_tp,
|
|
target_tp=target_tp,
|
|
cat_dim=LAYER_CONCAT_DIM.get(name, 0),
|
|
cat_dim=LAYER_CONCAT_DIM.get(name, 0),
|
|
is_glu=name in GLU_LAYERS,
|
|
is_glu=name in GLU_LAYERS,
|
|
|
|
+ quantization_bit_width=quantization_bit_width if name in QUANTIZED_LAYERS else None,
|
|
)
|
|
)
|
|
|
|
+ if name in QUANTIZED_LAYERS:
|
|
|
|
+ new_sd[key], new_sd[f"{key}_scale"] = new_sd[key]
|
|
new_sd = {"module": new_sd}
|
|
new_sd = {"module": new_sd}
|
|
return new_sd
|
|
return new_sd
|
|
|
|
|
|
@@ -110,7 +146,7 @@ def main(args):
|
|
else [i // (args.target_tp // original_tp)]
|
|
else [i // (args.target_tp // original_tp)]
|
|
)
|
|
)
|
|
]
|
|
]
|
|
- torch.save(create_checkpoint(sd_list, i, original_tp, args.target_tp), save_path)
|
|
|
|
|
|
+ torch.save(create_checkpoint(sd_list, i, original_tp, args.target_tp, args.quantization_bit_width), save_path)
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|