|
@@ -47,13 +47,17 @@ class ModelBuilder:
|
|
|
self.device = device
|
|
|
|
|
|
@classmethod
|
|
|
- def _sel_and_upd_prefix(cls, kv: Dict[str, Any], prefix: str, new_prefix: str = "") -> Dict[str, Any]:
|
|
|
+ def _sel_and_upd_prefix(
|
|
|
+ cls, kv: Dict[str, Any], prefix: str, new_prefix: str = ""
|
|
|
+ ) -> Dict[str, Any]:
|
|
|
# fmt: off
|
|
|
return {new_prefix + k[len(prefix):]: v for k, v in kv.items() if k.startswith(prefix)}
|
|
|
# fmt: on
|
|
|
|
|
|
@classmethod
|
|
|
- def _load_pretrained_w2v2_encoder(cls, model: UnitYModel, checkpoint_path: str) -> None:
|
|
|
+ def _load_pretrained_w2v2_encoder(
|
|
|
+ cls, model: UnitYModel, checkpoint_path: str
|
|
|
+ ) -> None:
|
|
|
"""Load w2v2 encoder model trained in fairseq1"""
|
|
|
logger.info(f"Loading w2v2 weights from {checkpoint_path}")
|
|
|
state_dict = torch.load(checkpoint_path)["model"]
|
|
@@ -90,7 +94,9 @@ class ModelBuilder:
|
|
|
model.speech_encoder.inner.load_state_dict(enc_state_dict, strict=True) # type: ignore
|
|
|
logger.info(f"Loaded w2v2 encoder from {checkpoint_path}")
|
|
|
|
|
|
- enc_fronted_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="encoder_frontend.") # noqa
|
|
|
+ enc_fronted_state_dict = cls._sel_and_upd_prefix( # noqa
|
|
|
+ kv=state_dict, prefix="encoder_frontend."
|
|
|
+ ) # noqa
|
|
|
# TODO: reconcile discrepancies between fr1 and fr2 model designs
|
|
|
# fr1-based w2v2 checkpoints with conv positional encoders use relpos self attention
|
|
|
# this is not compatible with the fr2 model design
|
|
@@ -98,24 +104,36 @@ class ModelBuilder:
|
|
|
# logger.info(f"Loaded w2v2 encoder frontend from {checkpoint_path}")
|
|
|
|
|
|
@classmethod
|
|
|
- def _load_pretrained_s2t_decoder(cls, model: UnitYModel, checkpoint_path: str) -> None:
|
|
|
+ def _load_pretrained_s2t_decoder(
|
|
|
+ cls, model: UnitYModel, checkpoint_path: str
|
|
|
+ ) -> None:
|
|
|
"""Load NLLB decoder trained in fairseq1"""
|
|
|
logger.info(f"Loading s2t decoder weights from {checkpoint_path}")
|
|
|
try:
|
|
|
state_dict = torch.load(checkpoint_path)["model"]
|
|
|
except ModuleNotFoundError:
|
|
|
- logger.info("If seeing `No module named 'omegaconf'`, run `pip install omegaconf`")
|
|
|
+ logger.info(
|
|
|
+ "If seeing `No module named 'omegaconf'`, run `pip install omegaconf`"
|
|
|
+ )
|
|
|
raise
|
|
|
decoder_prefix = "decoder."
|
|
|
- shared_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="shared_decoder.", new_prefix=decoder_prefix)
|
|
|
+ shared_state_dict = cls._sel_and_upd_prefix(
|
|
|
+ kv=state_dict, prefix="shared_decoder.", new_prefix=decoder_prefix
|
|
|
+ )
|
|
|
shared_state_dict = convert_model_state_dict(
|
|
|
state_dict=shared_state_dict, key_map=NllbLoader._fairseq_key_map()
|
|
|
)
|
|
|
for rm_key in ["decoder.embed_positions._float_tensor", "decoder.version"]:
|
|
|
del shared_state_dict[rm_key]
|
|
|
- decoder_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix=decoder_prefix, new_prefix="")
|
|
|
- frontend_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix="decoder_frontend.", new_prefix="")
|
|
|
- proj_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix="final_proj.", new_prefix="")
|
|
|
+ decoder_state = cls._sel_and_upd_prefix(
|
|
|
+ kv=shared_state_dict, prefix=decoder_prefix, new_prefix=""
|
|
|
+ )
|
|
|
+ frontend_state = cls._sel_and_upd_prefix(
|
|
|
+ kv=shared_state_dict, prefix="decoder_frontend.", new_prefix=""
|
|
|
+ )
|
|
|
+ proj_state = cls._sel_and_upd_prefix(
|
|
|
+ kv=shared_state_dict, prefix="final_proj.", new_prefix=""
|
|
|
+ )
|
|
|
model.text_decoder_frontend.load_state_dict(frontend_state, strict=True)
|
|
|
logger.info(f"Loaded s2t decoder frontend weights from {checkpoint_path}")
|
|
|
model.text_decoder.load_state_dict(decoder_state, strict=True)
|
|
@@ -124,20 +142,30 @@ class ModelBuilder:
|
|
|
logger.info(f"Loaded s2t decoder final_proj weights from {checkpoint_path}")
|
|
|
|
|
|
@classmethod
|
|
|
- def _load_pretrained_t2u(cls, model: UnitYModel, model_config: UnitYConfig, checkpoint_path: str) -> None:
|
|
|
+ def _load_pretrained_t2u(
|
|
|
+ cls, model: UnitYModel, model_config: UnitYConfig, checkpoint_path: str
|
|
|
+ ) -> None:
|
|
|
logger.info(f"Loading t2u weights from {checkpoint_path}")
|
|
|
t2u_model = model.t2u_model
|
|
|
assert t2u_model is not None
|
|
|
try:
|
|
|
state_dict = torch.load(checkpoint_path)["model"]
|
|
|
except ModuleNotFoundError:
|
|
|
- logger.info("If seeing `No module named 'omegaconf'`, run `pip install omegaconf`")
|
|
|
+ logger.info(
|
|
|
+ "If seeing `No module named 'omegaconf'`, run `pip install omegaconf`"
|
|
|
+ )
|
|
|
raise
|
|
|
- state_dict = {k.replace("encoder.", "synthesizer_encoder."): v for k, v in state_dict.items()}
|
|
|
+ state_dict = {
|
|
|
+ k.replace("encoder.", "synthesizer_encoder."): v
|
|
|
+ for k, v in state_dict.items()
|
|
|
+ }
|
|
|
state_dict = convert_model_state_dict(
|
|
|
- state_dict=state_dict, key_map=UnitYLoader._fairseq_key_map(config=model_config)
|
|
|
+ state_dict=state_dict,
|
|
|
+ key_map=UnitYLoader._fairseq_key_map(config=model_config),
|
|
|
+ )
|
|
|
+ t2u_state_dict = cls._sel_and_upd_prefix(
|
|
|
+ kv=state_dict, prefix="t2u_model.", new_prefix=""
|
|
|
)
|
|
|
- t2u_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="t2u_model.", new_prefix="")
|
|
|
t2u_model.load_state_dict(t2u_state_dict)
|
|
|
logger.info(f"Loaded t2u weights from {checkpoint_path}")
|
|
|
|
|
@@ -148,7 +176,9 @@ class ModelBuilder:
|
|
|
logger.info("Initializing model")
|
|
|
if config.from_model is not None:
|
|
|
logger.info(f"Loading model and weights from `{config.from_model}`")
|
|
|
- return load_unity_model(config.from_model, device=self.device, dtype=self.dtype)
|
|
|
+ return load_unity_model(
|
|
|
+ config.from_model, device=self.device, dtype=self.dtype
|
|
|
+ )
|
|
|
|
|
|
if config.from_model_config is not None:
|
|
|
logger.info(f"Loading Unity config from `{config.from_model_config}`")
|
|
@@ -157,21 +187,40 @@ class ModelBuilder:
|
|
|
logger.info("Creating custom Unity config")
|
|
|
model_config = self._build_custom_model_config()
|
|
|
else:
|
|
|
- raise ValueError("One of params from_model, from_model_config or custom_params has to be set")
|
|
|
+ raise ValueError(
|
|
|
+ "One of params from_model, from_model_config or custom_params has to be set"
|
|
|
+ )
|
|
|
logger.info("Building model")
|
|
|
- model = create_unity_model(config=model_config, dtype=self.dtype, device=self.device)
|
|
|
+ model = create_unity_model(
|
|
|
+ config=model_config, dtype=self.dtype, device=self.device
|
|
|
+ )
|
|
|
|
|
|
if self.config.pretrained_w2v2_path is not None:
|
|
|
self._load_pretrained_w2v2_encoder(model, self.config.pretrained_w2v2_path)
|
|
|
|
|
|
if self.config.pretrained_s2t_decoder_path is not None:
|
|
|
- self._load_pretrained_s2t_decoder(model, self.config.pretrained_s2t_decoder_path)
|
|
|
+ self._load_pretrained_s2t_decoder(
|
|
|
+ model, self.config.pretrained_s2t_decoder_path
|
|
|
+ )
|
|
|
|
|
|
if self.config.pretrained_t2u_path is not None:
|
|
|
- self._load_pretrained_t2u(model, model_config, self.config.pretrained_t2u_path)
|
|
|
+ self._load_pretrained_t2u(
|
|
|
+ model, model_config, self.config.pretrained_t2u_path
|
|
|
+ )
|
|
|
|
|
|
+ logger.info(f"Number of model params: {self._get_num_model_params(model)}")
|
|
|
return model
|
|
|
|
|
|
+ @classmethod
|
|
|
+ def _get_num_model_params(cls, model: torch.nn.Module) -> int:
|
|
|
+ pp = 0
|
|
|
+ for p in list(model.parameters()):
|
|
|
+ nn = 1
|
|
|
+ for s in list(p.size()):
|
|
|
+ nn = nn * s
|
|
|
+ pp += nn
|
|
|
+ return pp
|
|
|
+
|
|
|
def _build_custom_model_config(self) -> UnitYConfig:
|
|
|
config = self.config.custom_params
|
|
|
assert config is not None
|