瀏覽代碼

Skips loading the vocoder for X2T eval. (#84)

* Fix mypy issues in cli/eval_utils.

* Skip loading vocoder for X2T evaluation.
Kaushik Ram Sadagopan 1 年之前
父節點
當前提交
bfcdf3ba4f

+ 60 - 54
src/seamless_communication/cli/eval_utils/compute_metrics.py

@@ -6,16 +6,21 @@
 
 import logging
 from pathlib import Path
-from typing import Optional
+from typing import Tuple, Union
 
 import pandas as pd
-import sacrebleu
 import whisper
+
+from fairseq2.typing import Device
 from jiwer import cer, wer
+from sacrebleu.metrics.base import Score, Signature
+from sacrebleu.metrics.bleu import BLEU
+from sacrebleu.metrics.chrf import CHRF
+from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 from tqdm import tqdm
+from whisper import Whisper
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 
-from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 
 logging.basicConfig(
     level=logging.INFO,
@@ -26,19 +31,19 @@ logger = logging.getLogger(__name__)
 
 
 def init_whisper_model(
-    device: str,
+    device: Device,
     whisper_model_name: str = "large",
-):
+) -> Whisper:
     return whisper.load_model(name=whisper_model_name, device=device)
 
 
 def transcribe_series(
     audio_paths_series: pd.Series,
-    asr_model,
+    asr_model: Whisper,
     audio_lang: str,
     beam_size: int = 1,
     temperature: float = 0.0,
-):
+) -> pd.Series:
     """Transcribes each audio filepath from series and returns series of transcriptions
     Args:
         audio_paths_series (pd.Series): each line contains path to audio file.
@@ -84,7 +89,9 @@ def transcribe_series(
     return transcriptions_series
 
 
-def whisper_normalize_series(transcription_series: pd.Series, text_lang: str):
+def whisper_normalize_series(
+    transcription_series: pd.Series, text_lang: str
+) -> pd.Series:
     """Normalizes the text series using whisper noramlizer. English has a specific one in whisper package.
     Args:
         transcription_series (pd.Series): Each line contains arbitrary text written in text_lang
