model.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import logging
  7. import os
  8. from typing import Dict, Any
  9. import torch
  10. from m4t_scripts.train.configs import CustomModelParams, ModelConfig
  11. from seamless_communication.models.unity import (
  12. UnitYConfig,
  13. UnitYModel,
  14. load_unity_model,
  15. create_unity_model,
  16. )
  17. from seamless_communication.models.unity.loader import load_unity_config
  18. from seamless_communication.models.unity import UnitYT2UConfig
  19. from fairseq2.nn.transformer import TransformerNormOrder
  20. from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig
  21. from fairseq2.models.nllb.builder import NllbConfig
  22. from fairseq2.models.utils.checkpoint_loader import convert_model_state_dict
  23. from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
  24. from seamless_communication.models.unity.loader import UnitYLoader
  25. from fairseq2.models.nllb.loader import NllbLoader
  26. logger = logging.getLogger(__name__)
  27. CPU_DEVICE = torch.device("cpu")
  28. class ModelBuilder:
  29. def __init__(
  30. self,
  31. config: ModelConfig,
  32. dtype: torch.dtype = torch.float16,
  33. device: torch.device = CPU_DEVICE,
  34. ):
  35. self.config = config
  36. self.dtype = dtype
  37. self.device = device
  38. @classmethod
  39. def _sel_and_upd_prefix(
  40. cls, kv: Dict[str, Any], prefix: str, new_prefix: str = ""
  41. ) -> Dict[str, Any]:
  42. # fmt: off
  43. return {new_prefix + k[len(prefix):]: v for k, v in kv.items() if k.startswith(prefix)}
  44. # fmt: on
  45. @classmethod
  46. def _load_pretrained_w2v2_encoder(
  47. cls, model: UnitYModel, checkpoint_path: str
  48. ) -> None:
  49. """Load w2v2 encoder model trained in fairseq1"""
  50. logger.info(f"Loading w2v2 weights from {checkpoint_path}")
  51. state_dict = torch.load(checkpoint_path)["model"]
  52. key_map = Wav2Vec2Loader._fairseq_key_map()
  53. key_map.update(
  54. {
  55. r"^encoder.layers\.([0-9]+)\.conv_module.batch_norm.": r"encoder.layers.\1.conv.batch_norm.",
  56. r"^encoder.layers\.([0-9]+)\.conv_module.depthwise_conv.": r"encoder.layers.\1.conv.depthwise_conv.",
  57. r"^encoder.layers\.([0-9]+)\.conv_module.pointwise_conv([0-9]+)\.": (
  58. r"encoder.layers.\1.conv.pointwise_conv\2."
  59. ),
  60. r"^encoder.layers\.([0-9]+)\.conv_module.layer_norm.": r"encoder.layers.\1.conv_layer_norm.",
  61. r"^encoder.layers\.([0-9]+)\.ffn([0-9]+)\.layer_norm.": r"encoder.layers.\1.ffn\2_layer_norm.",
  62. r"^encoder.layers\.([0-9]+)\.ffn([0-9]+)\.w_1.": r"encoder.layers.\1.ffn\2.inner_proj.",
  63. r"^encoder.layers\.([0-9]+)\.ffn([0-9]+)\.w_2.": r"encoder.layers.\1.ffn\2.output_proj.",
  64. r"^encoder.layers\.([0-9]+)\.self_attn.linear_k\.": r"encoder.layers.\1.self_attn.k_proj.",
  65. r"^encoder.layers\.([0-9]+)\.self_attn.linear_q\.": r"encoder.layers.\1.self_attn.q_proj.",
  66. r"^encoder.layers\.([0-9]+)\.self_attn.linear_v\.": r"encoder.layers.\1.self_attn.v_proj.",
  67. r"^encoder.layers\.([0-9]+)\.self_attn.linear_out\.": r"encoder.layers.\1.self_attn.output_proj.",
  68. r"^encoder.layers\.([0-9]+)\.self_attn.linear_pos.weight": (
  69. r"encoder.layers.\1.self_attn.sdpa.r_proj.weight"
  70. ),
  71. r"^encoder.layers\.([0-9]+)\.self_attn.pos_bias_u": r"encoder.layers.\1.self_attn.sdpa.u_bias",
  72. r"^encoder.layers\.([0-9]+)\.self_attn.pos_bias_v": r"encoder.layers.\1.self_attn.sdpa.v_bias",
  73. # overrides existing rule
  74. r"^encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"encoder.layers.\1.layer_norm.",
  75. }
  76. )
  77. state_dict = convert_model_state_dict(state_dict=state_dict, key_map=key_map)
  78. # w2v2_encoder in fairseq2 have encoder layer_norm set to None
  79. for rm_key in ["encoder.layer_norm.bias", "encoder.layer_norm.weight"]:
  80. del state_dict[rm_key]
  81. enc_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="encoder.")
  82. model.speech_encoder.inner.load_state_dict(enc_state_dict, strict=True) # type: ignore
  83. logger.info(f"Loaded w2v2 encoder from {checkpoint_path}")
  84. enc_fronted_state_dict = cls._sel_and_upd_prefix( # noqa
  85. kv=state_dict, prefix="encoder_frontend."
  86. ) # noqa
  87. # TODO: reconcile discrepancies between fr1 and fr2 model designs
  88. # fr1-based w2v2 checkpoints with conv positional encoders use relpos self attention
  89. # this is not compatible with the fr2 model design
  90. # model.speech_encoder_frontend.load_state_dict(enc_fronted_state_dict)
  91. # logger.info(f"Loaded w2v2 encoder frontend from {checkpoint_path}")
  92. @classmethod
  93. def _load_pretrained_s2t_decoder(
  94. cls, model: UnitYModel, checkpoint_path: str
  95. ) -> None:
  96. """Load NLLB decoder trained in fairseq1"""
  97. logger.info(f"Loading s2t decoder weights from {checkpoint_path}")
  98. try:
  99. state_dict = torch.load(checkpoint_path)["model"]
  100. except ModuleNotFoundError:
  101. logger.info(
  102. "If seeing `No module named 'omegaconf'`, run `pip install omegaconf`"
  103. )
  104. raise
  105. decoder_prefix = "decoder."
  106. shared_state_dict = cls._sel_and_upd_prefix(
  107. kv=state_dict, prefix="shared_decoder.", new_prefix=decoder_prefix
  108. )
  109. shared_state_dict = convert_model_state_dict(
  110. state_dict=shared_state_dict, key_map=NllbLoader._fairseq_key_map()
  111. )
  112. for rm_key in ["decoder.embed_positions._float_tensor", "decoder.version"]:
  113. del shared_state_dict[rm_key]
  114. decoder_state = cls._sel_and_upd_prefix(
  115. kv=shared_state_dict, prefix=decoder_prefix, new_prefix=""
  116. )
  117. frontend_state = cls._sel_and_upd_prefix(
  118. kv=shared_state_dict, prefix="decoder_frontend.", new_prefix=""
  119. )
  120. proj_state = cls._sel_and_upd_prefix(
  121. kv=shared_state_dict, prefix="final_proj.", new_prefix=""
  122. )
  123. model.text_decoder_frontend.load_state_dict(frontend_state, strict=True)
  124. logger.info(f"Loaded s2t decoder frontend weights from {checkpoint_path}")
  125. model.text_decoder.load_state_dict(decoder_state, strict=True)
  126. logger.info(f"Loaded s2t decoder weights from {checkpoint_path}")
  127. model.final_proj.load_state_dict(proj_state, strict=True)
  128. logger.info(f"Loaded s2t decoder final_proj weights from {checkpoint_path}")
  129. @classmethod
  130. def _load_pretrained_t2u(
  131. cls, model: UnitYModel, model_config: UnitYConfig, checkpoint_path: str
  132. ) -> None:
  133. logger.info(f"Loading t2u weights from {checkpoint_path}")
  134. t2u_model = model.t2u_model
  135. assert t2u_model is not None
  136. try:
  137. state_dict = torch.load(checkpoint_path)["model"]
  138. except ModuleNotFoundError:
  139. logger.info(
  140. "If seeing `No module named 'omegaconf'`, run `pip install omegaconf`"
  141. )
  142. raise
  143. state_dict = {
  144. k.replace("encoder.", "synthesizer_encoder."): v
  145. for k, v in state_dict.items()
  146. }
  147. state_dict = convert_model_state_dict(
  148. state_dict=state_dict,
  149. key_map=UnitYLoader._fairseq_key_map(config=model_config),
  150. )
  151. t2u_state_dict = cls._sel_and_upd_prefix(
  152. kv=state_dict, prefix="t2u_model.", new_prefix=""
  153. )
  154. t2u_model.load_state_dict(t2u_state_dict)
  155. logger.info(f"Loaded t2u weights from {checkpoint_path}")
  156. def build_model(
  157. self,
  158. ) -> UnitYModel:
  159. config = self.config
  160. logger.info("Initializing model")
  161. if config.from_model is not None:
  162. logger.info(f"Loading model and weights from `{config.from_model}`")
  163. return load_unity_model(
  164. config.from_model, device=self.device, dtype=self.dtype
  165. )
  166. if config.from_model_config is not None:
  167. logger.info(f"Loading Unity config from `{config.from_model_config}`")
  168. model_config = load_unity_config(config.from_model_config)
  169. elif config.custom_params is not None:
  170. logger.info("Creating custom Unity config")
  171. model_config = self._build_custom_model_config()
  172. else:
  173. raise ValueError(
  174. "One of params from_model, from_model_config or custom_params has to be set"
  175. )
  176. logger.info("Building model")
  177. model = create_unity_model(
  178. config=model_config, dtype=self.dtype, device=self.device
  179. )
  180. if self.config.pretrained_w2v2_path is not None:
  181. self._load_pretrained_w2v2_encoder(model, self.config.pretrained_w2v2_path)
  182. if self.config.pretrained_s2t_decoder_path is not None:
  183. self._load_pretrained_s2t_decoder(
  184. model, self.config.pretrained_s2t_decoder_path
  185. )
  186. if self.config.pretrained_t2u_path is not None:
  187. self._load_pretrained_t2u(
  188. model, model_config, self.config.pretrained_t2u_path
  189. )
  190. logger.info(f"Number of model params: {self._get_num_model_params(model)}")
  191. return model
  192. @classmethod
  193. def _get_num_model_params(cls, model: torch.nn.Module) -> int:
  194. pp = 0
  195. for p in list(model.parameters()):
  196. nn = 1
  197. for s in list(p.size()):
  198. nn = nn * s
  199. pp += nn
  200. return pp
  201. def _build_custom_model_config(self) -> UnitYConfig:
  202. config = self.config.custom_params
  203. assert config is not None
  204. return UnitYConfig(
  205. model_dim=config.model_embed_dim,
  206. w2v2_encoder_config=Wav2Vec2EncoderConfig(
  207. model_dim=config.model_embed_dim,
  208. max_seq_len=4096,
  209. feature_dim=160,
  210. use_fbank=True,
  211. first_pass_dropout_p=0.0,
  212. layer_norm_features=config.w2v2_encoder_layers_layernorm_features,
  213. feature_extractor_layer_descs=[],
  214. feature_extractor_bias=False,
  215. feature_extractor_layer_norm_convs=False,
  216. feature_grad_scale=0,
  217. num_fbank_channels=80,
  218. fbank_stride=2,
  219. sample_fbank_every_k=1,
  220. pos_encoder_type=config.w2v2_pos_encoder_type,
  221. pos_encoder_depth=config.w2v2_pos_encoder_depth,
  222. pos_conv_kernel_size=config.w2v2_pos_conv_kernel_size,
  223. num_pos_conv_groups=config.w2v2_num_pos_conv_groups,
  224. use_conformer=config.w2v2_encoder_layers_use_conformer,
  225. num_encoder_layers=config.w2v2_encoder_layers,
  226. num_encoder_attn_heads=16,
  227. ffn_inner_dim=config.model_embed_dim * 4,
  228. dropout_p=0.0,
  229. attn_dropout_p=0.0,
  230. layer_drop_p=0.0,
  231. norm_order=TransformerNormOrder.POST,
  232. depthwise_conv_kernel_size=31,
  233. ),
  234. mt_model_config=NllbConfig(
  235. model_dim=config.model_embed_dim,
  236. max_seq_len=1024,
  237. vocabulary_size=config.nllb_vocabulary_size, # num_tokens + langs + spec symbols
  238. pad_idx=0,
  239. num_encoder_layers=config.nllb_encoder_layers,
  240. num_decoder_layers=config.nllb_decoder_layers,
  241. num_encoder_attn_heads=16,
  242. num_decoder_attn_heads=16,
  243. ffn_inner_dim=config.model_embed_dim * 8,
  244. dropout_p=0.1,
  245. ),
  246. t2u_config=UnitYT2UConfig(
  247. model_dim=config.model_embed_dim,
  248. unit_max_seq_len=2048,
  249. unit_vocabulary_size=config.unit_vocabulary_size,
  250. unit_pad_idx=1,
  251. num_encoder_layers=config.t2u_encoder_layers,
  252. num_decoder_layers=config.t2u_decoder_layers,
  253. nar_decoder_frontend_config=None,
  254. nar_decoder_config=None,
  255. num_encoder_attn_heads=16,
  256. num_decoder_attn_heads=16,
  257. ffn_inner_dim=config.model_embed_dim * 8,
  258. dropout_p=0.1,
  259. ),
  260. use_text_encoder=True,
  261. use_conformer_adaptor=False,
  262. num_adaptor_layers=1,
  263. adaptor_kernel_size=8,
  264. adaptor_stride=8,
  265. adaptor_layer_norm=True,
  266. adaptor_dropout_p=0.1,
  267. )
  268. if __name__ == "__main__":
  269. logging.basicConfig(
  270. level=logging.INFO,
  271. format=f"%(asctime)s %(levelname)s -- %(name)s.{os.getpid()}: %(message)s",
  272. )
  273. config = ModelConfig(
  274. custom_params=CustomModelParams(
  275. nllb_vocabulary_size=256103,
  276. ),
  277. pretrained_w2v2_path="/fsx-ust/spopuri/datasets/PT_CKPT/w2v2/w2vbert2rpq_600m_al5.pt",
  278. pretrained_s2t_decoder_path="/fsx-ust/spopuri/datasets/PT_CKPT/S2T/S2T_M4T_V1_V1_cleaned.pt",
  279. pretrained_t2u_path="/fsx-ust/spopuri/datasets/PT_CKPT/T2U/V5_10K_p2_14_80K.pt",
  280. )
  281. builder = ModelBuilder(config=config)
  282. model = ModelBuilder(config=config).build_model()