Эх сурвалжийг харах

M4T training scripts follow-ups

mavlyutov 1 жил өмнө
parent
commit
5c35042f34

+ 16 - 3
scripts/m4t/train/configs.py

@@ -4,6 +4,7 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
+import yaml
 
 from dataclasses import dataclass
 from typing import Dict, Any, Union, get_origin, get_args, List, Literal, Optional
@@ -23,7 +24,7 @@ class Config:
 
     @classmethod
     def _is_config(cls, type_like: Any) -> bool:
-        """ checks if type_like class is a subclass of Config"""
+        """Checks if type_like class is a subclass of Config"""
         try:
             if issubclass(type_like, Config):
                 return True
@@ -33,7 +34,7 @@ class Config:
 
     @classmethod
     def _is_optional_config(cls, type_like: Any) -> bool:
-        """ checks if type_like == Optional[subclass of Config] """
+        """Checks if type_like == Optional[subclass of Config]"""
         if not get_origin(type_like) == Union:
             return False
         args = [arg for arg in get_args(type_like) if arg is not type(None)]
@@ -47,7 +48,11 @@ class Config:
             # Optional[Config]
             if cls._is_optional_config(field_desc.type):
                 if non_null:
-                    type_arg = [arg for arg in get_args(field_desc.type) if arg is not type(None)][0]
+                    type_arg = [
+                        arg
+                        for arg in get_args(field_desc.type)
+                        if arg is not type(None)
+                    ][0]
                     kwargs[key] = type_arg.deserialize(asdict[key])
                 else:
                     kwargs[key] = None
@@ -63,6 +68,14 @@ class Config:
                 kwargs[key] = asdict.get(key)
         return cls(**kwargs)
 
+    @classmethod
+    def from_string(cls, serialized_config: str):
+        return cls.deserialize(yaml.load(serialized_config, Loader=yaml.FullLoader))
+
+    @classmethod
+    def from_file(cls, config_path: str):
+        return cls.deserialize(yaml.load(config_path, Loader=yaml.FullLoader))
+
 
 @dataclass
 class TextTokenizationConfig(Config):

+ 68 - 19
scripts/m4t/train/model.py

@@ -47,13 +47,17 @@ class ModelBuilder:
         self.device = device
 
     @classmethod
-    def _sel_and_upd_prefix(cls, kv: Dict[str, Any], prefix: str, new_prefix: str = "") -> Dict[str, Any]:
+    def _sel_and_upd_prefix(
+        cls, kv: Dict[str, Any], prefix: str, new_prefix: str = ""
+    ) -> Dict[str, Any]:
         # fmt: off
         return {new_prefix + k[len(prefix):]: v for k, v in kv.items() if k.startswith(prefix)}
         # fmt: on
 
     @classmethod
-    def _load_pretrained_w2v2_encoder(cls, model: UnitYModel, checkpoint_path: str) -> None:
+    def _load_pretrained_w2v2_encoder(
+        cls, model: UnitYModel, checkpoint_path: str
+    ) -> None:
         """Load w2v2 encoder model trained in fairseq1"""
         logger.info(f"Loading w2v2 weights from {checkpoint_path}")
         state_dict = torch.load(checkpoint_path)["model"]
@@ -90,7 +94,9 @@ class ModelBuilder:
         model.speech_encoder.inner.load_state_dict(enc_state_dict, strict=True)  # type: ignore
         logger.info(f"Loaded w2v2 encoder from {checkpoint_path}")
 
-        enc_fronted_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="encoder_frontend.")  # noqa
+        enc_fronted_state_dict = cls._sel_and_upd_prefix(  # noqa
+            kv=state_dict, prefix="encoder_frontend."
+        )  # noqa
         # TODO: reconcile discrepancies between fr1 and fr2 model designs
         #  fr1-based w2v2 checkpoints with conv positional encoders use relpos self attention
         #   this is not compatible with the fr2 model design
@@ -98,24 +104,36 @@ class ModelBuilder:
         # logger.info(f"Loaded w2v2 encoder frontend from {checkpoint_path}")
 
     @classmethod
-    def _load_pretrained_s2t_decoder(cls, model: UnitYModel, checkpoint_path: str) -> None:
+    def _load_pretrained_s2t_decoder(
+        cls, model: UnitYModel, checkpoint_path: str
+    ) -> None:
         """Load NLLB decoder trained in fairseq1"""
         logger.info(f"Loading s2t decoder weights from {checkpoint_path}")
         try:
             state_dict = torch.load(checkpoint_path)["model"]
         except ModuleNotFoundError:
