ソースを参照

Fix inconsistency between model vocab info and associated tokenizers (inherit directly from the tokenizers) (#126)

Ruslan Mavlyutov 1 年間 前
コミット
c9f611a0b2

+ 22 - 12
src/seamless_communication/cli/m4t/train/model.py

@@ -39,12 +39,16 @@ class ModelBuilder:
     def __init__(
         self,
         config: ModelConfig,
+        override_s2t_vocabulary_info: Optional[VocabularyInfo] = None,
+        override_t2u_vocabulary_info: Optional[VocabularyInfo] = None,
         dtype: torch.dtype = torch.float16,
         device: torch.device = CPU_DEVICE,
     ):
         self.config = config
         self.dtype = dtype
         self.device = device
+        self.override_s2t_vocabulary_info = override_s2t_vocabulary_info
+        self.override_t2u_vocabulary_info = override_t2u_vocabulary_info
 
     @classmethod
     def _sel_and_upd_prefix(
@@ -287,6 +291,18 @@ class ModelBuilder:
             pp += nn
         return pp
 
+    def _build_vocab_info(
+        self, vocab_size: int, ref_vocab_info: Optional[VocabularyInfo]
+    ) -> VocabularyInfo:
+        assert ref_vocab_info is not None
+        return VocabularyInfo(
+            size=vocab_size,
+            unk_idx=ref_vocab_info.unk_idx,
+            bos_idx=ref_vocab_info.bos_idx,
+            eos_idx=ref_vocab_info.eos_idx,
+            pad_idx=ref_vocab_info.pad_idx,
+        )
+
     def _build_custom_model_config(self) -> UnitYConfig:
         assert self.config.custom_params is not None
         config: CustomModelParams = self.config.custom_params
@@ -341,12 +357,9 @@ class ModelBuilder:
             mt_model_config=NllbConfig(
                 model_dim=config.model_embed_dim,
                 max_seq_len=1024,
-                vocab_info=VocabularyInfo(
-                    size=config.nllb_vocabulary_size,
-                    unk_idx=1,
-                    bos_idx=2,
-                    eos_idx=3,
-                    pad_idx=0,
+                vocab_info=self._build_vocab_info(
+                    vocab_size=config.nllb_vocabulary_size,
+                    ref_vocab_info=self.override_s2t_vocabulary_info,
                 ),
                 num_encoder_layers=config.nllb_encoder_layers,
                 num_decoder_layers=config.nllb_decoder_layers,
@@ -362,12 +375,9 @@ class ModelBuilder:
                 prosody_encoder_dim=0,
                 model_dim=config.model_embed_dim,
                 unit_max_seq_len=2048,
-                target_vocab_info=VocabularyInfo(
-                    size=config.unit_vocabulary_size,
-                    unk_idx=3,
-                    bos_idx=0,
-                    eos_idx=2,
-                    pad_idx=1,
+                target_vocab_info=self._build_vocab_info(
+                    vocab_size=config.unit_vocabulary_size,
+                    ref_vocab_info=self.override_t2u_vocabulary_info,
                 ),
                 num_encoder_layers=config.t2u_encoder_layers,
                 num_decoder_layers=config.t2u_decoder_layers,

+ 7 - 3
src/seamless_communication/cli/m4t/train/run_eval.py

@@ -248,8 +248,14 @@ def run_evaluation(
     logger.info(f"Device: {device}, float dtype: {float_dtype}")
     audio_zips_root = parameters.train_data.audio.audio_root_dir
     logger.info(f"Audio zip root: {audio_zips_root}")
+    text_tokenizer = _init_text_tokenizer(data_config=parameters.train_data)
+    unit_tokenizer = _init_unit_tokenizer(data_config=parameters.train_data)
     model = _model.ModelBuilder(
-        config=parameters.model, dtype=float_dtype, device=device
+        config=parameters.model,
+        dtype=float_dtype,
+        device=device,
+        override_s2t_vocabulary_info=text_tokenizer.vocab_info,
+        override_t2u_vocabulary_info=unit_tokenizer.vocab_info,
     ).build_model(skip_loading_weights=True)
     logger.info(f"Loading checkpoint from {checkpoint_path}")
     state_dict = torch.load(checkpoint_path, map_location=device)
@@ -262,8 +268,6 @@ def run_evaluation(
     }
     model.load_state_dict(state_dict)
     model.eval()
-    text_tokenizer = _init_text_tokenizer(data_config=parameters.train_data)
-    unit_tokenizer = _init_unit_tokenizer(data_config=parameters.train_data)
     fbank_extractor = WaveformToFbankConverter(
         num_mel_bins=parameters.train_data.audio.fbanks_num_mel_bins or 80,
         waveform_scale=parameters.train_data.audio.fbanks_waveform_scale,

+ 8 - 4
src/seamless_communication/cli/m4t/train/run_training.py

@@ -60,10 +60,6 @@ def run_training(
         parameters.training.float_dtype
     )
     logger.info(f"Device: {device}, float dtype: {float_dtype}")
-    model = _model.ModelBuilder(
-        config=parameters.model, dtype=float_dtype, device=device
-    ).build_model()
-    logger.info(f"Model: {model}")
     train_data = _dataloader.UnityDataLoader(
         config=parameters.train_data,
         rank=rank,
@@ -78,6 +74,14 @@ def run_training(
         target_device=device,
         float_dtype=float_dtype,
     )
+    model = _model.ModelBuilder(
+        config=parameters.model,
+        dtype=float_dtype,
+        device=device,
+        override_s2t_vocabulary_info=train_data.text_tokenizer.vocab_info,
+        override_t2u_vocabulary_info=train_data.unit_tokenizer.vocab_info,
+    ).build_model()
+    logger.info(f"Model: {model}")
     trainer = _trainer.UnitYTrainer(
         model=model,
         params=parameters.training,