소스 검색

Run black on the entire repo, fix mypy issues in translator.py, fix API bug in unity/generator.py (#62)

Kaushik Ram Sadagopan 1 년 전
부모
커밋
4314c30490

+ 32 - 8
scripts/m4t/train/run_training.py

@@ -47,21 +47,35 @@ def init_parser() -> argparse.ArgumentParser:
     return parser
 
 
-def run_training(parameters: WorkflowParams, work_dir: str, checkpoint_dir: str) -> None:
+def run_training(
+    parameters: WorkflowParams, work_dir: str, checkpoint_dir: str
+) -> None:
     logger.info(f"Workflow params: {parameters}")
     rank, world_size = dist_utils.get_rank(), dist_utils.get_world_size()
     logger.info(f"Rank: {rank}, world_size: {world_size}")
     assert torch.cuda.device_count() > 0, "GPU is not available"
     device = torch.device("cuda")
-    float_dtype = _trainer.UnitYTrainer._get_float_dtype(parameters.training.float_dtype)
+    float_dtype = _trainer.UnitYTrainer._get_float_dtype(
+        parameters.training.float_dtype
+    )
     logger.info(f"Device: {device}, float dtype: {float_dtype}")
-    model = _model.ModelBuilder(config=parameters.model, dtype=float_dtype, device=device).build_model()
+    model = _model.ModelBuilder(
+        config=parameters.model, dtype=float_dtype, device=device
+    ).build_model()
     logger.info(f"Model: {model}")
     train_data = _dataloader.UnityDataLoader(
-        config=parameters.train_data, rank=rank, world_size=world_size, target_device=device, float_dtype=float_dtype
+        config=parameters.train_data,
+        rank=rank,
+        world_size=world_size,
+        target_device=device,
+        float_dtype=float_dtype,
     )
     eval_data = _dataloader.UnityDataLoader(
-        config=parameters.eval_data, rank=rank, world_size=world_size, target_device=device, float_dtype=float_dtype
+        config=parameters.eval_data,
+        rank=rank,
+        world_size=world_size,
+        target_device=device,
+        float_dtype=float_dtype,
     )
     trainer = _trainer.UnitYTrainer(
         model=model,
@@ -75,7 +89,13 @@ def run_training(parameters: WorkflowParams, work_dir: str, checkpoint_dir: str)
 
 
 def get_loggers() -> List[logging.Logger]:
-    return [logger, _trainer.logger, _dataloader.logger, _model.logger, dist_utils.logger]
+    return [
+        logger,
+        _trainer.logger,
+        _dataloader.logger,
+        _model.logger,
+        dist_utils.logger,
+    ]
 
 
 def set_file_output_for_loggers(log_filename: str) -> None:
@@ -91,7 +111,9 @@ def main() -> None:
     dist_utils.init_distributed(get_loggers())
     is_master = dist_utils.is_main_process()
     with open(args.params, "r") as fp_in:
-        parameters = WorkflowParams.deserialize(yaml.load(fp_in, Loader=yaml.FullLoader))
+        parameters = WorkflowParams.deserialize(
+            yaml.load(fp_in, Loader=yaml.FullLoader)
+        )
     ts = str(int(time.time()))
     work_dir = args.wd
     checkpoint_dir = os.path.join(work_dir, "checkpoints")
@@ -108,7 +130,9 @@ def main() -> None:
     logger.info(f"Set logging to {log_path}")
     set_file_output_for_loggers(log_path)
     try:
-        run_training(parameters=parameters, work_dir=work_dir, checkpoint_dir=checkpoint_dir)
+        run_training(
+            parameters=parameters, work_dir=work_dir, checkpoint_dir=checkpoint_dir
+        )
     except Exception:
         # make sure that the stack tracke will be logged to log files
         logger.exception("Training failed")

+ 6 - 2
scripts/m4t/train/run_with_slurm.py

@@ -128,10 +128,14 @@ def main() -> None:
 
     assert job_name is not None
     assert len(job_name.split()) == 1, "spaces in job name not allowed"
-    assert partitions and len(partitions.split()) == 1, "spaces in partitions not allowed"
+    assert (
+        partitions and len(partitions.split()) == 1
+    ), "spaces in partitions not allowed"
     assert os.path.exists(params_file), "config file is missing"
     training_script_path = os.path.join(os.path.dirname(__file__), "run_training.py")
-    assert os.path.exists(training_script_path), f"Can't find training script {training_script_path}"
+    assert os.path.exists(
+        training_script_path
+    ), f"Can't find training script {training_script_path}"
     assert num_nodes > 0
     if not os.path.exists(work_dir):
         logger.info(f"Creating workdir {work_dir}")

+ 2 - 5
scripts/m4t/train/trainer.py

@@ -67,10 +67,7 @@ class UnitYTrainWrapper(nn.Module):
         )
         text_logits = self.model.final_proj(text_decoder_out)
         # t2u
-        (
-            unit_encoder_out,
-            unit_encoder_padding_mask,
-        ) = self.t2u.encode(
+        (unit_encoder_out, unit_encoder_padding_mask,) = self.t2u.encode(
             text_decoder_output=text_decoder_out,
             text_decoder_padding_mask=text_decoder_padding_mask,
         )
@@ -380,7 +377,7 @@ class UnitYTrainer:
         to_strip = ["module.", "model."]
         for prefix in to_strip:
             if key.startswith(prefix):
-                key = key[len(prefix):]
+                key = key[len(prefix) :]
         return key
 
     def _get_state(self) -> Dict[str, Any]:

+ 1 - 1
src/seamless_communication/models/inference/ngram_repeat_block_processor.py

@@ -45,7 +45,7 @@ class NGramRepeatBlockProcessor(LogitsProcessor):
         :returns:
             modified lprobs tensor with banned tokens set to -inf
         """
-        banned_tokens = [[] for _ in range(batch_size * beam_size)]
+        banned_tokens: List[List[int]] = [[] for _ in range(batch_size * beam_size)]
 
         if step_nr + 2 - self.no_repeat_ngram_size >= 0:
             cpu_tokens: List[List[int]] = seqs.cpu().tolist()

+ 11 - 4
src/seamless_communication/models/inference/translator.py

@@ -61,14 +61,17 @@ class Translator(nn.Module):
         # Load the model.
         if device == torch.device("cpu"):
             dtype = torch.float32
-        self.model: UnitYModel = self.load_model_for_inference(
+        self.model = self.load_model_for_inference(
             load_unity_model, model_name_or_card, device, dtype
         )
+        assert isinstance(self.model, UnitYModel)
+
         self.text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
+
+        self.unit_tokenizer: Optional[UnitTokenizer] = None
         if self.model.t2u_model is not None:
             self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
-        else:
-            self.unit_tokenizer = None
+
         self.device = device
         self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
         self.convert_to_fbank = WaveformToFbankConverter(
@@ -83,9 +86,10 @@ class Translator(nn.Module):
             pad_value=self.text_tokenizer.vocab_info.pad_idx, pad_to_multiple=2
         )
         # Load the vocoder.
-        self.vocoder: Vocoder = self.load_model_for_inference(
+        self.vocoder = self.load_model_for_inference(
             load_vocoder_model, vocoder_name_or_card, device, torch.float32
         )
+        assert isinstance(self.vocoder, Vocoder)
 
     @staticmethod
     def load_model_for_inference(
@@ -223,11 +227,14 @@ class Translator(nn.Module):
                 raise ValueError("src_lang must be specified for T2ST, T2TT tasks.")
 
             text = input
+            assert isinstance(text, str)
+
             self.token_encoder = self.text_tokenizer.create_encoder(
                 task="translation", lang=src_lang, mode="source", device=self.device
             )
             src = self.collate(self.token_encoder(text))
 
+        assert isinstance(self.model, UnitYModel)
         result = self.get_prediction(
             self.model,
             self.text_tokenizer,

+ 15 - 4
src/seamless_communication/models/tokenizer.py

@@ -27,7 +27,12 @@ class SPMTokenizer(TextTokenizer):
     langs: Set[str]
     prepend_target_langtok_to_target: bool
 
-    def __init__(self, pathname: PathLike, langs: Sequence[str], prepend_target_langtok_to_target: bool = True) -> None:
+    def __init__(
+        self,
+        pathname: PathLike,
+        langs: Sequence[str],
+        prepend_target_langtok_to_target: bool = True,
+    ) -> None:
         """
         :param pathname:
             The pathname of the SentencePiece model file.
@@ -79,18 +84,24 @@ class SPMTokenizer(TextTokenizer):
         assert lang is not None
 
         if lang not in self.langs:
-            raise ValueError(f"`lang` must be a supported language, but is '{lang}' instead.")
+            raise ValueError(
+                f"`lang` must be a supported language, but is '{lang}' instead."
+            )
 
         if mode is None or mode == "source":
             prefix_tokens = []
             suffix_tokens = ["</s>"]
         elif mode == "target":
             prefix_tokens = (
-                ["</s>"] + [self._lang_tok_to_internal(lang)] if self.prepend_target_langtok_to_target else []
+                ["</s>"] + [self._lang_tok_to_internal(lang)]
+                if self.prepend_target_langtok_to_target
+                else []
             )
             suffix_tokens = ["</s>"]
         else:
-            raise ValueError(f"`mode` must be 'source' or 'target', but is '{mode}' instead.")
+            raise ValueError(
+                f"`mode` must be 'source' or 'target', but is '{mode}' instead."
+            )
 
         return SentencePieceEncoder(
             self.model,

+ 1 - 1
src/seamless_communication/models/unity/generator.py

@@ -236,7 +236,7 @@ class UnitYGenerator:
             unit_seqs = unit_decoder_output.logits.argmax(dim=2)
             # Apply the padding mask to the generated units.
             unit_seqs = apply_padding_mask(
-                unit_seqs, decoder_padding_mask, fill_value=unit_decoder_output.pad_idx
+                unit_seqs, decoder_padding_mask, pad_value=unit_decoder_output.pad_idx
             )
 
         # Convert to speech units.