|
@@ -0,0 +1,118 @@
|
|
|
|
+import os
|
|
|
|
+import torch
|
|
|
|
+import argparse
|
|
|
|
+import glob
|
|
|
|
+
|
|
|
|
+SEQUENTIAL_LAYERS = [
|
|
|
|
+ "input_layernorm.weight",
|
|
|
|
+ "input_layernorm.bias",
|
|
|
|
+ "attention.dense.bias",
|
|
|
|
+ "post_attention_layernorm.weight",
|
|
|
|
+ "post_attention_layernorm.bias",
|
|
|
|
+ "mlp.dense_4h_to_h.bias",
|
|
|
|
+ "attention.rotary_emb.inv_freq",
|
|
|
|
+ "final_layernorm.weight",
|
|
|
|
+ "final_layernorm.bias",
|
|
|
|
+]
|
|
|
|
+
|
|
|
|
+GLU_LAYERS = [
|
|
|
|
+ "mlp.dense_h_to_4h.weight",
|
|
|
|
+ "mlp.dense_h_to_4h.bias",
|
|
|
|
+]
|
|
|
|
+
|
|
|
|
+QUANTIZED_LAYERS = [
|
|
|
|
+ "attention.dense.weight",
|
|
|
|
+ "attention.query_key_value.weight",
|
|
|
|
+ "mlp.dense_h_to_4h.weight",
|
|
|
|
+ "mlp.dense_4h_to_h.weight",
|
|
|
|
+]
|
|
|
|
+
|
|
|
|
+LAYER_CONCAT_DIM = {"attention.dense.weight": 1, "mlp.dense_4h_to_h.weight": 1}
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def parse_arguments():
|
|
|
|
+ parser = argparse.ArgumentParser()
|
|
|
|
+ parser.add_argument("--input-folder", default=None, type=str, help="Input SAT checkpoint folder")
|
|
|
|
+ parser.add_argument("--output-folder", default=None, type=str, help="Output SAT checkpoint folder")
|
|
|
|
+ parser.add_argument("--target-tp", default=4, type=int, help="Target TP degree")
|
|
|
|
+ return parser.parse_args()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def merge_weights(key, sd_list, tp_index, original_tp, target_tp, cat_dim, is_glu):
|
|
|
|
+ if original_tp >= target_tp:
|
|
|
|
+ if is_glu:
|
|
|
|
+ if original_tp > target_tp:
|
|
|
|
+ num_part = original_tp // target_tp
|
|
|
|
+ assert len(sd_list) == num_part
|
|
|
|
+ part1, part2 = [], []
|
|
|
|
+ for i in range(len(sd_list)):
|
|
|
|
+ chunks = torch.chunk(sd_list[i][key], 2, dim=cat_dim)
|
|
|
|
+ part1.append(chunks[0])
|
|
|
|
+ part2.append(chunks[1])
|
|
|
|
+ merged_sd = torch.cat(part1 + part2, dim=cat_dim)
|
|
|
|
+ else:
|
|
|
|
+ merged_sd = sd_list[0][key]
|
|
|
|
+ else:
|
|
|
|
+ merged_sd = torch.cat([sd[key] for sd in sd_list], dim=cat_dim)
|
|
|
|
+ else:
|
|
|
|
+ assert len(sd_list) == 1
|
|
|
|
+ num_part = target_tp // original_tp
|
|
|
|
+ if is_glu:
|
|
|
|
+ offset = tp_index % num_part
|
|
|
|
+ chunks = torch.chunk(sd_list[0][key], num_part * 2, dim=cat_dim)
|
|
|
|
+ merged_sd = torch.cat([chunks[offset], chunks[num_part + offset]], dim=cat_dim)
|
|
|
|
+ else:
|
|
|
|
+ # without clone, torch will save entire tensor
|
|
|
|
+ merged_sd = torch.chunk(sd_list[0][key], num_part, dim=cat_dim)[tp_index % num_part].clone()
|
|
|
|
+ return merged_sd
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def create_checkpoint(sd_list, tp_index, original_tp, target_tp):
|
|
|
|
+ new_sd = {}
|
|
|
|
+ for key in sd_list[0].keys():
|
|
|
|
+ name = ".".join(key.split(".")[3 if key.startswith("transformer.layers") else 1 :])
|
|
|
|
+ if name in SEQUENTIAL_LAYERS:
|
|
|
|
+ new_sd[key] = sd_list[0][key]
|
|
|
|
+ else:
|
|
|
|
+ new_sd[key] = merge_weights(
|
|
|
|
+ key,
|
|
|
|
+ sd_list,
|
|
|
|
+ tp_index=tp_index,
|
|
|
|
+ original_tp=original_tp,
|
|
|
|
+ target_tp=target_tp,
|
|
|
|
+ cat_dim=LAYER_CONCAT_DIM.get(name, 0),
|
|
|
|
+ is_glu=name in GLU_LAYERS,
|
|
|
|
+ )
|
|
|
|
+ new_sd = {"module": new_sd}
|
|
|
|
+ return new_sd
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def main(args):
|
|
|
|
+ iteration = open(os.path.join(args.input_folder, "latest"), "r").read()
|
|
|
|
+ original_tp = len(glob.glob(os.path.join(args.input_folder, iteration, "mp_rank_*_model_states.pt")))
|
|
|
|
+ print(f"Iteration {iteration} from {args.input_folder} to {args.output_folder}")
|
|
|
|
+ os.makedirs(args.output_folder, exist_ok=True)
|
|
|
|
+ with open(os.path.join(args.output_folder, "latest"), "w") as file:
|
|
|
|
+ file.write(str(iteration))
|
|
|
|
+ os.makedirs(os.path.join(args.output_folder, iteration), exist_ok=True)
|
|
|
|
+
|
|
|
|
+ for i in range(0, args.target_tp):
|
|
|
|
+ save_path = os.path.join(args.output_folder, iteration, f"mp_rank_{i:02}_model_states.pt")
|
|
|
|
+ print(f"Processing {save_path}")
|
|
|
|
+ num_parts = original_tp // args.target_tp
|
|
|
|
+ sd_list = [
|
|
|
|
+ torch.load(
|
|
|
|
+ os.path.join(args.input_folder, iteration, f"mp_rank_{j:02}_model_states.pt"), map_location="cpu"
|
|
|
|
+ )["module"]
|
|
|
|
+ for j in (
|
|
|
|
+ range(i * num_parts, (i + 1) * num_parts)
|
|
|
|
+ if args.target_tp <= original_tp
|
|
|
|
+ else [i // (args.target_tp // original_tp)]
|
|
|
|
+ )
|
|
|
|
+ ]
|
|
|
|
+ torch.save(create_checkpoint(sd_list, i, original_tp, args.target_tp), save_path)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+if __name__ == "__main__":
|
|
|
|
+ args = parse_arguments()
|
|
|
|
+ main(args)
|