convert_tp.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import os
  2. import torch
  3. import argparse
  4. import glob
  5. SEQUENTIAL_LAYERS = [
  6. "input_layernorm.weight",
  7. "input_layernorm.bias",
  8. "attention.dense.bias",
  9. "post_attention_layernorm.weight",
  10. "post_attention_layernorm.bias",
  11. "mlp.dense_4h_to_h.bias",
  12. "attention.rotary_emb.inv_freq",
  13. "final_layernorm.weight",
  14. "final_layernorm.bias",
  15. ]
  16. GLU_LAYERS = [
  17. "mlp.dense_h_to_4h.weight",
  18. "mlp.dense_h_to_4h.bias",
  19. ]
  20. QUANTIZED_LAYERS = [
  21. "attention.dense.weight",
  22. "attention.query_key_value.weight",
  23. "mlp.dense_h_to_4h.weight",
  24. "mlp.dense_4h_to_h.weight",
  25. ]
  26. LAYER_CONCAT_DIM = {"attention.dense.weight": 1, "mlp.dense_4h_to_h.weight": 1}
  27. def parse_arguments():
  28. parser = argparse.ArgumentParser()
  29. parser.add_argument("--input-folder", default=None, type=str, help="Input SAT checkpoint folder")
  30. parser.add_argument("--output-folder", default=None, type=str, help="Output SAT checkpoint folder")
  31. parser.add_argument("--target-tp", default=4, type=int, help="Target TP degree")
  32. return parser.parse_args()
  33. def merge_weights(key, sd_list, tp_index, original_tp, target_tp, cat_dim, is_glu):
  34. if original_tp >= target_tp:
  35. if is_glu:
  36. if original_tp > target_tp:
  37. num_part = original_tp // target_tp
  38. assert len(sd_list) == num_part
  39. part1, part2 = [], []
  40. for i in range(len(sd_list)):
  41. chunks = torch.chunk(sd_list[i][key], 2, dim=cat_dim)
  42. part1.append(chunks[0])
  43. part2.append(chunks[1])
  44. merged_sd = torch.cat(part1 + part2, dim=cat_dim)
  45. else:
  46. merged_sd = sd_list[0][key]
  47. else:
  48. merged_sd = torch.cat([sd[key] for sd in sd_list], dim=cat_dim)
  49. else:
  50. assert len(sd_list) == 1
  51. num_part = target_tp // original_tp
  52. if is_glu:
  53. offset = tp_index % num_part
  54. chunks = torch.chunk(sd_list[0][key], num_part * 2, dim=cat_dim)
  55. merged_sd = torch.cat([chunks[offset], chunks[num_part + offset]], dim=cat_dim)
  56. else:
  57. # without clone, torch will save entire tensor
  58. merged_sd = torch.chunk(sd_list[0][key], num_part, dim=cat_dim)[tp_index % num_part].clone()
  59. return merged_sd
  60. def create_checkpoint(sd_list, tp_index, original_tp, target_tp):
  61. new_sd = {}
  62. for key in sd_list[0].keys():
  63. name = ".".join(key.split(".")[3 if key.startswith("transformer.layers") else 1 :])
  64. if name in SEQUENTIAL_LAYERS:
  65. new_sd[key] = sd_list[0][key]
  66. else:
  67. new_sd[key] = merge_weights(
  68. key,
  69. sd_list,
  70. tp_index=tp_index,
  71. original_tp=original_tp,
  72. target_tp=target_tp,
  73. cat_dim=LAYER_CONCAT_DIM.get(name, 0),
  74. is_glu=name in GLU_LAYERS,
  75. )
  76. new_sd = {"module": new_sd}
  77. return new_sd
  78. def main(args):
  79. iteration = open(os.path.join(args.input_folder, "latest"), "r").read()
  80. original_tp = len(glob.glob(os.path.join(args.input_folder, iteration, "mp_rank_*_model_states.pt")))
  81. print(f"Iteration {iteration} from {args.input_folder} to {args.output_folder}")
  82. os.makedirs(args.output_folder, exist_ok=True)
  83. with open(os.path.join(args.output_folder, "latest"), "w") as file:
  84. file.write(str(iteration))
  85. os.makedirs(os.path.join(args.output_folder, iteration), exist_ok=True)
  86. for i in range(0, args.target_tp):
  87. save_path = os.path.join(args.output_folder, iteration, f"mp_rank_{i:02}_model_states.pt")
  88. print(f"Processing {save_path}")
  89. num_parts = original_tp // args.target_tp
  90. sd_list = [
  91. torch.load(
  92. os.path.join(args.input_folder, iteration, f"mp_rank_{j:02}_model_states.pt"), map_location="cpu"
  93. )["module"]
  94. for j in (
  95. range(i * num_parts, (i + 1) * num_parts)
  96. if args.target_tp <= original_tp
  97. else [i // (args.target_tp // original_tp)]
  98. )
  99. ]
  100. torch.save(create_checkpoint(sd_list, i, original_tp, args.target_tp), save_path)
  101. if __name__ == "__main__":
  102. args = parse_arguments()
  103. main(args)