فهرست منبع

Skips loading the vocoder for X2T eval. (#84)

* Fix mypy issues in cli/eval_utils.

* Skip loading vocoder for X2T evaluation.
Kaushik Ram Sadagopan 1 سال پیش
والد
کامیت
bfcdf3ba4f

+ 60 - 54
src/seamless_communication/cli/eval_utils/compute_metrics.py

@@ -6,16 +6,21 @@
 
 
 import logging
 import logging
 from pathlib import Path
 from pathlib import Path
-from typing import Optional
+from typing import Tuple, Union
 
 
 import pandas as pd
 import pandas as pd
-import sacrebleu
 import whisper
 import whisper
+
+from fairseq2.typing import Device
 from jiwer import cer, wer
 from jiwer import cer, wer
+from sacrebleu.metrics.base import Score, Signature
+from sacrebleu.metrics.bleu import BLEU
+from sacrebleu.metrics.chrf import CHRF
+from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 from tqdm import tqdm
 from tqdm import tqdm
+from whisper import Whisper
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 
 
-from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 
 
 logging.basicConfig(
 logging.basicConfig(
     level=logging.INFO,
     level=logging.INFO,
@@ -26,19 +31,19 @@ logger = logging.getLogger(__name__)
 
 
 
 
 def init_whisper_model(
 def init_whisper_model(
-    device: str,
+    device: Device,
     whisper_model_name: str = "large",
     whisper_model_name: str = "large",
-):
+) -> Whisper:
     return whisper.load_model(name=whisper_model_name, device=device)
     return whisper.load_model(name=whisper_model_name, device=device)
 
 
 
 
 def transcribe_series(
 def transcribe_series(
     audio_paths_series: pd.Series,
     audio_paths_series: pd.Series,
-    asr_model,
+    asr_model: Whisper,
     audio_lang: str,
     audio_lang: str,
     beam_size: int = 1,
     beam_size: int = 1,
     temperature: float = 0.0,
     temperature: float = 0.0,
-):
+) -> pd.Series:
     """Transcribes each audio filepath from series and returns series of transcriptions
     """Transcribes each audio filepath from series and returns series of transcriptions
     Args:
     Args:
         audio_paths_series (pd.Series): each line contains path to audio file.
         audio_paths_series (pd.Series): each line contains path to audio file.
@@ -84,7 +89,9 @@ def transcribe_series(
     return transcriptions_series
     return transcriptions_series
 
 
 
 
-def whisper_normalize_series(transcription_series: pd.Series, text_lang: str):
+def whisper_normalize_series(
+    transcription_series: pd.Series, text_lang: str
+) -> pd.Series:
     """Normalizes the text series using whisper noramlizer. English has a specific one in whisper package.
     """Normalizes the text series using whisper noramlizer. English has a specific one in whisper package.
     Args:
     Args:
         transcription_series (pd.Series): Each line contains arbitrary text written in text_lang
         transcription_series (pd.Series): Each line contains arbitrary text written in text_lang
@@ -112,12 +119,12 @@ def compute_asr_bleu(
     audio_paths_series: pd.Series,
     audio_paths_series: pd.Series,
     ref_text_series: pd.Series,
     ref_text_series: pd.Series,
     lang: str,
     lang: str,
-    asr_model,
-    whisper_normalize_text: Optional[bool] = True,
-    beam_size: Optional[int] = 1,
-    temperature: Optional[float] = 0.0,
-    return_transcriptions: Optional[bool] = True,
-):
+    asr_model: Whisper,
+    whisper_normalize_text: bool = True,
+    beam_size: int = 1,
+    temperature: float = 0.0,
+    return_transcriptions: bool = True,
+) -> Tuple[Score, Signature, pd.DataFrame]:
     """Wraps functions above to compute corpus-level ASR-BLEU
     """Wraps functions above to compute corpus-level ASR-BLEU
     ASR decoding hyper-parameters are hard coded to ensure reproducibility across evaluations
     ASR decoding hyper-parameters are hard coded to ensure reproducibility across evaluations
     Args:
     Args:
@@ -125,10 +132,10 @@ def compute_asr_bleu(
         ref_text_series (pd.Series): each line contains the text reference to compare audio with
         ref_text_series (pd.Series): each line contains the text reference to compare audio with
         lang (str): the language of both audio and ref_text
         lang (str): the language of both audio and ref_text
         asr_model: whisper ASR model
         asr_model: whisper ASR model
-        whisper_normalize_text (bool, Optional): normalize both text hypotheses and reference if True. Defaults to True.
-        beam_size (int, Optional): beam_size for whisper generation
-        temperature (float, Optional): Temperature sampling value for whisper generation
-        return_transcriptions (bool, Optional)
+        whisper_normalize_text (bool): normalize both text hypotheses and reference if True. Defaults to True.
+        beam_size (int): beam_size for whisper generation
+        temperature (float): Temperature sampling value for whisper generation
+        return_transcriptions (bool)
     """
     """
 
 
     audio_transcriptions = transcribe_series(
     audio_transcriptions = transcribe_series(
@@ -159,11 +166,11 @@ def compute_asr_bleu(
     return asr_bleu, asr_bleu_signature, transcript_df
     return asr_bleu, asr_bleu_signature, transcript_df
 
 
 
 
-def get_tokenizer(lang: str, metric: Optional[str] = "bleu"):
+def get_tokenizer(lang: str, metric: str = "bleu") -> str:
     """Get tokenizer for language
     """Get tokenizer for language
     Args:
     Args:
         lang (str): Three letter code of the language
         lang (str): Three letter code of the language
-        metric (str, Optional): Metric being computed. Valid values are "bleu" and "asr"
+        metric (str): Metric being computed. Valid values are "bleu" and "asr"
     """
     """
     lang_tok_map = {
     lang_tok_map = {
         "cmn": "char",
         "cmn": "char",
@@ -183,8 +190,8 @@ def compute_asr_error_rate(
     hyp_text_series: pd.Series,
     hyp_text_series: pd.Series,
     ref_text_series: pd.Series,
     ref_text_series: pd.Series,
     lang: str,
     lang: str,
-    whisper_normalize_text=True,
-):
+    whisper_normalize_text: bool = True,
+) -> Tuple[Score, str]:
     """Wraps normalization functions and computes ASR WER/CER score
     """Wraps normalization functions and computes ASR WER/CER score
     Args:
     Args:
         hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
         hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
@@ -208,9 +215,9 @@ def compute_corpus_metric_score(
     hyp_text_series: pd.Series,
     hyp_text_series: pd.Series,
     ref_text_series: pd.Series,
     ref_text_series: pd.Series,
     lang: str,
     lang: str,
-    whisper_normalize_text=True,
-    metric: Optional[str] = "bleu",
-):
+    whisper_normalize_text: bool = True,
+    metric: str = "bleu",
+) -> Tuple[Score, Signature]:
     """Wraps normalization functions and compute corpus-level BLEU/chrF++ score
     """Wraps normalization functions and compute corpus-level BLEU/chrF++ score
     Args:
     Args:
         hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
         hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
@@ -225,12 +232,13 @@ def compute_corpus_metric_score(
         ref_text_series = whisper_normalize_series(ref_text_series, lang)
         ref_text_series = whisper_normalize_series(ref_text_series, lang)
 
 
     tokenizer_name = get_tokenizer(lang)
     tokenizer_name = get_tokenizer(lang)
+    corpus_metric_score_metric: Union[BLEU, CHRF]
     if metric == "bleu":
     if metric == "bleu":
-        corpus_metric_score_metric = sacrebleu.metrics.bleu.BLEU(
+        corpus_metric_score_metric = BLEU(
             lowercase=whisper_normalize_text, tokenize=tokenizer_name
             lowercase=whisper_normalize_text, tokenize=tokenizer_name
         )  # lowercase applied if we use whisper_normalize_text
         )  # lowercase applied if we use whisper_normalize_text
     elif metric == "chrF++":
     elif metric == "chrF++":
-        corpus_metric_score_metric = sacrebleu.CHRF(word_order=2)
+        corpus_metric_score_metric = CHRF(word_order=2)
 
 
     corpus_metric_score = corpus_metric_score_metric.corpus_score(
     corpus_metric_score = corpus_metric_score_metric.corpus_score(
         hyp_text_series.to_list(), [ref_text_series.to_list()]
         hyp_text_series.to_list(), [ref_text_series.to_list()]
@@ -242,29 +250,29 @@ def compute_corpus_metric_score(
 
 
 
 
 def compute_quality_metrics(
 def compute_quality_metrics(
-    output_manifest_tsv_path: str,
-    output_dir: str,
+    output_manifest_tsv_path: Path,
+    output_path: Path,
     tgt_lang: str,
     tgt_lang: str,
     task: str,
     task: str,
-    device: str,
-    whisper_model_name: Optional[str] = "large",
-    whisper_normalize_text_output: Optional[bool] = False,
-    ref_text_col_name: Optional[str] = "ref_tgt_text",
-    pred_text_col_name: Optional[str] = "pred_tgt_text",
-    pred_audio_col_name: Optional[str] = "pred_tgt_audio",
-):
+    device: Device,
+    whisper_model_name: str = "large",
+    whisper_normalize_text_output: bool = False,
+    ref_text_col_name: str = "ref_tgt_text",
+    pred_text_col_name: str = "pred_tgt_text",
+    pred_audio_col_name: str = "pred_tgt_audio",
+) -> None:
     """Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side
     """Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side
     Args:
     Args:
-        output_manifest_tsv_path (str): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
-        output_dir (str): Directory to write files with metrics
+        output_manifest_tsv_path (Path): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
+        output_path (Path): Directory to write files with metrics
         tgt_lang (str): what language we evaluate on
         tgt_lang (str): what language we evaluate on
         task (str): Task we are currently evaluating for
         task (str): Task we are currently evaluating for
-        device (str): Device to use for inference
-        whisper_model_name (str, Optional): Whisper model name. Defaults to "large".
-        whisper_normalize_text_output (bool, Optional): Normalizes text output using whisper_normalizer if set to true
-        ref_text_col_name (str, Optional): Column name in the tsv corresponding to reference target text
-        pred_text_col_name (str, Optional): Column name in the tsv corresponding to predicted target text
-        pred_audio_col_name (str, Optional): Column name in the tsv corresponding to predicted target audio.
+        device (Device): Device to use for inference
+        whisper_model_name (str): Whisper model name. Defaults to "large".
+        whisper_normalize_text_output (bool): Normalizes text output using whisper_normalizer if set to true
+        ref_text_col_name (str): Column name in the tsv corresponding to reference target text
+        pred_text_col_name (str): Column name in the tsv corresponding to predicted target text
+        pred_audio_col_name (str): Column name in the tsv corresponding to predicted target audio.
             Setting this value to none will skip speech metrics
             Setting this value to none will skip speech metrics
     """
     """
     df = pd.read_csv(
     df = pd.read_csv(
@@ -272,8 +280,8 @@ def compute_quality_metrics(
     )
     )
     task = task.upper()
     task = task.upper()
 
 
-    if not Path(output_dir).exists():
-        Path(output_dir).mkdir(parents=True, exist_ok=True)
+    if not output_path.exists():
+        output_path.mkdir(parents=True, exist_ok=True)
 
 
     if task in ["S2TT", "S2ST", "T2TT"] and pred_text_col_name:
     if task in ["S2TT", "S2ST", "T2TT"] and pred_text_col_name:
         metric = "chrF++" if task == "T2TT" else "bleu"
         metric = "chrF++" if task == "T2TT" else "bleu"
@@ -298,7 +306,7 @@ def compute_quality_metrics(
                 else "s2tt_bleu.json"
                 else "s2tt_bleu.json"
             )
             )
             cur_task = "S2TT"
             cur_task = "S2TT"
-        with open((Path(output_dir) / filename).as_posix(), "w") as f:
+        with open((output_path / filename).as_posix(), "w") as f:
             f.write(text_metric_json)
             f.write(text_metric_json)
 
 
         logger.info(f"{cur_task} {metric}:\n{text_metric_json}")
         logger.info(f"{cur_task} {metric}:\n{text_metric_json}")
@@ -317,7 +325,7 @@ def compute_quality_metrics(
             whisper_normalize_text=True,
             whisper_normalize_text=True,
         )
         )
         transcripts_df.to_csv(
         transcripts_df.to_csv(
-            (Path(output_dir) / "whisper_audio_transcriptions.tsv"),
+            (output_path / "whisper_audio_transcriptions.tsv"),
             sep="\t",
             sep="\t",
             index=False,
             index=False,
             encoding="utf-8",
             encoding="utf-8",
@@ -331,27 +339,25 @@ def compute_quality_metrics(
         )
         )
 
 
         with open(
         with open(
-            (Path(output_dir) / f"{task.lower()}_asr_bleu_normalized.json").as_posix(),
+            (output_path / f"{task.lower()}_asr_bleu_normalized.json").as_posix(),
             "w",
             "w",
         ) as f:
         ) as f:
             f.write(asr_bleu_normalized_json)
             f.write(asr_bleu_normalized_json)
 
 
-        with open((Path(output_dir) / filename).as_posix(), "w") as f:
+        with open((output_path / filename).as_posix(), "w") as f:
             f.write(text_metric_json)
             f.write(text_metric_json)
 
 
         logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
         logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
 
 
     if task == "ASR":
     if task == "ASR":
-        asr_error_rate, asr_error_rate_signature = compute_asr_error_rate(
+        _, asr_error_rate_signature = compute_asr_error_rate(
             hyp_text_series=df[pred_text_col_name],
             hyp_text_series=df[pred_text_col_name],
             ref_text_series=df[ref_text_col_name],
             ref_text_series=df[ref_text_col_name],
             lang=tgt_lang,
             lang=tgt_lang,
             whisper_normalize_text=whisper_normalize_text_output,
             whisper_normalize_text=whisper_normalize_text_output,
         )
         )
 
 
-        with open((Path(output_dir) / "asr_error_rate.json").as_posix(), "w") as f:
+        with open((output_path / "asr_error_rate.json").as_posix(), "w") as f:
             f.write(asr_error_rate_signature)
             f.write(asr_error_rate_signature)
 
 
         logger.info(f"ASR : {asr_error_rate_signature}")
         logger.info(f"ASR : {asr_error_rate_signature}")
-
-    return

+ 8 - 13
src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -12,7 +12,7 @@ import subprocess
 from argparse import Namespace
 from argparse import Namespace
 from dataclasses import dataclass
 from dataclasses import dataclass
 from pathlib import Path
 from pathlib import Path
-from typing import Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple
 
 
 import torch
 import torch
 import torchaudio
 import torchaudio
@@ -28,7 +28,7 @@ from tqdm import tqdm
 from seamless_communication.cli.eval_utils.compute_metrics import (
 from seamless_communication.cli.eval_utils.compute_metrics import (
     compute_quality_metrics,
     compute_quality_metrics,
 )
 )
-from seamless_communication.cli.predict import (
+from seamless_communication.cli.m4t.predict import (
     add_inference_arguments,
     add_inference_arguments,
     set_generation_opts,
     set_generation_opts,
 )
 )
@@ -222,7 +222,7 @@ def run_eval(
     translator: Translator,
     translator: Translator,
     text_tokenizer: TextTokenizer,
     text_tokenizer: TextTokenizer,
     ctx: EvalContext,
     ctx: EvalContext,
-    whisper_model_name: Optional[str] = None,
+    whisper_model_name: str,
 ) -> None:
 ) -> None:
     pipeline = build_data_pipeline(ctx, text_tokenizer)
     pipeline = build_data_pipeline(ctx, text_tokenizer)
 
 
@@ -267,10 +267,7 @@ def run_eval(
 
 
             # Skip performing inference when the input is entirely corrupted.
             # Skip performing inference when the input is entirely corrupted.
             if src["seqs"].numel() > 0:
             if src["seqs"].numel() > 0:
-                (
-                    text_output,
-                    speech_output,
-                ) = translator.predict(
+                (text_output, speech_output,) = translator.predict(
                     src,
                     src,
                     ctx.task,
                     ctx.task,
                     ctx.target_lang,
                     ctx.target_lang,
@@ -287,10 +284,7 @@ def run_eval(
                     speech_output = None
                     speech_output = None
 
 
             if valid_sequences is not None and not valid_sequences.all():
             if valid_sequences is not None and not valid_sequences.all():
-                (
-                    text_output,
-                    speech_output,
-                ) = adjust_output_for_corrupted_inputs(
+                (text_output, speech_output,) = adjust_output_for_corrupted_inputs(
                     valid_sequences,
                     valid_sequences,
                     text_output,
                     text_output,
                     speech_output,
                     speech_output,
@@ -323,7 +317,7 @@ def run_eval(
 
 
     compute_quality_metrics(
     compute_quality_metrics(
         output_manifest_tsv_path=model_outputs_tsv,
         output_manifest_tsv_path=model_outputs_tsv,
-        output_dir=output_path,
+        output_path=output_path,
         tgt_lang=ctx.target_lang,
         tgt_lang=ctx.target_lang,
         task=ctx.task,
         task=ctx.task,
         device=ctx.device,
         device=ctx.device,
@@ -331,7 +325,7 @@ def run_eval(
     )
     )
 
 
 
 
-def main(optional_args: Optional[Dict] = None):
+def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
     parser = argparse.ArgumentParser(
     parser = argparse.ArgumentParser(
         description="M4T evaluation for tasks supported by Translator."
         description="M4T evaluation for tasks supported by Translator."
     )
     )
@@ -398,6 +392,7 @@ def main(optional_args: Optional[Dict] = None):
         device,
         device,
         text_tokenizer=text_tokenizer,
         text_tokenizer=text_tokenizer,
         dtype=dtype,
         dtype=dtype,
+        output_modality=output_modality,
     )
     )
 
 
     text_generation_opts, unit_generation_opts = set_generation_opts(args)
     text_generation_opts, unit_generation_opts = set_generation_opts(args)

+ 1 - 1
src/seamless_communication/cli/m4t/finetune/trainer.py

@@ -22,7 +22,7 @@ from fairseq2.optim.lr_scheduler import MyleLR
 from fairseq2.typing import Device
 from fairseq2.typing import Device
 from torch.optim import Adam
 from torch.optim import Adam
 
 
-from seamless_communication.cli.finetune import dataloader, dist_utils
+from seamless_communication.cli.m4t.finetune import dataloader, dist_utils
 from seamless_communication.models.unity import UnitYModel
 from seamless_communication.models.unity import UnitYModel
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)

+ 11 - 8
src/seamless_communication/inference/translator.py

@@ -7,7 +7,7 @@ import logging
 from dataclasses import dataclass
 from dataclasses import dataclass
 from enum import Enum, auto
 from enum import Enum, auto
 from pathlib import Path
 from pathlib import Path
-from typing import Callable, List, Optional, Tuple, Union, cast
+from typing import Any, Dict, Callable, List, Optional, Tuple, Union, cast
 
 
 import torch
 import torch
 import torch.nn as nn
 import torch.nn as nn
@@ -35,7 +35,7 @@ from seamless_communication.models.unity import (
     load_unity_text_tokenizer,
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
     load_unity_unit_tokenizer,
 )
 )
-from seamless_communication.models.vocoder import Vocoder, load_vocoder_model
+from seamless_communication.models.vocoder import load_vocoder_model
 
 
 logging.basicConfig(
 logging.basicConfig(
     level=logging.INFO,
     level=logging.INFO,
@@ -78,6 +78,7 @@ class Translator(nn.Module):
         device: Device,
         device: Device,
         text_tokenizer: Optional[TextTokenizer] = None,
         text_tokenizer: Optional[TextTokenizer] = None,
         dtype: DataType = torch.float16,
         dtype: DataType = torch.float16,
+        output_modality: Optional[Modality] = None,
     ):
     ):
         super().__init__()
         super().__init__()
         # Load the model.
         # Load the model.
@@ -112,11 +113,12 @@ class Translator(nn.Module):
         self.collate = Collater(
         self.collate = Collater(
             pad_value=self.text_tokenizer.vocab_info.pad_idx, pad_to_multiple=2
             pad_value=self.text_tokenizer.vocab_info.pad_idx, pad_to_multiple=2
         )
         )
-        # Load the vocoder.
-        self.vocoder = self.load_model_for_inference(
-            load_vocoder_model, vocoder_name_or_card, device, torch.float32
-        )
-        assert isinstance(self.vocoder, Vocoder)
+        self.vocoder = None
+        if output_modality is None or output_modality == Modality.SPEECH:
+            # Load the vocoder.
+            self.vocoder = self.load_model_for_inference(
+                load_vocoder_model, vocoder_name_or_card, device, torch.float32
+            )
 
 
     @staticmethod
     @staticmethod
     def load_model_for_inference(
     def load_model_for_inference(
@@ -186,7 +188,7 @@ class Translator(nn.Module):
     @torch.inference_mode()
     @torch.inference_mode()
     def predict(
     def predict(
         self,
         self,
-        input: Union[str, Tensor, dict],
+        input: Union[str, Tensor, Dict[str, Any]],
         task_str: str,
         task_str: str,
         tgt_lang: str,
         tgt_lang: str,
         src_lang: Optional[str] = None,
         src_lang: Optional[str] = None,
@@ -286,6 +288,7 @@ class Translator(nn.Module):
             return text_output.sentences, None
             return text_output.sentences, None
         else:
         else:
             assert unit_output is not None
             assert unit_output is not None
+            assert self.vocoder is not None
 
 
             if isinstance(self.model.t2u_model, UnitYT2UModel):
             if isinstance(self.model.t2u_model, UnitYT2UModel):
                 # Remove the lang token for AR UnitY since the vocoder doesn't need it
                 # Remove the lang token for AR UnitY since the vocoder doesn't need it