convert_tp.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import os
  2. import sys
  3. import torch
  4. import argparse
  5. import glob
  6. from typing import *
  7. sys.path.append(".")
  8. SEQUENTIAL_LAYERS = [
  9. "input_layernorm.weight",
  10. "input_layernorm.bias",
  11. "attention.dense.bias",
  12. "post_attention_layernorm.weight",
  13. "post_attention_layernorm.bias",
  14. "mlp.dense_4h_to_h.bias",
  15. "attention.rotary_emb.inv_freq",
  16. "final_layernorm.weight",
  17. "final_layernorm.bias",
  18. ]
  19. GLU_LAYERS = [
  20. "mlp.dense_h_to_4h.weight",
  21. "mlp.dense_h_to_4h.bias",
  22. ]
  23. QUANTIZED_LAYERS = [
  24. "attention.dense.weight",
  25. "attention.query_key_value.weight",
  26. "mlp.dense_h_to_4h.weight",
  27. "mlp.dense_4h_to_h.weight",
  28. ]
  29. LAYER_CONCAT_DIM = {"attention.dense.weight": 1, "mlp.dense_4h_to_h.weight": 1}
  30. def parse_arguments():
  31. parser = argparse.ArgumentParser()
  32. parser.add_argument("--input-folder", default=None, type=str, help="Input SAT checkpoint folder")
  33. parser.add_argument("--output-folder", default=None, type=str, help="Output SAT checkpoint folder")
  34. parser.add_argument("--target-tp", default=4, type=int, help="Target TP degree")
  35. parser.add_argument("--quantization-bit-width", default=None, type=int, help="Quantization bit width")
  36. args = parser.parse_args()
  37. if args.quantization_bit_width is not None:
  38. assert args.quantization_bit_width in [4, 8]
  39. return args
  40. def merge_weights(
  41. key: str,
  42. sd_list: List[Dict],
  43. tp_index: int,
  44. original_tp: int,
  45. target_tp: int,
  46. cat_dim: int,
  47. is_glu: bool,
  48. quantization_bit_width: Optional[int],
  49. ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
  50. if original_tp >= target_tp:
  51. if is_glu:
  52. if original_tp > target_tp:
  53. num_part = original_tp // target_tp
  54. assert len(sd_list) == num_part
  55. part1, part2 = [], []
  56. for i in range(len(sd_list)):
  57. chunks = torch.chunk(sd_list[i][key], 2, dim=cat_dim)
  58. part1.append(chunks[0])
  59. part2.append(chunks[1])
  60. merged_sd = torch.cat(part1 + part2, dim=cat_dim)
  61. else:
  62. merged_sd = sd_list[0][key]
  63. else:
  64. merged_sd = torch.cat([sd[key] for sd in sd_list], dim=cat_dim)
  65. else:
  66. assert len(sd_list) == 1
  67. num_part = target_tp // original_tp
  68. if is_glu:
  69. offset = tp_index % num_part
  70. chunks = torch.chunk(sd_list[0][key], num_part * 2, dim=cat_dim)
  71. merged_sd = torch.cat([chunks[offset], chunks[num_part + offset]], dim=cat_dim)
  72. else:
  73. # without clone, torch will save entire tensor
  74. merged_sd = torch.chunk(sd_list[0][key], num_part, dim=cat_dim)[tp_index % num_part].clone()
  75. if quantization_bit_width is not None:
  76. from kernels import compress_int4_weight
  77. weight = merged_sd.cuda()
  78. weight_scale = (weight.abs().max(dim=-1).values / ((2 ** (quantization_bit_width - 1)) - 1)).half()
  79. weight = torch.round(weight / weight_scale[:, None]).to(torch.int8)
  80. if quantization_bit_width == 4:
  81. weight = compress_int4_weight(weight)
  82. return weight.cpu(), weight_scale.cpu()
  83. return merged_sd
  84. def create_checkpoint(
  85. sd_list: List[Dict], tp_index: int, original_tp: int, target_tp: int, quantization_bit_width: Optional[int]
  86. ) -> Dict:
  87. new_sd = {}
  88. for key in sd_list[0].keys():
  89. name = ".".join(key.split(".")[3 if key.startswith("transformer.layers") else 1 :])
  90. if name in SEQUENTIAL_LAYERS:
  91. new_sd[key] = sd_list[0][key]
  92. else:
  93. new_sd[key] = merge_weights(
  94. key,
  95. sd_list,
  96. tp_index=tp_index,
  97. original_tp=original_tp,
  98. target_tp=target_tp,
  99. cat_dim=LAYER_CONCAT_DIM.get(name, 0),
  100. is_glu=name in GLU_LAYERS,
  101. quantization_bit_width=quantization_bit_width if name in QUANTIZED_LAYERS else None,
  102. )
  103. if name in QUANTIZED_LAYERS:
  104. new_sd[key], new_sd[f"{key}_scale"] = new_sd[key]
  105. new_sd = {"module": new_sd}
  106. return new_sd
  107. def main(args):
  108. iteration = open(os.path.join(args.input_folder, "latest"), "r").read()
  109. original_tp = len(glob.glob(os.path.join(args.input_folder, iteration, "mp_rank_*_model_states.pt")))
  110. print(f"Iteration {iteration} from {args.input_folder} to {args.output_folder}")
  111. os.makedirs(args.output_folder, exist_ok=True)
  112. with open(os.path.join(args.output_folder, "latest"), "w") as file:
  113. file.write(str(iteration))
  114. os.makedirs(os.path.join(args.output_folder, iteration), exist_ok=True)
  115. for i in range(0, args.target_tp):
  116. save_path = os.path.join(args.output_folder, iteration, f"mp_rank_{i:02}_model_states.pt")
  117. print(f"Processing {save_path}")
  118. num_parts = original_tp // args.target_tp
  119. sd_list = [
  120. torch.load(
  121. os.path.join(args.input_folder, iteration, f"mp_rank_{j:02}_model_states.pt"), map_location="cpu"
  122. )["module"]
  123. for j in (
  124. range(i * num_parts, (i + 1) * num_parts)
  125. if args.target_tp <= original_tp
  126. else [i // (args.target_tp // original_tp)]
  127. )
  128. ]
  129. torch.save(create_checkpoint(sd_list, i, original_tp, args.target_tp, args.quantization_bit_width), save_path)
  130. if __name__ == "__main__":
  131. args = parse_arguments()
  132. main(args)