compile_chkpt.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. #
  7. #
  8. # The rules to blend the p2v decoder, mel-vocoder and the watermarking:
  9. #
  10. # Step 1) Make the big sequential module `layers` that consists of:
  11. # - PostNet (last layer of the p2v decoder) : 5 layers
  12. # - mel-vocoder layers (conv_pre, ups, resblocks, conv_post): 18 layers
  13. # - watermarking encoder and decoder: 32 layers
  14. #
  15. # Step 2) Take the last 32 layers of the watermarking, split into 4 blocks of
  16. # 8 layers. Mix these blocks into the previous layers
  17. #
  18. # The final mixed architecture SPVM (Spaghetti Pretssel Vovoder Model):
  19. #
  20. # <P2V: Post Net>
  21. # |
  22. # <Block 1 of Watermarker> ------
  23. # | |
  24. # \/ |
  25. # <Melvocoder: Conv_pre> |
  26. # | (skip) |
  27. # <Block 2 of Watermarker> -----|
  28. # | |
  29. # \/ |
  30. # <Melvocoder: Upsampler> |
  31. # | (skip) |
  32. # <Block 3 of Watermarker> -----|
  33. # | |
  34. # \/ |
  35. # <Melvocoder: Resblocks> |
  36. # | (skip) |
  37. # <Block 4 of Watermarker> -----|
  38. # | |
  39. # \/ |
  40. # <Melvocoder: Conv_post> |
  41. # | |
  42. # | ------------------|
  43. # |
  44. # \/
  45. # watermarked wavs
  46. from pathlib import Path
  47. from argparse import ArgumentParser
  48. from typing import Any, Mapping, Match
  49. import torch
  50. from fairseq2.models.utils.checkpoint import (
  51. convert_fairseq_checkpoint,
  52. convert_model_state_dict,
  53. load_checkpoint,
  54. )
  55. def pretssel_key_map() -> Mapping[str, str]:
  56. """
  57. The rule for renaming the layers of Pretssel model checkpoint:
  58. - Merge decoder.postnet into `layers`
  59. """
  60. from seamless_communication.models.pretssel.loader import _fairseq_key_map # noqa
  61. key_map = _fairseq_key_map(None) # type: ignore[arg-type]
  62. del key_map[r"^decoder\.postnet\."]
  63. key_map[r"^decoder\.postnet\.convolutions\."] = r"layers."
  64. return key_map
  65. def vocoder_key_map() -> Mapping[str, Any]:
  66. """
  67. Rename layers in the mel-vocoder checkpoint. We flatten the vocoder arch and put everything
  68. into the `layers`, in which `postnet_size` layers from the PostNet already present. In other words:
  69. - conv_pre -> layers.<postnet_size + watermark / 4>
  70. - ups.i -> layers.<postnet_size + 1 + i + watermark_size / 2>
  71. - resblocks.i -> layers.<postnet_size + 1 + ups_size + i + 3 * watermark_size / 4>
  72. - conv_post.i -> layers.<postnet_size + 1 + ups_size + resblocks_size + i + watermark_size>
  73. """
  74. return {
  75. # fmt: off
  76. # postnet_size = 5, 1st wm block = 8 -> 13
  77. r"^conv_pre\.": r"layers.13.", # noqa, 2nd wm block = 8 -> +8
  78. r"^ups\.([0-9]+)\.": lambda x: f"layers.{int(x.group(1)) + 22}.", # noqa, ups_size = 4, 3rd wm block = 8 -> +12
  79. r"^resblocks\.([0-9]+)\.": lambda x: f"layers.{int(x.group(1)) + 34}.", # noqa, resblocks_size = 12, 4th wm block = 8 -> +20
  80. r"^conv_post\.": r"layers.54.",
  81. # fmt: on
  82. }
  83. def wm_key_map() -> Mapping[Any, Any]:
  84. """
  85. flatten all encoders and decoders into the one sequential layer (step 1) and split the watermaker
  86. into 4 blocks and mix into the layers of the p2v decoder and mel-vocoder:
  87. - encoder.model.[0-7] --> layers.<postnet_size + i> (5 --> 12)
  88. - encoder.model.[8-15] --> layers.<postnet_size + 9> (14 --> 21)
  89. - decoder.model.[0-7] --> layers.<postnet_size + vocoder_ups_size + conv_pre + 16> (26 -> 33)
  90. - decoder.model.[8-15] --> layers.<postnet_size + vocoder_ups_size + conv_pre + resblock_size + 24> (46 -> 53)
  91. """
  92. def encoder_layer_index(match_obj: Match[str]) -> str:
  93. idx = int(match_obj.group(1))
  94. # First half of the encoder is after the PostNet
  95. if idx < 8:
  96. # postnet_size = 5
  97. return f"layers.{idx + 5}."
  98. # Second half of the encoder goes after the mel-vocoder:conv_pre
  99. else:
  100. # postnet = 5, conv_pre = 1 --> +6
  101. return f"layers.{idx + 6}."
  102. def decoder_layer_index(match_obj: Match[str]) -> str:
  103. idx = int(match_obj.group(1))
  104. # First half of the decoder is after the mel-vocoder:ups
  105. if idx < 8:
  106. # postnet 5, conv_pre 1, encoder 16, ups 4 --> +26
  107. return f"layers.{idx + 26}."
  108. else:
  109. # postnet 5, conv_pre 1, encoder 16, ups 4, resblock 12 -> +38
  110. return f"layers.{idx + 38}."
  111. return {
  112. r"^encoder\.model\.([0-9]+)\.": encoder_layer_index,
  113. r"^decoder\.model\.([0-9]+)\.": decoder_layer_index,
  114. }
  115. def combine_chkpts(pretssel_file: str, vocoder_file: str, out_path: str) -> None:
  116. """Combine the pretssel and melhifigan into one model"""
  117. pretssel_chkpt = load_checkpoint(pretssel_file)
  118. pretssel_chkpt = convert_fairseq_checkpoint(pretssel_chkpt, pretssel_key_map())
  119. vocoder_chkpt = load_checkpoint(vocoder_file)
  120. vocoder_chkpt = convert_fairseq_checkpoint(vocoder_chkpt, vocoder_key_map())
  121. wm_ckpt = load_checkpoint(
  122. "/large_experiments/seamless/nllb/watermarking/checkpoints/ckpt_e9d0008c.th",
  123. )
  124. # wm_ckpt is not a fairseq2 checkpoint so we have to handle it differently
  125. wm_ckpt = convert_model_state_dict(wm_ckpt, wm_key_map())
  126. # Merge the state dicts
  127. ckpt = pretssel_chkpt
  128. state_dict = ckpt["model"]
  129. for vocoder_key in vocoder_chkpt["model"]:
  130. state_dict[vocoder_key] = vocoder_chkpt["model"][vocoder_key]
  131. for wm_key, wm_val in wm_ckpt.items():
  132. if wm_key.startswith("layers."):
  133. state_dict[wm_key] = wm_val
  134. # Remove obsolete layers
  135. keys_to_delete = [
  136. "encoder.embed_positions._float_tensor",
  137. "decoder.embed_positions._float_tensor",
  138. "enc_emb_proj.weight",
  139. "enc_emb_proj.bias",
  140. ]
  141. keys_to_delete.extend(
  142. [
  143. key
  144. for key in state_dict
  145. if key.startswith("decoder.var_adaptor.duration_predictor")
  146. ]
  147. )
  148. for key in keys_to_delete:
  149. if key in state_dict:
  150. del state_dict[key]
  151. out_path = "/large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-final.pt"
  152. model_mapping_metafile = Path(out_path).with_suffix(".arch")
  153. with open(model_mapping_metafile, "w", encoding="utf-8") as o:
  154. o.write(vocoder_key_map.__doc__) # type: ignore
  155. o.write("\n")
  156. o.write(wm_key_map.__doc__) # type: ignore
  157. o.write("\n")
  158. torch.save(ckpt, out_path)
  159. if __name__ == "__main__":
  160. # fmt: off
  161. parser = ArgumentParser(description="Compile watermarking into p2v decoder and vocoder")
  162. parser.add_argument(
  163. "--pretssel",
  164. default="/checkpoint/mjhwang/experiments/230930-noiseaug_p2v-mls_multilingual_6lang/231005-noiseaug_p2v-mls_multilingual_6lang-alignfix.config_v2.langemb1.vuv_logit1.denoise.ngpu16/checkpoint_best.pt",
  165. type=str,
  166. help="Path to the Pretssel model checkpoint",
  167. )
  168. parser.add_argument(
  169. "--vocoder",
  170. # default="/large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt",
  171. default="/large_experiments/seamless/workstream/expressivity/oss/checkpoints/melhifigan_20231121.pt",
  172. type=str,
  173. help="Path to the mel-vocoder checkpoint",
  174. )
  175. parser.add_argument(
  176. "--output",
  177. default="/large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-final.pt",
  178. # default="/large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-20231121.pt",
  179. type=str,
  180. help="Path to the output watermarked model checkpoint",
  181. )
  182. # fmt: on
  183. args = parser.parse_args()
  184. combine_chkpts(args.pretssel, args.vocoder, args.output)