| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 | 
							- import numpy as np
 
- import torch
 
- """
 
- upsample_scales -> upsample_rates
 
- resblock_dilations -> resblock_dilation_sizes
 
- in_channels -> model_in_dim
 
- out_channels -> upsample_initial_channel
 
- """
 
- def main() -> None:
 
-     # chkpt_root = "/checkpoint/mjhwang/experiments/231007-mel_vocoder-mls_multilingual_6lang/train_mls_multilingual_6lang_subset_hifigan.v1_8gpu_adapt"
 
-     chkpt_root = "/checkpoint/mjhwang/experiments/231112-mel_vocoder-ai_speech_24khz/train_train_highquality_speech_20231111_no16khz_100000_hifigan.v1_8gpu_adapt"
 
-     chkpt = torch.load(f"{chkpt_root}/checkpoint-400000steps.pkl")
 
-     del chkpt["model"]["discriminator"]
 
-     conv_seq_map = {
 
-         ".1.bias": ".bias",
 
-         ".1.weight_g": ".weight_g",
 
-         ".1.weight_v": ".weight_v",
 
-     }
 
-     def update_key(k: str) -> str:
 
-         if k.startswith("input_conv"):
 
-             k = k.replace("input_conv", "conv_pre")
 
-         elif k.startswith("upsamples"):
 
-             k = k.replace("upsamples", "ups")
 
-             for _k, _v in conv_seq_map.items():
 
-                 k = k.replace(_k, _v)
 
-         elif k.startswith("blocks"):
 
-             k = k.replace("blocks", "resblocks")
 
-             for _k, _v in conv_seq_map.items():
 
-                 k = k.replace(_k, _v)
 
-         elif k.startswith("output_conv"):
 
-             k = k.replace("output_conv", "conv_post")
 
-             for _k, _v in conv_seq_map.items():
 
-                 k = k.replace(_k, _v)
 
-         return k
 
-     chkpt["model"] = {update_key(k): v for k, v in chkpt["model"]["generator"].items()}
 
-     stats_path = f"{chkpt_root}/stats.npy"
 
-     stats = np.load(stats_path)
 
-     mean = torch.from_numpy(stats[0].reshape(-1)).float()
 
-     scale = torch.from_numpy(stats[1].reshape(-1)).float()
 
-     chkpt["model"]["mean"] = mean
 
-     chkpt["model"]["scale"] = scale
 
-     for k in ["optimizer", "scheduler", "steps", "epochs"]:
 
-         del chkpt[k]
 
-     # out_path = "/large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt"
 
-     out_path = "/large_experiments/seamless/workstream/expressivity/oss/checkpoints/melhifigan_20231121.pt"
 
-     torch.save(chkpt, out_path)
 
- if __name__ == "__main__":
 
-     main()
 
 
  |