convert_mel_hifigan_chkpt.py 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. import numpy as np
  2. import torch
  3. """
  4. upsample_scales -> upsample_rates
  5. resblock_dilations -> resblock_dilation_sizes
  6. in_channels -> model_in_dim
  7. out_channels -> upsample_initial_channel
  8. """
  9. def main() -> None:
  10. # chkpt_root = "/checkpoint/mjhwang/experiments/231007-mel_vocoder-mls_multilingual_6lang/train_mls_multilingual_6lang_subset_hifigan.v1_8gpu_adapt"
  11. chkpt_root = "/checkpoint/mjhwang/experiments/231112-mel_vocoder-ai_speech_24khz/train_train_highquality_speech_20231111_no16khz_100000_hifigan.v1_8gpu_adapt"
  12. chkpt = torch.load(f"{chkpt_root}/checkpoint-400000steps.pkl")
  13. del chkpt["model"]["discriminator"]
  14. conv_seq_map = {
  15. ".1.bias": ".bias",
  16. ".1.weight_g": ".weight_g",
  17. ".1.weight_v": ".weight_v",
  18. }
  19. def update_key(k: str) -> str:
  20. if k.startswith("input_conv"):
  21. k = k.replace("input_conv", "conv_pre")
  22. elif k.startswith("upsamples"):
  23. k = k.replace("upsamples", "ups")
  24. for _k, _v in conv_seq_map.items():
  25. k = k.replace(_k, _v)
  26. elif k.startswith("blocks"):
  27. k = k.replace("blocks", "resblocks")
  28. for _k, _v in conv_seq_map.items():
  29. k = k.replace(_k, _v)
  30. elif k.startswith("output_conv"):
  31. k = k.replace("output_conv", "conv_post")
  32. for _k, _v in conv_seq_map.items():
  33. k = k.replace(_k, _v)
  34. return k
  35. chkpt["model"] = {update_key(k): v for k, v in chkpt["model"]["generator"].items()}
  36. stats_path = f"{chkpt_root}/stats.npy"
  37. stats = np.load(stats_path)
  38. mean = torch.from_numpy(stats[0].reshape(-1)).float()
  39. scale = torch.from_numpy(stats[1].reshape(-1)).float()
  40. chkpt["model"]["mean"] = mean
  41. chkpt["model"]["scale"] = scale
  42. for k in ["optimizer", "scheduler", "steps", "epochs"]:
  43. del chkpt[k]
  44. # out_path = "/large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt"
  45. out_path = "/large_experiments/seamless/workstream/expressivity/oss/checkpoints/melhifigan_20231121.pt"
  46. torch.save(chkpt, out_path)
  47. if __name__ == "__main__":
  48. main()