Sfoglia il codice sorgente

Add checkpoint tensor parallel conversion script

Sengxian 3 anni fa
parent
commit
6b410ef9d2
1 ha cambiato i file con 118 aggiunte e 0 eliminazioni
  1. 118 0
      tools/convert_tp.py

+ 118 - 0
tools/convert_tp.py

@@ -0,0 +1,118 @@
+import os
+import torch
+import argparse
+import glob
+
+SEQUENTIAL_LAYERS = [
+    "input_layernorm.weight",
+    "input_layernorm.bias",
+    "attention.dense.bias",
+    "post_attention_layernorm.weight",
+    "post_attention_layernorm.bias",
+    "mlp.dense_4h_to_h.bias",
+    "attention.rotary_emb.inv_freq",
+    "final_layernorm.weight",
+    "final_layernorm.bias",
+]
+
+GLU_LAYERS = [
+    "mlp.dense_h_to_4h.weight",
+    "mlp.dense_h_to_4h.bias",
+]
+
+QUANTIZED_LAYERS = [
+    "attention.dense.weight",
+    "attention.query_key_value.weight",
+    "mlp.dense_h_to_4h.weight",
+    "mlp.dense_4h_to_h.weight",
+]
+
+LAYER_CONCAT_DIM = {"attention.dense.weight": 1, "mlp.dense_4h_to_h.weight": 1}
+
+
+def parse_arguments():
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--input-folder", default=None, type=str, help="Input SAT checkpoint folder")
+    parser.add_argument("--output-folder", default=None, type=str, help="Output SAT checkpoint folder")
+    parser.add_argument("--target-tp", default=4, type=int, help="Target TP degree")
+    return parser.parse_args()
+
+
+def merge_weights(key, sd_list, tp_index, original_tp, target_tp, cat_dim, is_glu):
+    if original_tp >= target_tp:
+        if is_glu:
+            if original_tp > target_tp:
+                num_part = original_tp // target_tp
+                assert len(sd_list) == num_part
+                part1, part2 = [], []
+                for i in range(len(sd_list)):
+                    chunks = torch.chunk(sd_list[i][key], 2, dim=cat_dim)
+                    part1.append(chunks[0])
+                    part2.append(chunks[1])
+                merged_sd = torch.cat(part1 + part2, dim=cat_dim)
+            else:
+                merged_sd = sd_list[0][key]
+        else:
+            merged_sd = torch.cat([sd[key] for sd in sd_list], dim=cat_dim)
+    else:
+        assert len(sd_list) == 1
+        num_part = target_tp // original_tp
+        if is_glu:
+            offset = tp_index % num_part
+            chunks = torch.chunk(sd_list[0][key], num_part * 2, dim=cat_dim)
+            merged_sd = torch.cat([chunks[offset], chunks[num_part + offset]], dim=cat_dim)
+        else:
+            # without clone, torch will save entire tensor
+            merged_sd = torch.chunk(sd_list[0][key], num_part, dim=cat_dim)[tp_index % num_part].clone()
+    return merged_sd
+
+
+def create_checkpoint(sd_list, tp_index, original_tp, target_tp):
+    new_sd = {}
+    for key in sd_list[0].keys():
+        name = ".".join(key.split(".")[3 if key.startswith("transformer.layers") else 1 :])
+        if name in SEQUENTIAL_LAYERS:
+            new_sd[key] = sd_list[0][key]
+        else:
+            new_sd[key] = merge_weights(
+                key,
+                sd_list,
+                tp_index=tp_index,
+                original_tp=original_tp,
+                target_tp=target_tp,
+                cat_dim=LAYER_CONCAT_DIM.get(name, 0),
+                is_glu=name in GLU_LAYERS,
+            )
+    new_sd = {"module": new_sd}
+    return new_sd
+
+
+def main(args):
+    iteration = open(os.path.join(args.input_folder, "latest"), "r").read()
+    original_tp = len(glob.glob(os.path.join(args.input_folder, iteration, "mp_rank_*_model_states.pt")))
+    print(f"Iteration {iteration} from {args.input_folder} to {args.output_folder}")
+    os.makedirs(args.output_folder, exist_ok=True)
+    with open(os.path.join(args.output_folder, "latest"), "w") as file:
+        file.write(str(iteration))
+    os.makedirs(os.path.join(args.output_folder, iteration), exist_ok=True)
+
+    for i in range(0, args.target_tp):
+        save_path = os.path.join(args.output_folder, iteration, f"mp_rank_{i:02}_model_states.pt")
+        print(f"Processing {save_path}")
+        num_parts = original_tp // args.target_tp
+        sd_list = [
+            torch.load(
+                os.path.join(args.input_folder, iteration, f"mp_rank_{j:02}_model_states.pt"), map_location="cpu"
+            )["module"]
+            for j in (
+                range(i * num_parts, (i + 1) * num_parts)
+                if args.target_tp <= original_tp
+                else [i // (args.target_tp // original_tp)]
+            )
+        ]
+        torch.save(create_checkpoint(sd_list, i, original_tp, args.target_tp), save_path)
+
+
+if __name__ == "__main__":
+    args = parse_arguments()
+    main(args)