123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118 |
- import os
- import torch
- import argparse
- import glob
- SEQUENTIAL_LAYERS = [
- "input_layernorm.weight",
- "input_layernorm.bias",
- "attention.dense.bias",
- "post_attention_layernorm.weight",
- "post_attention_layernorm.bias",
- "mlp.dense_4h_to_h.bias",
- "attention.rotary_emb.inv_freq",
- "final_layernorm.weight",
- "final_layernorm.bias",
- ]
- GLU_LAYERS = [
- "mlp.dense_h_to_4h.weight",
- "mlp.dense_h_to_4h.bias",
- ]
- QUANTIZED_LAYERS = [
- "attention.dense.weight",
- "attention.query_key_value.weight",
- "mlp.dense_h_to_4h.weight",
- "mlp.dense_4h_to_h.weight",
- ]
- LAYER_CONCAT_DIM = {"attention.dense.weight": 1, "mlp.dense_4h_to_h.weight": 1}
- def parse_arguments():
- parser = argparse.ArgumentParser()
- parser.add_argument("--input-folder", default=None, type=str, help="Input SAT checkpoint folder")
- parser.add_argument("--output-folder", default=None, type=str, help="Output SAT checkpoint folder")
- parser.add_argument("--target-tp", default=4, type=int, help="Target TP degree")
- return parser.parse_args()
- def merge_weights(key, sd_list, tp_index, original_tp, target_tp, cat_dim, is_glu):
- if original_tp >= target_tp:
- if is_glu:
- if original_tp > target_tp:
- num_part = original_tp // target_tp
- assert len(sd_list) == num_part
- part1, part2 = [], []
- for i in range(len(sd_list)):
- chunks = torch.chunk(sd_list[i][key], 2, dim=cat_dim)
- part1.append(chunks[0])
- part2.append(chunks[1])
- merged_sd = torch.cat(part1 + part2, dim=cat_dim)
- else:
- merged_sd = sd_list[0][key]
- else:
- merged_sd = torch.cat([sd[key] for sd in sd_list], dim=cat_dim)
- else:
- assert len(sd_list) == 1
- num_part = target_tp // original_tp
- if is_glu:
- offset = tp_index % num_part
- chunks = torch.chunk(sd_list[0][key], num_part * 2, dim=cat_dim)
- merged_sd = torch.cat([chunks[offset], chunks[num_part + offset]], dim=cat_dim)
- else:
- # without clone, torch will save entire tensor
- merged_sd = torch.chunk(sd_list[0][key], num_part, dim=cat_dim)[tp_index % num_part].clone()
- return merged_sd
- def create_checkpoint(sd_list, tp_index, original_tp, target_tp):
- new_sd = {}
- for key in sd_list[0].keys():
- name = ".".join(key.split(".")[3 if key.startswith("transformer.layers") else 1 :])
- if name in SEQUENTIAL_LAYERS:
- new_sd[key] = sd_list[0][key]
- else:
- new_sd[key] = merge_weights(
- key,
- sd_list,
- tp_index=tp_index,
- original_tp=original_tp,
- target_tp=target_tp,
- cat_dim=LAYER_CONCAT_DIM.get(name, 0),
- is_glu=name in GLU_LAYERS,
- )
- new_sd = {"module": new_sd}
- return new_sd
- def main(args):
- iteration = open(os.path.join(args.input_folder, "latest"), "r").read()
- original_tp = len(glob.glob(os.path.join(args.input_folder, iteration, "mp_rank_*_model_states.pt")))
- print(f"Iteration {iteration} from {args.input_folder} to {args.output_folder}")
- os.makedirs(args.output_folder, exist_ok=True)
- with open(os.path.join(args.output_folder, "latest"), "w") as file:
- file.write(str(iteration))
- os.makedirs(os.path.join(args.output_folder, iteration), exist_ok=True)
- for i in range(0, args.target_tp):
- save_path = os.path.join(args.output_folder, iteration, f"mp_rank_{i:02}_model_states.pt")
- print(f"Processing {save_path}")
- num_parts = original_tp // args.target_tp
- sd_list = [
- torch.load(
- os.path.join(args.input_folder, iteration, f"mp_rank_{j:02}_model_states.pt"), map_location="cpu"
- )["module"]
- for j in (
- range(i * num_parts, (i + 1) * num_parts)
- if args.target_tp <= original_tp
- else [i // (args.target_tp // original_tp)]
- )
- ]
- torch.save(create_checkpoint(sd_list, i, original_tp, args.target_tp), save_path)
- if __name__ == "__main__":
- args = parse_arguments()
- main(args)
|