@@ -112,12 +119,12 @@ def compute_asr_bleu(
     audio_paths_series: pd.Series,
     ref_text_series: pd.Series,
     lang: str,
-    asr_model,
-    whisper_normalize_text: Optional[bool] = True,
-    beam_size: Optional[int] = 1,
-    temperature: Optional[float] = 0.0,
-    return_transcriptions: Optional[bool] = True,
-):
+    asr_model: Whisper,
+    whisper_normalize_text: bool = True,
+    beam_size: int = 1,
+    temperature: float = 0.0,
+    return_transcriptions: bool = True,
+) -> Tuple[Score, Signature, pd.DataFrame]:
     """Wraps functions above to compute corpus-level ASR-BLEU
     ASR decoding hyper-parameters are hard coded to ensure reproducibility across evaluations
     Args:
@@ -125,10 +132,10 @@ def compute_asr_bleu(
         ref_text_series (pd.Series): each line contains the text reference to compare audio with
         lang (str): the language of both audio and ref_text
         asr_model: whisper ASR model
-        whisper_normalize_text (bool, Optional): normalize both text hypotheses and reference if True. Defaults to True.
-        beam_size (int, Optional): beam_size for whisper generation
-        temperature (float, Optional): Temperature sampling value for whisper generation
-        return_transcriptions (bool, Optional)
+        whisper_normalize_text (bool): normalize both text hypotheses and reference if True. Defaults to True.
+        beam_size (int): beam_size for whisper generation
+        temperature (float): Temperature sampling value for whisper generation
+        return_transcriptions (bool)
     """
 
     audio_transcriptions = transcribe_series(
@@ -159,11 +166,11 @@ def compute_asr_bleu(
     return asr_bleu, asr_bleu_signature, transcript_df
 
 
-def get_tokenizer(lang: str, metric: Optional[str] = "bleu"):
+def get_tokenizer(lang: str, metric: str = "bleu") -> str:
     """Get tokenizer for language
     Args:
         lang (str): Three letter code of the language
-        metric (str, Optional): Metric being computed. Valid values are "bleu" and "asr"
+        metric (str): Metric being computed. Valid values are "bleu" and "asr"
     """
     lang_tok_map = {
         "cmn": "char",
@@ -183,8 +190,8 @@ def compute_asr_error_rate(
     hyp_text_series: pd.Series,
     ref_text_series: pd.Series,
     lang: str,
-    whisper_normalize_text=True,
-):
+    whisper_normalize_text: bool = True,
+) -> Tuple[Score, str]:
     """Wraps normalization functions and computes ASR WER/CER score
     Args:
         hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
@@ -208,9 +215,9 @@ def compute_corpus_metric_score(
     hyp_text_series: pd.Series,
     ref_text_series: pd.Series,
     lang: str,
-    whisper_normalize_text=True,
-    metric: Optional[str] = "bleu",
-):
+    whisper_normalize_text: bool = True,
+    metric: str = "bleu",
+) -> Tuple[Score, Signature]:
     """Wraps normalization functions and compute corpus-level BLEU/chrF++ score
     Args:
         hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
@@ -225,12 +232,13 @@ def compute_corpus_metric_score(
         ref_text_series = whisper_normalize_series(ref_text_series, lang)
 
     tokenizer_name = get_tokenizer(lang)
+    corpus_metric_score_metric: Union[BLEU, CHRF]
     if metric == "bleu":
-        corpus_metric_score_metric = sacrebleu.metrics.bleu.BLEU(
+        corpus_metric_score_metric = BLEU(
             lowercase=whisper_normalize_text, tokenize=tokenizer_name
         )  # lowercase applied if we use whisper_normalize_text
     elif metric == "chrF++":
-        corpus_metric_score_metric = sacrebleu.CHRF(word_order=2)
+        corpus_metric_score_metric = CHRF(word_order=2)
 
     corpus_metric_score = corpus_metric_score_metric.corpus_score(
         hyp_text_series.to_list(), [ref_text_series.to_list()]
@@ -242,29 +250,29 @@ def compute_corpus_metric_score(
 
 
 def compute_quality_metrics(
-    output_manifest_tsv_path: str,
-    output_dir: str,
+    output_manifest_tsv_path: Path,
+    output_path: Path,
     tgt_lang: str,
     task: str,
-    device: str,
-    whisper_model_name: Optional[str] = "large",
-    whisper_normalize_text_output: Optional[bool] = False,
-    ref_text_col_name: Optional[str] = "ref_tgt_text",
-    pred_text_col_name: Optional[str] = "pred_tgt_text",
-    pred_audio_col_name: Optional[str] = "pred_tgt_audio",
-):
+    device: Device,
+    whisper_model_name: str = "large",
+    whisper_normalize_text_output: bool = False,
+    ref_text_col_name: str = "ref_tgt_text",
+    pred_text_col_name: str = "pred_tgt_text",
+    pred_audio_col_name: str = "pred_tgt_audio",
+) -> None:
     """Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side
     Args:
-        output_manifest_tsv_path (str): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
-        output_dir (str): Directory to write files with metrics
+        output_manifest_tsv_path (Path): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
+        output_path (Path): Directory to write files with metrics
         tgt_lang (str): what language we evaluate on
         task (str): Task we are currently evaluating for
-        device (str): Device to use for inference
-        whisper_model_name (str, Optional): Whisper model name. Defaults to "large".
-        whisper_normalize_text_output (bool, Optional): Normalizes text output using whisper_normalizer if set to true
-        ref_text_col_name (str, Optional): Column name in the tsv corresponding to reference target text
-        pred_text_col_name (str, Optional): Column name in the tsv corresponding to predicted target text
-        pred_audio_col_name (str, Optional): Column name in the tsv corresponding to predicted target audio.
+        device (Device): Device to use for inference
+        whisper_model_name (str): Whisper model name. Defaults to "large".
+        whisper_normalize_text_output (bool): Normalizes text output using whisper_normalizer if set to true
+        ref_text_col_name (str): Column name in the tsv corresponding to reference target text
+        pred_text_col_name (str): Column name in the tsv corresponding to predicted target text
+        pred_audio_col_name (str): Column name in the tsv corresponding to predicted target audio.
             Setting this value to none will skip speech metrics
     """
     df = pd.read_csv(
@@ -272,8 +280,8 @@ def compute_quality_metrics(
     )
     task = task.upper()
 
-    if not Path(output_dir).exists():
-        Path(output_dir).mkdir(parents=True, exist_ok=True)
+    if not output_path.exists():
+        output_path.mkdir(parents=True, exist_ok=True)
 
     if task in ["S2TT", "S2ST", "T2TT"] and pred_text_col_name:
         metric = "chrF++" if task == "T2TT" else "bleu"
@@ -298,7 +306,7 @@ def compute_quality_metrics(
                 else "s2tt_bleu.json"
             )
             cur_task = "S2TT"
-        with open((Path(output_dir) / filename).as_posix(), "w") as f:
+        with open((output_path / filename).as_posix(), "w") as f:
             f.write(text_metric_json)
 
         logger.info(f"{cur_task} {metric}:\n{text_metric_json}")
@@ -317,7 +325,7 @@ def compute_quality_metrics(
             whisper_normalize_text=True,
         )
         transcripts_df.to_csv(
-            (Path(output_dir) / "whisper_audio_transcriptions.tsv"),
+            (output_path / "whisper_audio_transcriptions.tsv"),
             sep="\t",
             index=False,
             encoding="utf-8",
@@ -331,27 +339,25 @@ def compute_quality_metrics(
         )
 
         with open(
-            (Path(output_dir) / f"{task.lower()}_asr_bleu_normalized.json").as_posix(),
+            (output_path / f"{task.lower()}_asr_bleu_normalized.json").as_posix(),
             "w",
         ) as f:
             f.write(asr_bleu_normalized_json)
 
-        with open((Path(output_dir) / filename).as_posix(), "w") as f:
+        with open((output_path / filename).as_posix(), "w") as f:
             f.write(text_metric_json)
 
         logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
 
     if task == "ASR":
-        asr_error_rate, asr_error_rate_signature = compute_asr_error_rate(
+        _, asr_error_rate_signature = compute_asr_error_rate(
             hyp_text_series=df[pred_text_col_name],
             ref_text_series=df[ref_text_col_name],
             lang=tgt_lang,
             whisper_normalize_text=whisper_normalize_text_output,
         )
 
-        with open((Path(output_dir) / "asr_error_rate.json").as_posix(), "w") as f:
+        with open((output_path / "asr_error_rate.json").as_posix(), "w") as f:
             f.write(asr_error_rate_signature)
 
         logger.info(f"ASR : {asr_error_rate_signature}")
-
-    return

+ 8 - 13
src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -12,7 +12,7 @@ import subprocess
 from argparse import Namespace
 from dataclasses import dataclass
 from pathlib import Path
-from typing import Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 
 import torch
 import torchaudio
@@ -28,7 +28,7 @@ from tqdm import tqdm
 from seamless_communication.cli.eval_utils.compute_metrics import (
     compute_quality_metrics,
 )
-from seamless_communication.cli.predict import (
+from seamless_communication.cli.m4t.predict import (
     add_inference_arguments,
     set_generation_opts,
 )
@@ -222,7 +222,7 @@ def run_eval(
     translator: Translator,
     text_tokenizer: TextTokenizer,
     ctx: EvalContext,
-    whisper_model_name: Optional[str] = None,
+    whisper_model_name: str,
 ) -> None:
     pipeline = build_data_pipeline(ctx, text_tokenizer)
 
@@ -267,10 +267,7 @@ def run_eval(
 
             # Skip performing inference when the input is entirely corrupted.
             if src["seqs"].numel() > 0:
-                (
-                    text_output,
-                    speech_output,
-                ) = translator.predict(
+                (text_output, speech_output,) = translator.predict(
                     src,
                     ctx.task,
                     ctx.target_lang,
@@ -287,10 +284,7 @@ def run_eval(
                     speech_output = None
 
             if valid_sequences is not None and not valid_sequences.all():
-                (
-                    text_output,
-                    speech_output,
-                ) = adjust_output_for_corrupted_inputs(
+                (text_output, speech_output,) = adjust_output_for_corrupted_inputs(
                     valid_sequences,
                     text_output,
                     speech_output,
@@ -323,7 +317,7 @@ def run_eval(
 
     compute_quality_metrics(
         output_manifest_tsv_path=model_outputs_tsv,
-        output_dir=output_path,
+        output_path=output_path,
         tgt_lang=ctx.target_lang,
         task=ctx.task,
         device=ctx.device,
@@ -331,7 +325,7 @@ def run_eval(
     )
 
 
-def main(optional_args: Optional[Dict] = None):
+def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
     parser = argparse.ArgumentParser(
         description="M4T evaluation for tasks supported by Translator."
     )
@@ -398,6 +392,7 @@ def main(optional_args: Optional[Dict] = None):
         device,
         text_tokenizer=text_tokenizer,
         dtype=dtype,
+        output_modality=output_modality,
     )
 
     text_generation_opts, unit_generation_opts = set_generation_opts(args)

+ 1 - 1
src/seamless_communication/cli/m4t/finetune/trainer.py

@@ -22,7 +22,7 @@ from fairseq2.optim.lr_scheduler import MyleLR
 from fairseq2.typing import Device
 from torch.optim import Adam
 
-from seamless_communication.cli.finetune import dataloader, dist_utils
+from seamless_communication.cli.m4t.finetune import dataloader, dist_utils
 from seamless_communication.models.unity import UnitYModel
 
 logger = logging.getLogger(__name__)

+ 11 - 8
src/seamless_communication/inference/translator.py

@@ -7,7 +7,7 @@ import logging
 from dataclasses import dataclass
 from enum import Enum, auto
 from pathlib import Path
-from typing import Callable, List, Optional, Tuple, Union, cast
+from typing import Any, Dict, Callable, List, Optional, Tuple, Union, cast
 
 import torch
 import torch.nn as nn
@@ -35,7 +35,7 @@ from seamless_communication.models.unity import (
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
 )
-from seamless_communication.models.vocoder import Vocoder, load_vocoder_model
+from seamless_communication.models.vocoder import load_vocoder_model
 
 logging.basicConfig(
     level=logging.INFO,
@@ -78,6 +78,7 @@ class Translator(nn.Module):
         device: Device,
         text_tokenizer: Optional[TextTokenizer] = None,
         dtype: DataType = torch.float16,
+        output_modality: Optional[Modality] = None,
     ):
         super().__init__()
         # Load the model.
@@ -112,11 +113,12 @@ class Translator(nn.Module):
         self.collate = Collater(
             pad_value=self.text_tokenizer.vocab_info.pad_idx, pad_to_multiple=2
         )
-        # Load the vocoder.
-        self.vocoder = self.load_model_for_inference(
-            load_vocoder_model, vocoder_name_or_card, device, torch.float32
-        )
-        assert isinstance(self.vocoder, Vocoder)
+        self.vocoder = None
+        if output_modality is None or output_modality == Modality.SPEECH:
+            # Load the vocoder.
+            self.vocoder = self.load_model_for_inference(
+                load_vocoder_model, vocoder_name_or_card, device, torch.float32
+            )
 
     @staticmethod
     def load_model_for_inference(
@@ -186,7 +188,7 @@ class Translator(nn.Module):
     @torch.inference_mode()
     def predict(
         self,
-        input: Union[str, Tensor, dict],
+        input: Union[str, Tensor, Dict[str, Any]],
         task_str: str,
         tgt_lang: str,
         src_lang: Optional[str] = None,
@@ -286,6 +288,7 @@ class Translator(nn.Module):
             return text_output.sentences, None
         else:
             assert unit_output is not None
+            assert self.vocoder is not None
 
             if isinstance(self.model.t2u_model, UnitYT2UModel):
                 # Remove the lang token for AR UnitY since the vocoder doesn't need it