-            logger.info("If seeing `No module named 'omegaconf'`, run `pip install omegaconf`")
+            logger.info(
+                "If seeing `No module named 'omegaconf'`, run `pip install omegaconf`"
+            )
             raise
         decoder_prefix = "decoder."
-        shared_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="shared_decoder.", new_prefix=decoder_prefix)
+        shared_state_dict = cls._sel_and_upd_prefix(
+            kv=state_dict, prefix="shared_decoder.", new_prefix=decoder_prefix
+        )
         shared_state_dict = convert_model_state_dict(
             state_dict=shared_state_dict, key_map=NllbLoader._fairseq_key_map()
         )
         for rm_key in ["decoder.embed_positions._float_tensor", "decoder.version"]:
             del shared_state_dict[rm_key]
-        decoder_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix=decoder_prefix, new_prefix="")
-        frontend_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix="decoder_frontend.", new_prefix="")
-        proj_state = cls._sel_and_upd_prefix(kv=shared_state_dict, prefix="final_proj.", new_prefix="")
+        decoder_state = cls._sel_and_upd_prefix(
+            kv=shared_state_dict, prefix=decoder_prefix, new_prefix=""
+        )
+        frontend_state = cls._sel_and_upd_prefix(
+            kv=shared_state_dict, prefix="decoder_frontend.", new_prefix=""
+        )
+        proj_state = cls._sel_and_upd_prefix(
+            kv=shared_state_dict, prefix="final_proj.", new_prefix=""
+        )
         model.text_decoder_frontend.load_state_dict(frontend_state, strict=True)
         logger.info(f"Loaded s2t decoder frontend weights from {checkpoint_path}")
         model.text_decoder.load_state_dict(decoder_state, strict=True)
@@ -124,20 +142,30 @@ class ModelBuilder:
         logger.info(f"Loaded s2t decoder final_proj weights from {checkpoint_path}")
 
     @classmethod
-    def _load_pretrained_t2u(cls, model: UnitYModel, model_config: UnitYConfig, checkpoint_path: str) -> None:
+    def _load_pretrained_t2u(
+        cls, model: UnitYModel, model_config: UnitYConfig, checkpoint_path: str
+    ) -> None:
         logger.info(f"Loading t2u weights from {checkpoint_path}")
         t2u_model = model.t2u_model
         assert t2u_model is not None
         try:
             state_dict = torch.load(checkpoint_path)["model"]
         except ModuleNotFoundError:
-            logger.info("If seeing `No module named 'omegaconf'`, run `pip install omegaconf`")
+            logger.info(
+                "If seeing `No module named 'omegaconf'`, run `pip install omegaconf`"
+            )
             raise
-        state_dict = {k.replace("encoder.", "synthesizer_encoder."): v for k, v in state_dict.items()}
+        state_dict = {
+            k.replace("encoder.", "synthesizer_encoder."): v
+            for k, v in state_dict.items()
+        }
         state_dict = convert_model_state_dict(
-            state_dict=state_dict, key_map=UnitYLoader._fairseq_key_map(config=model_config)
+            state_dict=state_dict,
+            key_map=UnitYLoader._fairseq_key_map(config=model_config),
+        )
+        t2u_state_dict = cls._sel_and_upd_prefix(
+            kv=state_dict, prefix="t2u_model.", new_prefix=""
         )
-        t2u_state_dict = cls._sel_and_upd_prefix(kv=state_dict, prefix="t2u_model.", new_prefix="")
         t2u_model.load_state_dict(t2u_state_dict)
         logger.info(f"Loaded t2u weights from {checkpoint_path}")
 
@@ -148,7 +176,9 @@ class ModelBuilder:
         logger.info("Initializing model")
         if config.from_model is not None:
             logger.info(f"Loading model and weights from `{config.from_model}`")
-            return load_unity_model(config.from_model, device=self.device, dtype=self.dtype)
+            return load_unity_model(
+                config.from_model, device=self.device, dtype=self.dtype
+            )
 
         if config.from_model_config is not None:
             logger.info(f"Loading Unity config from `{config.from_model_config}`")
