model.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import logging
  7. import os
  8. from typing import Dict, Any
  9. import torch
  10. from m4t_scripts.train.configs import CustomModelParams, ModelConfig
  11. from seamless_communication.models.unity import (
  12. UnitYConfig,
  13. UnitYModel,
  14. load_unity_model,
  15. create_unity_model,
  16. )
  17. from seamless_communication.models.unity.loader import load_unity_config
  18. from seamless_communication.models.unity import UnitYT2UConfig
  19. from fairseq2.nn.transformer import TransformerNormOrder
  20. from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig
  21. from fairseq2.models.nllb.builder import NllbConfig
  22. from fairseq2.models.utils.checkpoint_loader import convert_model_state_dict
  23. from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
  24. from seamless_communication.models.unity.loader import UnitYLoader
  25. from fairseq2.models.nllb.loader import NllbLoader
  26. logger = logging.getLogger(__name__)
  27. CPU_DEVICE = torch.device("cpu")
  28. class ModelBuilder:
  29. def __init__(
  30. self,
  31. config: ModelConfig,
  32. dtype: torch.dtype = torch.float16,
  33. device: torch.device = CPU_DEVICE,
  34. ):
  35. self.config = config
  36. self.dtype = dtype
  37. self.device = device
  38. @classmethod
  39. def _sel_and_upd_prefix(cls, kv: Dict[str, Any], prefix: str, new_prefix: str = "") -> Dict[str, Any]:
  40. # fmt: off
  41. return {new_prefix + k[len(prefix):]: v for k, v in kv.items() if k.startswith(prefix)}
  42. # fmt: on
  43. @classmethod
  44. def _load_pretrained_w2v2_encoder(cls, model: UnitYModel, checkpoint_path: str) -> None:
  45. """Load w2v2 encoder model trained in fairseq1"""
  46. logger.info(f"Loading w2v2 weights from {checkpoint_path}")
  47. state_dict = torch.load(checkpoint_path)["model"]
  48. key_map = Wav2Vec2Loader._fairseq_key_map()
  49. key_map.update(
  50. {
  51. r"^encoder.layers\.([0-9]+)\.conv_module.batch_norm.": r"encoder.layers.\1.conv.batch_norm.",
  52. r"^encoder.layers\.([0-9]+)\.conv_module.depthwise_conv.": r"encoder.layers.\1.conv.depthwise_conv.",
  53. r"^encoder.layers\.([0-9]+)\.conv_module.pointwise_conv([0-9]+)\.": (
  54. r"encoder.layers.\1.conv.pointwise_conv\2."
  55. ),
  56. r"^encoder.layers\.([0-9]+)\.conv_module.layer_norm.": r"encoder.layers.\1.conv_layer_norm.",
  57. r"^encoder.layers\.([0-9]+)\.ffn([0-9]+)\.layer_norm.": r"encoder.layers.\1.ffn\2_layer_norm.",
  58. r"^encoder.layers\.([0-9]+)\.ffn([0-9]+)\.w_1.": r"encoder.layers.\1.ffn\2.inner_proj.",
  59. r"^encoder.layers\.([0-9]+)\.ffn([0-9]+)\.w_2.": r"encoder.layers.\1.ffn\2.output_proj.",
  60. r"^encoder.layers\.([0-9]+)\.self_attn.linear_k\.": r"encoder.layers.\1.self_attn.k_proj.",
  61. r"^encoder.layers\.([0-9]+)\.self_attn.linear_q\.": r"encoder.layers.\1.self_attn.q_proj.",
  62. r"^encoder.layers\.([0-9]+)\.self_attn.linear_v\.": r"encoder.layers.\1.self_attn.v_proj.",
  63. r"^encoder.layers\.([0-9]+)\.self_attn.linear_out\.": r"encoder.layers.\1.self_attn.output_proj.",
  64. r"^encoder.layers\.([0-9]+)\.self_attn.linear_pos.weight": (
  65. r"encoder.layers.\1.self_attn.sdpa.r_proj.weight"
  66. ),
  67. r"^encoder.layers\.([0-9]+)\.self_attn.pos_bias_u": r"encoder.layers.\1.self_attn.sdpa.u_bias",
  68. r"^encoder.layers\.([0-9]+)\.self_attn.pos_bias_v": r"encoder.layers.\1.self_attn.sdpa.v_bias",
  69. # overrides existing rule
  70. r"^encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"encoder.layers.\1.layer_norm.",
  71. }
  72. )
  73. state_dict = convert_model_state_dict(state_dict=state_dict, key_map=key_map)
  74. # w2v2_encoder in fairseq2 have encoder layer_norm set to None
  75. for rm_key in ["encoder.layer_norm.bias", "encoder.layer_norm.weight"]:
  76. del state_dict[rm_key]
  77. enc_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="encoder.")
  78. model.speech_encoder.inner.load_state_dict(enc_state_dict, strict=True) # type: ignore
  79. logger.info(f"Loaded w2v2 encoder from {checkpoint_path}")
  80. enc_fronted_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="encoder_frontend.") # noqa
  81. # TODO: reconcile discrepancies between fr1 and fr2 model designs
  82. # fr1-based w2v2 checkpoints with conv positional encoders use relpos self attention
  83. # this is not compatible with the fr2 model design
  84. # model.speech_encoder_frontend.load_state_dict(enc_fronted_state_dict)
  85. # logger.info(f"Loaded w2v2 encoder frontend from {checkpoint_path}")
  86. @classmethod
  87. def _load_pretrained_s2t_decoder(cls, model: UnitYModel, checkpoint_path: str) -> None:
  88. """Load NLLB decoder trained in fairseq1"""
  89. logger.info(f"Loading s2t decoder weights from {checkpoint_path}")
  90. try:
  91. state_dict = torch.load(checkpoint_path)["model"]
  92. except ModuleNotFoundError:
  93. logger.info("If seeing `No module named 'omegaconf'`, run `pip install omegaconf`")
  94. raise
  95. decoder_prefix = "decoder."
  96. shared_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="shared_decoder.", new_prefix=decoder_prefix)
  97. shared_state_dict = convert_model_state_dict(
  98. state_dict=shared_state_dict, key_map=NllbLoader._fairseq_key_map()
  99. )
  100. for rm_key in ["decoder.embed_positions._float_tensor", "decoder.version"]:
  101. del shared_state_dict[rm_key]
  102. decoder_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix=decoder_prefix, new_prefix="")
  103. frontend_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix="decoder_frontend.", new_prefix="")
  104. proj_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix="final_proj.", new_prefix="")
  105. model.text_decoder_frontend.load_state_dict(frontend_state, strict=True)
  106. logger.info(f"Loaded s2t decoder frontend weights from {checkpoint_path}")
  107. model.text_decoder.load_state_dict(decoder_state, strict=True)
  108. logger.info(f"Loaded s2t decoder weights from {checkpoint_path}")
  109. model.final_proj.load_state_dict(proj_state, strict=True)
  110. logger.info(f"Loaded s2t decoder final_proj weights from {checkpoint_path}")
  111. @classmethod
  112. def _load_pretrained_t2u(cls, model: UnitYModel, model_config: UnitYConfig, checkpoint_path: str) -> None:
  113. logger.info(f"Loading t2u weights from {checkpoint_path}")
  114. t2u_model = model.t2u_model
  115. assert t2u_model is not None
  116. try:
  117. state_dict = torch.load(checkpoint_path)["model"]
  118. except ModuleNotFoundError:
  119. logger.info("If seeing `No module named 'omegaconf'`, run `pip install omegaconf`")
  120. raise
  121. state_dict = {k.replace("encoder.", "synthesizer_encoder."): v for k, v in state_dict.items()}
  122. state_dict = convert_model_state_dict(
  123. state_dict=state_dict, key_map=UnitYLoader._fairseq_key_map(config=model_config)
  124. )
  125. t2u_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="t2u_model.", new_prefix="")
  126. t2u_model.load_state_dict(t2u_state_dict)
  127. logger.info(f"Loaded t2u weights from {checkpoint_path}")
  128. def build_model(
  129. self,
  130. ) -> UnitYModel:
  131. config = self.config
  132. logger.info("Initializing model")
  133. if config.from_model is not None:
  134. logger.info(f"Loading model and weights from `{config.from_model}`")
  135. return load_unity_model(config.from_model, device=self.device, dtype=self.dtype)
  136. if config.from_model_config is not None:
  137. logger.info(f"Loading Unity config from `{config.from_model_config}`")
  138. model_config = load_unity_config(config.from_model_config)
  139. elif config.custom_params is not None:
  140. logger.info("Creating custom Unity config")
  141. model_config = self._build_custom_model_config()
  142. else:
  143. raise ValueError("One of params from_model, from_model_config or custom_params has to be set")
  144. logger.info("Building model")
  145. model = create_unity_model(config=model_config, dtype=self.dtype, device=self.device)
  146. if self.config.pretrained_w2v2_path is not None:
  147. self._load_pretrained_w2v2_encoder(model, self.config.pretrained_w2v2_path)
  148. if self.config.pretrained_s2t_decoder_path is not None:
  149. self._load_pretrained_s2t_decoder(model, self.config.pretrained_s2t_decoder_path)
  150. if self.config.pretrained_t2u_path is not None:
  151. self._load_pretrained_t2u(model, model_config, self.config.pretrained_t2u_path)
  152. return model
  153. def _build_custom_model_config(self) -> UnitYConfig:
  154. config = self.config.custom_params
  155. assert config is not None
  156. return UnitYConfig(
  157. model_dim=config.model_embed_dim,
  158. w2v2_encoder_config=Wav2Vec2EncoderConfig(
  159. model_dim=config.model_embed_dim,
  160. max_seq_len=4096,
  161. feature_dim=160,
  162. use_fbank=True,
  163. first_pass_dropout_p=0.0,
  164. layer_norm_features=config.w2v2_encoder_layers_layernorm_features,
  165. feature_extractor_layer_descs=[],
  166. feature_extractor_bias=False,
  167. feature_extractor_layer_norm_convs=False,
  168. feature_grad_scale=0,
  169. num_fbank_channels=80,
  170. fbank_stride=2,
  171. sample_fbank_every_k=1,
  172. pos_encoder_type=config.w2v2_pos_encoder_type,
  173. pos_encoder_depth=config.w2v2_pos_encoder_depth,
  174. pos_conv_kernel_size=config.w2v2_pos_conv_kernel_size,
  175. num_pos_conv_groups=config.w2v2_num_pos_conv_groups,
  176. use_conformer=config.w2v2_encoder_layers_use_conformer,
  177. num_encoder_layers=config.w2v2_encoder_layers,
  178. num_encoder_attn_heads=16,
  179. ffn_inner_dim=config.model_embed_dim * 4,
  180. dropout_p=0.0,
  181. attn_dropout_p=0.0,
  182. layer_drop_p=0.0,
  183. norm_order=TransformerNormOrder.POST,
  184. depthwise_conv_kernel_size=31,
  185. ),
  186. mt_model_config=NllbConfig(
  187. model_dim=config.model_embed_dim,
  188. max_seq_len=1024,
  189. vocabulary_size=config.nllb_vocabulary_size, # num_tokens + langs + spec symbols
  190. pad_idx=0,
  191. num_encoder_layers=config.nllb_encoder_layers,
  192. num_decoder_layers=config.nllb_decoder_layers,
  193. num_encoder_attn_heads=16,
  194. num_decoder_attn_heads=16,
  195. ffn_inner_dim=config.model_embed_dim * 8,
  196. dropout_p=0.1,
  197. ),
  198. t2u_config=UnitYT2UConfig(
  199. model_dim=config.model_embed_dim,
  200. unit_max_seq_len=2048,
  201. unit_vocabulary_size=config.unit_vocabulary_size,
  202. unit_pad_idx=1,
  203. num_encoder_layers=config.t2u_encoder_layers,
  204. num_decoder_layers=config.t2u_decoder_layers,
  205. nar_decoder_frontend_config=None,
  206. nar_decoder_config=None,
  207. num_encoder_attn_heads=16,
  208. num_decoder_attn_heads=16,
  209. ffn_inner_dim=config.model_embed_dim * 8,
  210. dropout_p=0.1,
  211. ),
  212. use_text_encoder=True,
  213. use_conformer_adaptor=False,
  214. num_adaptor_layers=1,
  215. adaptor_kernel_size=8,
  216. adaptor_stride=8,
  217. adaptor_layer_norm=True,
  218. adaptor_dropout_p=0.1,
  219. )
  220. if __name__ == "__main__":
  221. logging.basicConfig(
  222. level=logging.INFO,
  223. format=f"%(asctime)s %(levelname)s -- %(name)s.{os.getpid()}: %(message)s",
  224. )
  225. config = ModelConfig(
  226. custom_params=CustomModelParams(
  227. nllb_vocabulary_size=256103,
  228. ),
  229. pretrained_w2v2_path="/fsx-ust/spopuri/datasets/PT_CKPT/w2v2/w2vbert2rpq_600m_al5.pt",
  230. pretrained_s2t_decoder_path="/fsx-ust/spopuri/datasets/PT_CKPT/S2T/S2T_M4T_V1_V1_cleaned.pt",
  231. pretrained_t2u_path="/fsx-ust/spopuri/datasets/PT_CKPT/T2U/V5_10K_p2_14_80K.pt",
  232. )
  233. builder = ModelBuilder(config=config)
  234. model = ModelBuilder(config=config).build_model()