convert_pretssel_hifigan_chkpt.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import numpy as np
  2. import torch
  3. """
  4. upsample_scales -> upsample_rates
  5. resblock_dilations -> resblock_dilation_sizes
  6. in_channels -> model_in_dim
  7. out_channels -> upsample_initial_channel
  8. """
  9. def main():
  10. chkpt_root = "/checkpoint/mjhwang/experiments/231007-mel_vocoder-mls_multilingual_6lang/train_mls_multilingual_6lang_subset_hifigan.v1_8gpu_adapt"
  11. cfg = f"{chkpt_root}/config.yml"
  12. # TODO: display cfg
  13. chkpt = torch.load(f"{chkpt_root}/checkpoint-400000steps.pkl")
  14. del chkpt["model"]["discriminator"]
  15. conv_seq_map = {
  16. ".1.bias": ".bias",
  17. ".1.weight_g": ".weight_g",
  18. ".1.weight_v": ".weight_v",
  19. }
  20. def update_key(k):
  21. if k.startswith("input_conv"):
  22. k = k.replace("input_conv", "conv_pre")
  23. elif k.startswith("upsamples"):
  24. k = k.replace("upsamples", "ups")
  25. for _k, _v in conv_seq_map.items():
  26. k = k.replace(_k, _v)
  27. elif k.startswith("blocks"):
  28. k = k.replace("blocks", "resblocks")
  29. for _k, _v in conv_seq_map.items():
  30. k = k.replace(_k, _v)
  31. elif k.startswith("output_conv"):
  32. k = k.replace("output_conv", "conv_post")
  33. for _k, _v in conv_seq_map.items():
  34. k = k.replace(_k, _v)
  35. return k
  36. chkpt["model"] = {update_key(k): v for k, v in chkpt["model"]["generator"].items()}
  37. stats_path = f"{chkpt_root}/stats.npy"
  38. stats = np.load(stats_path)
  39. mean = torch.from_numpy(stats[0].reshape(-1)).float()
  40. scale = torch.from_numpy(stats[1].reshape(-1)).float()
  41. chkpt["model"]["mean"] = mean
  42. chkpt["model"]["scale"] = scale
  43. for k in ["optimizer", "scheduler", "steps", "epochs"]:
  44. del chkpt[k]
  45. out_path = "/large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt"
  46. torch.save(chkpt, out_path)
  47. if __name__ == "__main__":
  48. main()