|
@@ -7,17 +7,15 @@
|
|
|
|
|
|
import logging
|
|
|
import os
|
|
|
-from typing import Any, Dict
|
|
|
+from typing import Any, Dict, Optional
|
|
|
|
|
|
import torch
|
|
|
+
|
|
|
from fairseq2.data import VocabularyInfo
|
|
|
from fairseq2.models.nllb.builder import NllbConfig
|
|
|
-from fairseq2.models.nllb.loader import NllbLoader
|
|
|
-from fairseq2.models.utils.checkpoint_loader import convert_model_state_dict
|
|
|
+from fairseq2.models.utils.checkpoint import convert_model_state_dict
|
|
|
from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig
|
|
|
-from fairseq2.models.wav2vec2.loader import Wav2Vec2Loader
|
|
|
from fairseq2.nn.transformer import TransformerNormOrder
|
|
|
-
|
|
|
from seamless_communication.cli.m4t.train.configs import CustomModelParams, ModelConfig
|
|
|
from seamless_communication.models.unity import (
|
|
|
UnitYConfig,
|
|
@@ -26,7 +24,10 @@ from seamless_communication.models.unity import (
|
|
|
create_unity_model,
|
|
|
load_unity_model,
|
|
|
)
|
|
|
-from seamless_communication.models.unity.loader import UnitYLoader, load_unity_config
|
|
|
+from seamless_communication.models.unity.loader import (
|
|
|
+ _fairseq_key_map as unity_fairseq_key_map,
|
|
|
+)
|
|
|
+from seamless_communication.models.unity.loader import load_unity_config
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
@@ -60,7 +61,28 @@ class ModelBuilder:
|
|
|
"""Load w2v2 encoder model trained in fairseq1"""
|
|
|
logger.info(f"Loading w2v2 weights from {checkpoint_path}")
|
|
|
state_dict = torch.load(checkpoint_path)["model"]
|
|
|
- key_map = Wav2Vec2Loader._fairseq_key_map()
|
|
|
+ key_map = {
|
|
|
+ # fmt: off
|
|
|
+ r"^encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"encoder.layers.\1.self_attn.output_proj.",
|
|
|
+ r"^encoder\.layers\.([0-9]+)\.fc1\.": r"encoder.layers.\1.ffn.inner_proj.",
|
|
|
+ r"^encoder\.layers\.([0-9]+)\.fc2\.": r"encoder.layers.\1.ffn.output_proj.",
|
|
|
+ r"^encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"encoder.layers.\1.ffn_layer_norm.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"decoder.layers.\1.ffn_layer_norm.",
|
|
|
+ r"^encoder\.embed_tokens\.": r"encoder_frontend.embed.",
|
|
|
+ r"^encoder\.pos_conv\.0\.": r"encoder_frontend.pos_encoder.conv.",
|
|
|
+ r"^feature_extractor\.conv_layers\.([0-9]+)\.0\.": r"encoder_frontend.feature_extractor.layers.\1.conv.",
|
|
|
+ r"^feature_extractor\.conv_layers\.([0-9]+)\.2\.1\.": \
|
|
|
+ r"encoder_frontend.feature_extractor.layers.\1.layer_norm.",
|
|
|
+ r"^feature_extractor\.conv_layers\.0\.2\.": \
|
|
|
+ r"encoder_frontend.feature_extractor.layers.0.group_norm.",
|
|
|
+ r"^layer_norm\.": r"encoder_frontend.post_extract_layer_norm.",
|
|
|
+ r"^post_extract_proj\.": r"encoder_frontend.model_dim_proj.",
|
|
|
+ r"^mask_emb": r"masker.temporal_mask_embed",
|
|
|
+ r"^quantizer\.vars": r"quantizer.entries",
|
|
|
+ r"^quantizer\.weight_proj\.": r"quantizer.entry_proj.",
|
|
|
+ r"^project_q\.": r"final_target_proj.",
|
|
|
+ # fmt: on
|
|
|
+ }
|
|
|
key_map.update(
|
|
|
{
|
|
|
r"^encoder.layers\.([0-9]+)\.conv_module.batch_norm.": r"encoder.layers.\1.conv.batch_norm.",
|
|
@@ -119,8 +141,28 @@ class ModelBuilder:
|
|
|
shared_state_dict = cls._sel_and_upd_prefix(
|
|
|
kv=state_dict, prefix="shared_decoder.", new_prefix=decoder_prefix
|
|
|
)
|
|
|
+ nllb_fairseq_key_map = {
|
|
|
+ # fmt: off
|
|
|
+ r"^encoder\.embed_tokens\.": r"encoder_frontend.embed.",
|
|
|
+ r"^decoder\.embed_tokens\.": r"decoder_frontend.embed.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"decoder.layers.\1.self_attn.output_proj.",
|
|
|
+ r"^encoder\.layers\.([0-9]+)\.self_attn\.out_proj\.": r"encoder.layers.\1.self_attn.output_proj.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn\.out_proj\.": \
|
|
|
+ r"decoder.layers.\1.encoder_decoder_attn.output_proj.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn\.": r"decoder.layers.\1.encoder_decoder_attn.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.encoder_attn_layer_norm\.": \
|
|
|
+ r"decoder.layers.\1.encoder_decoder_attn_layer_norm.",
|
|
|
+ r"^encoder\.layers\.([0-9]+)\.fc1\.": r"encoder.layers.\1.ffn.inner_proj.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.fc1\.": r"decoder.layers.\1.ffn.inner_proj.",
|
|
|
+ r"^encoder\.layers\.([0-9]+)\.fc2\.": r"encoder.layers.\1.ffn.output_proj.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.fc2\.": r"decoder.layers.\1.ffn.output_proj.",
|
|
|
+ r"^encoder\.layers\.([0-9]+)\.final_layer_norm\.": r"encoder.layers.\1.ffn_layer_norm.",
|
|
|
+ r"^decoder\.layers\.([0-9]+)\.final_layer_norm\.": r"decoder.layers.\1.ffn_layer_norm.",
|
|
|
+ r"^decoder\.output_projection\.": r"final_proj.",
|
|
|
+ # fmt: on
|
|
|
+ }
|
|
|
shared_state_dict = convert_model_state_dict(
|
|
|
- state_dict=shared_state_dict, key_map=NllbLoader._fairseq_key_map()
|
|
|
+ state_dict=shared_state_dict, key_map=nllb_fairseq_key_map
|
|
|
)
|
|
|
for rm_key in ["decoder.embed_positions._float_tensor", "decoder.version"]:
|
|
|
del shared_state_dict[rm_key]
|
|
@@ -133,11 +175,14 @@ class ModelBuilder:
|
|
|
proj_state = cls._sel_and_upd_prefix(
|
|
|
kv=shared_state_dict, prefix="final_proj.", new_prefix=""
|
|
|
)
|
|
|
- model.text_decoder_frontend.load_state_dict(frontend_state, strict=True)
|
|
|
+ if model.text_decoder_frontend is not None:
|
|
|
+ model.text_decoder_frontend.load_state_dict(frontend_state, strict=True)
|
|
|
logger.info(f"Loaded s2t decoder frontend weights from {checkpoint_path}")
|
|
|
- model.text_decoder.load_state_dict(decoder_state, strict=True)
|
|
|
+ if model.text_decoder is not None:
|
|
|
+ model.text_decoder.load_state_dict(decoder_state, strict=True)
|
|
|
logger.info(f"Loaded s2t decoder weights from {checkpoint_path}")
|
|
|
- model.final_proj.load_state_dict(proj_state, strict=True)
|
|
|
+ if model.final_proj is not None:
|
|
|
+ model.final_proj.load_state_dict(proj_state, strict=True)
|
|
|
logger.info(f"Loaded s2t decoder final_proj weights from {checkpoint_path}")
|
|
|
|
|
|
@classmethod
|
|
@@ -160,7 +205,7 @@ class ModelBuilder:
|
|
|
}
|
|
|
state_dict = convert_model_state_dict(
|
|
|
state_dict=state_dict,
|
|
|
- key_map=UnitYLoader._fairseq_key_map(config=model_config),
|
|
|
+ key_map=unity_fairseq_key_map(config=model_config),
|
|
|
)
|
|
|
t2u_state_dict = cls._sel_and_upd_prefix(
|
|
|
kv=state_dict, prefix="t2u_model.", new_prefix=""
|
|
@@ -170,7 +215,16 @@ class ModelBuilder:
|
|
|
|
|
|
def build_model(
|
|
|
self,
|
|
|
+ skip_loading_weights: bool = False,
|
|
|
) -> UnitYModel:
|
|
|
+ """
|
|
|
+ Args:
|
|
|
+ skip_loading_weights (bool, optional):
|
|
|
+ Ignores pretrained_w2v2_path, pretrained_s2t_decoder_path, pretrained_t2u_path.
|
|
|
+ Defaults to False.
|
|
|
+ Returns:
|
|
|
+ UnitYModel: initialized UnitY model
|
|
|
+ """
|
|
|
config = self.config
|
|
|
logger.info("Initializing model")
|
|
|
if config.from_model is not None:
|
|
@@ -193,25 +247,38 @@ class ModelBuilder:
|
|
|
model = create_unity_model(
|
|
|
config=model_config, dtype=self.dtype, device=self.device
|
|
|
)
|
|
|
+ if not skip_loading_weights:
|
|
|
+ if self.config.pretrained_w2v2_path is not None:
|
|
|
+ self._load_pretrained_w2v2_encoder(
|
|
|
+ model, self.config.pretrained_w2v2_path
|
|
|
+ )
|
|
|
|
|
|
- if self.config.pretrained_w2v2_path is not None:
|
|
|
- self._load_pretrained_w2v2_encoder(model, self.config.pretrained_w2v2_path)
|
|
|
+ if self.config.pretrained_s2t_decoder_path is not None:
|
|
|
+ self._load_pretrained_s2t_decoder(
|
|
|
+ model, self.config.pretrained_s2t_decoder_path
|
|
|
+ )
|
|
|
|
|
|
- if self.config.pretrained_s2t_decoder_path is not None:
|
|
|
- self._load_pretrained_s2t_decoder(
|
|
|
- model, self.config.pretrained_s2t_decoder_path
|
|
|
- )
|
|
|
+ if self.config.pretrained_t2u_path is not None:
|
|
|
+ self._load_pretrained_t2u(
|
|
|
+ model, model_config, self.config.pretrained_t2u_path
|
|
|
+ )
|
|
|
|
|
|
- if self.config.pretrained_t2u_path is not None:
|
|
|
- self._load_pretrained_t2u(
|
|
|
- model, model_config, self.config.pretrained_t2u_path
|
|
|
+ def _num_s2t_params(model: UnitYModel) -> int:
|
|
|
+ return (
|
|
|
+ self._get_num_model_params(model.speech_encoder_frontend)
|
|
|
+ + self._get_num_model_params(model.speech_encoder)
|
|
|
+ + self._get_num_model_params(model.text_decoder_frontend)
|
|
|
+ + self._get_num_model_params(model.text_decoder)
|
|
|
)
|
|
|
|
|
|
logger.info(f"Number of model params: {self._get_num_model_params(model)}")
|
|
|
+ logger.info(f"Number of S2T params: {_num_s2t_params(model)}")
|
|
|
return model
|
|
|
|
|
|
@classmethod
|
|
|
- def _get_num_model_params(cls, model: torch.nn.Module) -> int:
|
|
|
+ def _get_num_model_params(cls, model: Optional[torch.nn.Module]) -> int:
|
|
|
+ if model is None:
|
|
|
+ return 0
|
|
|
pp = 0
|
|
|
for p in list(model.parameters()):
|
|
|
nn = 1
|
|
@@ -221,14 +288,32 @@ class ModelBuilder:
|
|
|
return pp
|
|
|
|
|
|
def _build_custom_model_config(self) -> UnitYConfig:
|
|
|
- config = self.config.custom_params
|
|
|
+ assert self.config.custom_params is not None
|
|
|
+ config: CustomModelParams = self.config.custom_params
|
|
|
+ num_fbank_channels = (
|
|
|
+ config.num_fbank_channels if config.num_fbank_channels is not None else 80
|
|
|
+ )
|
|
|
+ fbank_stride = config.fbank_stride if config.fbank_stride is not None else 2
|
|
|
+ nllb_ffn_inner_dim = (
|
|
|
+ config.nllb_ffn_inner_dim
|
|
|
+ if config.nllb_ffn_inner_dim is not None
|
|
|
+ else config.model_embed_dim * 8
|
|
|
+ )
|
|
|
+ w2v2_ffn_inner_dim = (
|
|
|
+ config.w2v2_ffn_inner_dim
|
|
|
+ if config.w2v2_ffn_inner_dim is not None
|
|
|
+ else config.model_embed_dim * 4
|
|
|
+ )
|
|
|
assert config is not None
|
|
|
return UnitYConfig(
|
|
|
+ use_gelu=False,
|
|
|
+ use_text_decoder=True,
|
|
|
+ prosody_encoder_config=None,
|
|
|
model_dim=config.model_embed_dim,
|
|
|
w2v2_encoder_config=Wav2Vec2EncoderConfig(
|
|
|
model_dim=config.model_embed_dim,
|
|
|
max_seq_len=4096,
|
|
|
- feature_dim=160,
|
|
|
+ feature_dim=num_fbank_channels * fbank_stride,
|
|
|
use_fbank=True,
|
|
|
first_pass_dropout_p=0.0,
|
|
|
layer_norm_features=config.w2v2_encoder_layers_layernorm_features,
|
|
@@ -236,8 +321,8 @@ class ModelBuilder:
|
|
|
feature_extractor_bias=False,
|
|
|
feature_extractor_layer_norm_convs=False,
|
|
|
feature_grad_scale=0,
|
|
|
- num_fbank_channels=80,
|
|
|
- fbank_stride=2,
|
|
|
+ num_fbank_channels=num_fbank_channels,
|
|
|
+ fbank_stride=fbank_stride,
|
|
|
sample_fbank_every_k=1,
|
|
|
pos_encoder_type=config.w2v2_pos_encoder_type,
|
|
|
pos_encoder_depth=config.w2v2_pos_encoder_depth,
|
|
@@ -246,7 +331,7 @@ class ModelBuilder:
|
|
|
use_conformer=config.w2v2_encoder_layers_use_conformer,
|
|
|
num_encoder_layers=config.w2v2_encoder_layers,
|
|
|
num_encoder_attn_heads=16,
|
|
|
- ffn_inner_dim=config.model_embed_dim * 4,
|
|
|
+ ffn_inner_dim=w2v2_ffn_inner_dim,
|
|
|
dropout_p=0.0,
|
|
|
attn_dropout_p=0.0,
|
|
|
layer_drop_p=0.0,
|
|
@@ -267,10 +352,14 @@ class ModelBuilder:
|
|
|
num_decoder_layers=config.nllb_decoder_layers,
|
|
|
num_encoder_attn_heads=16,
|
|
|
num_decoder_attn_heads=16,
|
|
|
- ffn_inner_dim=config.model_embed_dim * 8,
|
|
|
+ ffn_inner_dim=nllb_ffn_inner_dim,
|
|
|
dropout_p=0.1,
|
|
|
),
|
|
|
t2u_config=UnitYT2UConfig(
|
|
|
+ use_gelu=False,
|
|
|
+ char_pad_idx=0,
|
|
|
+ use_prosody_proj=False,
|
|
|
+ prosody_encoder_dim=0,
|
|
|
model_dim=config.model_embed_dim,
|
|
|
unit_max_seq_len=2048,
|
|
|
target_vocab_info=VocabularyInfo(
|
|
@@ -312,5 +401,7 @@ if __name__ == "__main__":
|
|
|
pretrained_s2t_decoder_path="/fsx-ust/spopuri/datasets/PT_CKPT/S2T/S2T_M4T_V1_V1_cleaned.pt",
|
|
|
pretrained_t2u_path="/fsx-ust/spopuri/datasets/PT_CKPT/T2U/V5_10K_p2_14_80K.pt",
|
|
|
)
|
|
|
+ config = ModelConfig(from_model_config="seamlessM4T_medium")
|
|
|
builder = ModelBuilder(config=config)
|
|
|
- model = ModelBuilder(config=config).build_model()
|
|
|
+ model = builder.build_model()
|
|
|
+ print(model)
|