@@ -157,21 +187,40 @@ class ModelBuilder:
             logger.info("Creating custom Unity config")
             model_config = self._build_custom_model_config()
         else:
-            raise ValueError("One of params from_model, from_model_config or custom_params has to be set")
+            raise ValueError(
+                "One of params from_model, from_model_config or custom_params has to be set"
+            )
         logger.info("Building model")
-        model = create_unity_model(config=model_config, dtype=self.dtype, device=self.device)
+        model = create_unity_model(
+            config=model_config, dtype=self.dtype, device=self.device
+        )
 
         if self.config.pretrained_w2v2_path is not None:
             self._load_pretrained_w2v2_encoder(model, self.config.pretrained_w2v2_path)
 
         if self.config.pretrained_s2t_decoder_path is not None:
-            self._load_pretrained_s2t_decoder(model, self.config.pretrained_s2t_decoder_path)
+            self._load_pretrained_s2t_decoder(
+                model, self.config.pretrained_s2t_decoder_path
+            )
 
         if self.config.pretrained_t2u_path is not None:
-            self._load_pretrained_t2u(model, model_config, self.config.pretrained_t2u_path)
+            self._load_pretrained_t2u(
+                model, model_config, self.config.pretrained_t2u_path
+            )
 
+        logger.info(f"Number of model params: {self._get_num_model_params(model)}")
         return model
 
+    @classmethod
+    def _get_num_model_params(cls, model: torch.nn.Module) -> int:
+        pp = 0
+        for p in list(model.parameters()):
+            nn = 1
+            for s in list(p.size()):
+                nn = nn * s
+            pp += nn
+        return pp
+
     def _build_custom_model_config(self) -> UnitYConfig:
         config = self.config.custom_params
         assert config is not None

+ 1 - 1
scripts/m4t/train/recipes/asr_small.yaml

@@ -36,7 +36,7 @@ model:
     model_embed_dim: 768
     nllb_decoder_layers: 3
     nllb_encoder_layers: 1
-    nllb_vocabulary_size: 256102
+    nllb_vocabulary_size: 20010
     t2u_decoder_layers: 1
     t2u_encoder_layers: 1
     unit_vocabulary_size: 10082

+ 2 - 2
scripts/m4t/train/recipes/asr_small_wh_transc.yaml

@@ -36,7 +36,7 @@ model:
     model_embed_dim: 768
     nllb_decoder_layers: 3
     nllb_encoder_layers: 1
-    nllb_vocabulary_size: 256102
+    nllb_vocabulary_size: 20010
     t2u_decoder_layers: 1
     t2u_encoder_layers: 1
     unit_vocabulary_size: 10082
@@ -86,7 +86,7 @@ train_data:
     num_units: null
   unit_tokenizer_name: seamlessM4T_large
 training:
-  eval_steps: 1000 
+  eval_steps: 5000
   float_dtype: bf16
   label_smoothing: 0.2
   learning_rate: 0.0001

+ 2 - 2
scripts/m4t/train/recipes/large_M4T_v1.yaml

@@ -77,10 +77,10 @@ train_data:
     num_units: null
   unit_tokenizer_name: seamlessM4T_large
 training:
-  eval_steps: 1000
+  eval_steps: 5000
   float_dtype: bf16
   label_smoothing: 0.2
-  learning_rate: 0.00005
+  learning_rate: 0.0001
   log_steps: 200
   max_epochs: 100
   patience: 10

+ 10 - 2
scripts/m4t/train/trainer.py

@@ -372,13 +372,21 @@ class UnitYTrainer:
         self._release_memory(batch)
 
     def _release_memory(self, batch: dataloader.MultimodalSeqsBatch) -> None:
-        """ Explicitly release large memory consumers """
+        """Explicitly release large memory consumers"""
         del batch
 
+    def _strip_state_key_prefixes(self, key: str) -> str:
+        """Removes state_dict keys prefixes associated with model wrappers"""
+        to_strip = ["module.", "model."]
+        for prefix in to_strip:
+            if key.startswith(prefix):
+                key = key[len(prefix):]
+        return key
+
     def _get_state(self) -> Dict[str, Any]:
         model_state_dict = self.model.state_dict()
         model_state_dict = {
-            key.replace("module.model.", ""): value
+            self._strip_state_key_prefixes(key): value
             for key, value in model_state_dict.items()
         }
         return model_state_dict