123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258 |
- # Copyright (c) Meta Platforms, Inc. and affiliates.
- # All rights reserved.
- #
- # This source code is licensed under the BSD-style license found in the
- # LICENSE file in the root directory of this source tree.
- import logging
- import os
- from typing import Dict, Any
- import torch
- from m4t_scripts.train.configs import CustomModelParams, ModelConfig
- from seamless_communication.models.unity import (
- UnitYConfig,
- UnitYModel,
- load_unity_model,
- create_unity_model,
- )
- from seamless_communication.models.unity.loader import load_unity_config
- from seamless_communication.models.unity import UnitYT2UConfig
- from fairseq2.nn.transformer import TransformerNormOrder
- from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig
- from fairseq2.models.nllb.builder import NllbConfig
- from fairseq2.models.utils.checkpoint_loader import convert_model_state_dict
- from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
- from seamless_communication.models.unity.loader import UnitYLoader
- from fairseq2.models.nllb.loader import NllbLoader
- logger = logging.getLogger(__name__)
- CPU_DEVICE = torch.device("cpu")
- class ModelBuilder:
- def __init__(
- self,
- config: ModelConfig,
- dtype: torch.dtype = torch.float16,
- device: torch.device = CPU_DEVICE,
- ):
- self.config = config
- self.dtype = dtype
- self.device = device
- @classmethod
- def _sel_and_upd_prefix(cls, kv: Dict[str, Any], prefix: str, new_prefix: str = "") -> Dict[str, Any]:
- # fmt: off
- return {new_prefix + k[len(prefix):]: v for k, v in kv.items() if k.startswith(prefix)}
- # fmt: on
- @classmethod
- def _load_pretrained_w2v2_encoder(cls, model: UnitYModel, checkpoint_path: str) -> None:
- """Load w2v2 encoder model trained in fairseq1"""
- logger.info(f"Loading w2v2 weights from {checkpoint_path}")
- state_dict = torch.load(checkpoint_path)["model"]
- key_map = Wav2Vec2Loader._fairseq_key_map()
- key_map.update(
- {
- r"^encoder.layers\.([0-9]+)\.conv_module.batch_norm.": r"encoder.layers.\1.conv.batch_norm.",
- r"^encoder.layers\.([0-9]+)\.conv_module.depthwise_conv.": r"encoder.layers.\1.conv.depthwise_conv.",
- r"^encoder.layers\.([0-9]+)\.conv_module.pointwise_conv([0-9]+)\.": (
- r"encoder.layers.\1.conv.pointwise_conv\2."
- ),
- r"^encoder.layers\.([0-9]+)\.conv_module.layer_norm.": r"encoder.layers.\1.conv_layer_norm.",
- r"^encoder.layers\.([0-9]+)\.ffn([0-9]+)\.layer_norm.": r"encoder.layers.\1.ffn\2_layer_norm.",
- r"^encoder.layers\.([0-9]+)\.ffn([0-9]+)\.w_1.": r"encoder.layers.\1.ffn\2.inner_proj.",
- r"^encoder.layers\.([0-9]+)\.ffn([0-9]+)\.w_2.": r"encoder.layers.\1.ffn\2.output_proj.",
- r"^encoder.layers\.([0-9]+)\.self_attn.linear_k\.": r"encoder.layers.\1.self_attn.k_proj.",
- r"^encoder.layers\.([0-9]+)\.self_attn.linear_q\.": r"encoder.layers.\1.self_attn.q_proj.",
- r"^encoder.layers\.([0-9]+)\.self_attn.linear_v\.": r"encoder.layers.\1.self_attn.v_proj.",
- r"^encoder.layers\.([0-9]+)\.self_attn.linear_out\.": r"encoder.layers.\1.self_attn.output_proj.",
- r"^encoder.layers\.([0-9]+)\.self_attn.linear_pos.weight": (
- r"encoder.layers.\1.self_attn.sdpa.r_proj.weight"
- ),
- r"^encoder.layers\.([0-9]+)\.self_attn.pos_bias_u": r"encoder.layers.\1.self_attn.sdpa.u_bias",
- r"^encoder.layers\.([0-9]+)\.self_attn.pos_bias_v": r"encoder.layers.\1.self_attn.sdpa.v_bias",
- # overrides existing rule
- r"^encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"encoder.layers.\1.layer_norm.",
- }
- )
- state_dict = convert_model_state_dict(state_dict=state_dict, key_map=key_map)
- # w2v2_encoder in fairseq2 have encoder layer_norm set to None
- for rm_key in ["encoder.layer_norm.bias", "encoder.layer_norm.weight"]:
- del state_dict[rm_key]
- enc_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="encoder.")
- model.speech_encoder.inner.load_state_dict(enc_state_dict, strict=True) # type: ignore
- logger.info(f"Loaded w2v2 encoder from {checkpoint_path}")
- enc_fronted_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="encoder_frontend.") # noqa
- # TODO: reconcile discrepancies between fr1 and fr2 model designs
- # fr1-based w2v2 checkpoints with conv positional encoders use relpos self attention
- # this is not compatible with the fr2 model design
- # model.speech_encoder_frontend.load_state_dict(enc_fronted_state_dict)
- # logger.info(f"Loaded w2v2 encoder frontend from {checkpoint_path}")
- @classmethod
- def _load_pretrained_s2t_decoder(cls, model: UnitYModel, checkpoint_path: str) -> None:
- """Load NLLB decoder trained in fairseq1"""
- logger.info(f"Loading s2t decoder weights from {checkpoint_path}")
- try:
- state_dict = torch.load(checkpoint_path)["model"]
- except ModuleNotFoundError:
- logger.info("If seeing `No module named 'omegaconf'`, run `pip install omegaconf`")
- raise
- decoder_prefix = "decoder."
- shared_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="shared_decoder.", new_prefix=decoder_prefix)
- shared_state_dict = convert_model_state_dict(
- state_dict=shared_state_dict, key_map=NllbLoader._fairseq_key_map()
- )
- for rm_key in ["decoder.embed_positions._float_tensor", "decoder.version"]:
- del shared_state_dict[rm_key]
- decoder_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix=decoder_prefix, new_prefix="")
- frontend_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix="decoder_frontend.", new_prefix="")
- proj_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix="final_proj.", new_prefix="")
- model.text_decoder_frontend.load_state_dict(frontend_state, strict=True)
- logger.info(f"Loaded s2t decoder frontend weights from {checkpoint_path}")
- model.text_decoder.load_state_dict(decoder_state, strict=True)
- logger.info(f"Loaded s2t decoder weights from {checkpoint_path}")
- model.final_proj.load_state_dict(proj_state, strict=True)
- logger.info(f"Loaded s2t decoder final_proj weights from {checkpoint_path}")
- @classmethod
- def _load_pretrained_t2u(cls, model: UnitYModel, model_config: UnitYConfig, checkpoint_path: str) -> None:
- logger.info(f"Loading t2u weights from {checkpoint_path}")
- t2u_model = model.t2u_model
- assert t2u_model is not None
- try:
- state_dict = torch.load(checkpoint_path)["model"]
- except ModuleNotFoundError:
- logger.info("If seeing `No module named 'omegaconf'`, run `pip install omegaconf`")
- raise
- state_dict = {k.replace("encoder.", "synthesizer_encoder."): v for k, v in state_dict.items()}
- state_dict = convert_model_state_dict(
- state_dict=state_dict, key_map=UnitYLoader._fairseq_key_map(config=model_config)
- )
- t2u_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="t2u_model.", new_prefix="")
- t2u_model.load_state_dict(t2u_state_dict)
- logger.info(f"Loaded t2u weights from {checkpoint_path}")
- def build_model(
- self,
- ) -> UnitYModel:
- config = self.config
- logger.info("Initializing model")
- if config.from_model is not None:
- logger.info(f"Loading model and weights from `{config.from_model}`")
- return load_unity_model(config.from_model, device=self.device, dtype=self.dtype)
- if config.from_model_config is not None:
- logger.info(f"Loading Unity config from `{config.from_model_config}`")
- model_config = load_unity_config(config.from_model_config)
- elif config.custom_params is not None:
- logger.info("Creating custom Unity config")
- model_config = self._build_custom_model_config()
- else:
- raise ValueError("One of params from_model, from_model_config or custom_params has to be set")
- logger.info("Building model")
- model = create_unity_model(config=model_config, dtype=self.dtype, device=self.device)
- if self.config.pretrained_w2v2_path is not None:
- self._load_pretrained_w2v2_encoder(model, self.config.pretrained_w2v2_path)
- if self.config.pretrained_s2t_decoder_path is not None:
- self._load_pretrained_s2t_decoder(model, self.config.pretrained_s2t_decoder_path)
- if self.config.pretrained_t2u_path is not None:
- self._load_pretrained_t2u(model, model_config, self.config.pretrained_t2u_path)
- return model
- def _build_custom_model_config(self) -> UnitYConfig:
- config = self.config.custom_params
- assert config is not None
- return UnitYConfig(
- model_dim=config.model_embed_dim,
- w2v2_encoder_config=Wav2Vec2EncoderConfig(
- model_dim=config.model_embed_dim,
- max_seq_len=4096,
- feature_dim=160,
- use_fbank=True,
- first_pass_dropout_p=0.0,
- layer_norm_features=config.w2v2_encoder_layers_layernorm_features,
- feature_extractor_layer_descs=[],
- feature_extractor_bias=False,
- feature_extractor_layer_norm_convs=False,
- feature_grad_scale=0,
- num_fbank_channels=80,
- fbank_stride=2,
- sample_fbank_every_k=1,
- pos_encoder_type=config.w2v2_pos_encoder_type,
- pos_encoder_depth=config.w2v2_pos_encoder_depth,
- pos_conv_kernel_size=config.w2v2_pos_conv_kernel_size,
- num_pos_conv_groups=config.w2v2_num_pos_conv_groups,
- use_conformer=config.w2v2_encoder_layers_use_conformer,
- num_encoder_layers=config.w2v2_encoder_layers,
- num_encoder_attn_heads=16,
- ffn_inner_dim=config.model_embed_dim * 4,
- dropout_p=0.0,
- attn_dropout_p=0.0,
- layer_drop_p=0.0,
- norm_order=TransformerNormOrder.POST,
- depthwise_conv_kernel_size=31,
- ),
- mt_model_config=NllbConfig(
- model_dim=config.model_embed_dim,
- max_seq_len=1024,
- vocabulary_size=config.nllb_vocabulary_size, # num_tokens + langs + spec symbols
- pad_idx=0,
- num_encoder_layers=config.nllb_encoder_layers,
- num_decoder_layers=config.nllb_decoder_layers,
- num_encoder_attn_heads=16,
- num_decoder_attn_heads=16,
- ffn_inner_dim=config.model_embed_dim * 8,
- dropout_p=0.1,
- ),
- t2u_config=UnitYT2UConfig(
- model_dim=config.model_embed_dim,
- unit_max_seq_len=2048,
- unit_vocabulary_size=config.unit_vocabulary_size,
- unit_pad_idx=1,
- num_encoder_layers=config.t2u_encoder_layers,
- num_decoder_layers=config.t2u_decoder_layers,
- nar_decoder_frontend_config=None,
- nar_decoder_config=None,
- num_encoder_attn_heads=16,
- num_decoder_attn_heads=16,
- ffn_inner_dim=config.model_embed_dim * 8,
- dropout_p=0.1,
- ),
- use_text_encoder=True,
- use_conformer_adaptor=False,
- num_adaptor_layers=1,
- adaptor_kernel_size=8,
- adaptor_stride=8,
- adaptor_layer_norm=True,
- adaptor_dropout_p=0.1,
- )
- if __name__ == "__main__":
- logging.basicConfig(
- level=logging.INFO,
- format=f"%(asctime)s %(levelname)s -- %(name)s.{os.getpid()}: %(message)s",
- )
- config = ModelConfig(
- custom_params=CustomModelParams(
- nllb_vocabulary_size=256103,
- ),
- pretrained_w2v2_path="/fsx-ust/spopuri/datasets/PT_CKPT/w2v2/w2vbert2rpq_600m_al5.pt",
- pretrained_s2t_decoder_path="/fsx-ust/spopuri/datasets/PT_CKPT/S2T/S2T_M4T_V1_V1_cleaned.pt",
- pretrained_t2u_path="/fsx-ust/spopuri/datasets/PT_CKPT/T2U/V5_10K_p2_14_80K.pt",
- )
- builder = ModelBuilder(config=config)
- model = ModelBuilder(config=config).build_model()
|