ソースを参照

Add evaluation pipeline to compute all metrics (#79)

spopuri 1 年間 前
コミット
a1646750f1

+ 356 - 0
scripts/eval_utils/compute_metrics.py

@@ -0,0 +1,356 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from pathlib import Path
+import logging
+import pandas as pd
+import sacrebleu
+import whisper
+from jiwer import wer, cer
+from tqdm import tqdm
+from typing import Optional
+from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
+
+from scripts.eval_utils.lang_mapping import LANG3_LANG2
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
+
+def init_whisper_model(
+    device: str,
+    whisper_model_name: str = "large",
+):
+    return whisper.load_model(name=whisper_model_name, device=device)
+
+
+def transcribe_series(
+    audio_paths_series: pd.Series,
+    asr_model,
+    audio_lang: str,
+    beam_size: int = 1,
+    temperature: float = 0.0,
+):
+    """Transcribes each audio filepath from series and returns series of transcriptions
+    Args:
+        audio_paths_series (pd.Series): each line contains path to audio file.
+        asr_model: ASR model to do the transcribing process e.g. Whisper
+        audio_lang (str): what language is used in the given audio, used by ASR model
+        beam_size (int): whisper beam size. Defaults to 1
+        temperature (float): whisper temperature. Defaults to 0.0 to avoid fallback decoding (see details below).
+    Returns:
+        pd.Series: Series where each line has a transcription of corresponding audio from audio_paths_series
+    Whisper model implements decoding with fallback: https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L147
+    The core idea is that decoding at each time step might happen multiple times if at least one criterion to "fall back" i.e.
+    start over is fired. Number of fallback iterations is determined by the schedule of temperature values:
+    https://github.com/openai/whisper/blob/main/whisper/transcribe.py#L41
+    By default this schedule is active and temperature=(0.0, 0.2, 0.4, 0.6, 0.8, 1.0) i.e. even with beam_size 5 it might fell back and
+    turn on sampling by using temperature > 0, in this case the beam search is not used in the fall back iteration.
+    Explicit setting of temperature=0.0 overwrites the schedule and fall back decoding has only one for loop iteration i.e. no fall backs.
+    This allows us to do reproducible evaluation without sample variations. Beware that this might introduce the repetition loops in
+    the transcriptions and lead to worse ASR-BLEU score in the end.
+    """
+
+    if len(audio_lang) == 3:
+        # to make it work with whisper
+        audio_lang = LANG3_LANG2[audio_lang]
+
+    transcriptions = {}
+
+    for idx, audio_path in tqdm(
+        audio_paths_series.items(),
+        desc=f"Transcribing {audio_paths_series.name} column",
+        total=len(audio_paths_series),
+    ):
+        hypo = asr_model.transcribe(
+            audio_path,
+            temperature=temperature,
+            beam_size=beam_size,
+            language=audio_lang,
+        )["text"].strip()
+        transcriptions[idx] = hypo
+
+    transcriptions_series = pd.Series(transcriptions)
+    transcriptions_series.name = f"{audio_paths_series.name}_transcribed"
+
+    return transcriptions_series
+
+
+def whisper_normalize_series(transcription_series: pd.Series, text_lang: str):
+    """Normalizes the text series using whisper noramlizer. English has a specific one in whisper package.
+    Args:
+        transcription_series (pd.Series): Each line contains arbitrary text written in text_lang
+        text_lang (str): Language of the text in series
+    Returns:
+        pd.Series: Series with normalized text
+    """
+    if text_lang == "eng":
+        normalizer = EnglishTextNormalizer()
+    else:
+        normalizer = BasicTextNormalizer()
+
+    norm_transcriptions = {}
+
+    for idx, text in transcription_series.items():
+        norm_transcriptions[idx] = normalizer(text)
+
+    norm_transcriptions_series = pd.Series(norm_transcriptions)
+    norm_transcriptions_series.name = transcription_series.name
+
+    return norm_transcriptions_series
+
+
+def compute_asr_bleu(
+    audio_paths_series: pd.Series,
+    ref_text_series: pd.Series,
+    lang: str,
+    asr_model,
+    whisper_normalize_text: Optional[bool] = True,
+    beam_size: Optional[int] = 1,
+    temperature: Optional[float] = 0.0,
+    return_transcriptions: Optional[bool] = True,
+):
+    """Wraps functions above to compute corpus-level ASR-BLEU
+    ASR decoding hyper-parameters are hard coded to ensure reproducibility across evaluations
+    Args:
+        audio_paths_series (pd.Series): each line contains path to audio
+        ref_text_series (pd.Series): each line contains the text reference to compare audio with
+        lang (str): the language of both audio and ref_text
+        asr_model: whisper ASR model
+        whisper_normalize_text (bool, Optional): normalize both text hypotheses and reference if True. Defaults to True.
+        beam_size (int, Optional): beam_size for whisper generation
+        temperature (float, Optional): Temperature sampling value for whisper generation
+        return_transcriptions (bool, Optional)
+    """
+
+    audio_transcriptions = transcribe_series(
+        audio_paths_series,
+        asr_model,
+        audio_lang=lang,
+        beam_size=beam_size,
+        temperature=temperature,
+    )
+    asr_bleu, asr_bleu_signature = compute_corpus_metric_score(
+        audio_transcriptions, ref_text_series, lang, whisper_normalize_text
+    )
+    asr_bleu_signature.info["whisper_asr_beam_size"] = beam_size
+    asr_bleu_signature.info["whisper_asr_temperature"] = temperature
+    asr_bleu_signature.info["whisper_asr_language"] = lang
+
+    transcript_df = None
+    if return_transcriptions:
+        transcript_df = pd.concat(
+            [
+                audio_paths_series,
+                audio_transcriptions,
+                ref_text_series,
+            ],
+            axis=1,
+            keys=["audio", "transcript", "reference"],
+        )
+    return asr_bleu, asr_bleu_signature, transcript_df
+
+
+def get_tokenizer(lang: str, metric: Optional[str] = "bleu"):
+    """Get tokenizer for language
+    Args:
+        lang (str): Three letter code of the language
+        metric (str, Optional): Metric being computed. Valid values are "bleu" and "asr"
+    """
+    lang_tok_map = {
+        "cmn": "char",
+        "jpn": "char",
+        "tha": "char",
+        "lao": "char",
+        "mya": "char",
+    }
+    default = (
+        "13a" if metric == "bleu" else "word"
+    )  # 13a is the default tokenizer for bleu and wer for asr
+    tok = lang_tok_map.get(lang, default)
+    return tok
+
+
+def compute_asr_error_rate(
+    hyp_text_series: pd.Series,
+    ref_text_series: pd.Series,
+    lang: str,
+    whisper_normalize_text=True,
+):
+    """Wraps normalization functions and computes ASR WER/CER score
+    Args:
+        hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
+        ref_text_series (pd.Series): _description_
+        lang (str): _description_
+        whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True.
+    Returns:
+        (MetricScore, MetricScoreSignature)
+    """
+    if whisper_normalize_text:
+        hyp_text_series = whisper_normalize_series(hyp_text_series, lang)
+        ref_text_series = whisper_normalize_series(ref_text_series, lang)
+
+    tokenizer_name = get_tokenizer(lang, metric="error_rate")
+    metric_name = wer if tokenizer_name == "word" else cer
+    metric_score = metric_name(hyp_text_series.to_list(), ref_text_series.to_list())
+    return metric_score, f"{metric_name.__name__} is {metric_score}"
+
+
+def compute_corpus_metric_score(
+    hyp_text_series: pd.Series,
+    ref_text_series: pd.Series,
+    lang: str,
+    whisper_normalize_text=True,
+    metric: Optional[str] = "bleu",
+):
+    """Wraps normalization functions and compute corpus-level BLEU/chrF++ score
+    Args:
+        hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
+        ref_text_series (pd.Series): _description_
+        lang (str): _description_
+        whisper_normalize_text (bool, optional): normalize both text hypotheses and reference if True. Defaults to True.
+    Returns:
+        (MetricScore, MetricScoreSignature)
+    """
+    if whisper_normalize_text:
+        hyp_text_series = whisper_normalize_series(hyp_text_series, lang)
+        ref_text_series = whisper_normalize_series(ref_text_series, lang)
+
+    tokenizer_name = get_tokenizer(lang)
+    if metric == "bleu":
+        corpus_metric_score_metric = sacrebleu.metrics.bleu.BLEU(
+            lowercase=whisper_normalize_text, tokenize=tokenizer_name
+        )  # lowercase applied if we use whisper_normalize_text
+    elif metric == "chrF++":
+        corpus_metric_score_metric = sacrebleu.CHRF(word_order=2)
+
+    corpus_metric_score = corpus_metric_score_metric.corpus_score(
+        hyp_text_series.to_list(), [ref_text_series.to_list()]
+    )
+    corpus_metric_score_signature = corpus_metric_score_metric.get_signature()
+    corpus_metric_score_signature.info["whisper_normalize"] = whisper_normalize_text
+
+    return corpus_metric_score, corpus_metric_score_signature
+
+
+def compute_quality_metrics(
+    output_manifest_tsv_path: str,
+    output_dir: str,
+    tgt_lang: str,
+    task: str,
+    device: str,
+    whisper_model_name: Optional[str] = "large",
+    whisper_normalize_text_output: Optional[bool] = False,
+    ref_text_col_name: Optional[str] = "ref_tgt_text",
+    pred_text_col_name: Optional[str] = "pred_tgt_text",
+    pred_audio_col_name: Optional[str] = "pred_tgt_audio",
+):
+    """Wraps asr and s2t bleu functions to call it with TSV manifest composed on expressivity side
+    Args:
+        output_manifest_tsv_path (str): output manifest which has "ref_text", "hypo_audio", "s2t_out" column names
+        output_dir (str): Directory to write files with metrics
+        tgt_lang (str): what language we evaluate on
+        task (str): Task we are currently evaluating for
+        device (str): Device to use for inference
+        whisper_model_name (str, Optional): Whisper model name. Defaults to "large".
+        whisper_normalize_text_output (bool, Optional): Normalizes text output using whisper_normalizer if set to true
+        ref_text_col_name (str, Optional): Column name in the tsv corresponding to reference target text
+        pred_text_col_name (str, Optional): Column name in the tsv corresponding to predicted target text
+        pred_audio_col_name (str, Optional): Column name in the tsv corresponding to predicted target audio.
+            Setting this value to none will skip speech metrics
+    """
+    df = pd.read_csv(
+        output_manifest_tsv_path, sep="\t", quoting=3, encoding="utf-8", escapechar="\\"
+    )
+    task = task.upper()
+
+    if not Path(output_dir).exists():
+        Path(output_dir).mkdir(parents=True, exist_ok=True)
+
+    if task in ["S2TT", "S2ST", "T2TT"] and pred_text_col_name:
+        metric = "chrF++" if task == "T2TT" else "bleu"
+        text_metric, text_metric_signature = compute_corpus_metric_score(
+            hyp_text_series=df[pred_text_col_name],
+            ref_text_series=df[ref_text_col_name],
+            lang=tgt_lang,
+            whisper_normalize_text=whisper_normalize_text_output,
+            metric=metric,
+        )
+        text_metric_json = text_metric.format(
+            signature=text_metric_signature.format(), is_json=True
+        )
+
+        if task == "T2TT":
+            filename = "t2tt_chrf.json"
+            cur_task = "T2TT"
+        else:
+            filename = (
+                "s2tt_bleu_normalized.json"
+                if whisper_normalize_text_output
+                else "s2tt_bleu.json"
+            )
+            cur_task = "S2TT"
+        with open((Path(output_dir) / filename).as_posix(), "w") as f:
+            f.write(text_metric_json)
+
+        logger.info(f"{cur_task} {metric}:\n{text_metric_json}")
+
+    if task in ["T2ST", "S2ST"]:
+        whisper_model = init_whisper_model(device, whisper_model_name)
+        (
+            asr_bleu_normalized,
+            asr_bleu_normalized_signature,
+            transcripts_df,
+        ) = compute_asr_bleu(
+            audio_paths_series=df[pred_audio_col_name],
+            ref_text_series=df[ref_text_col_name],
+            lang=tgt_lang,
+            asr_model=whisper_model,
+            whisper_normalize_text=True,
+        )
+        transcripts_df.to_csv(
+            (Path(output_dir) / f"whisper_audio_transcriptions.tsv"),
+            sep="\t",
+            index=False,
+            encoding="utf-8",
+            escapechar="\\",
+        )
+
+        asr_bleu_normalized_signature.info["whisper_asr_model"] = whisper_model_name
+
+        asr_bleu_normalized_json = asr_bleu_normalized.format(
+            signature=asr_bleu_normalized_signature.format(), is_json=True
+        )
+
+        with open(
+            (Path(output_dir) / f"{task.lower()}_asr_bleu_normalized.json").as_posix(),
+            "w",
+        ) as f:
+            f.write(asr_bleu_normalized_json)
+
+        with open((Path(output_dir) / filename).as_posix(), "w") as f:
+            f.write(text_metric_json)
+
+        logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
+
+    if task == "ASR":
+        asr_error_rate, asr_error_rate_signature = compute_asr_error_rate(
+            hyp_text_series=df[pred_text_col_name],
+            ref_text_series=df[ref_text_col_name],
+            lang=tgt_lang,
+            whisper_normalize_text=whisper_normalize_text_output,
+        )
+
+        with open((Path(output_dir) / "asr_error_rate.json").as_posix(), "w") as f:
+            f.write(asr_error_rate_signature)
+
+        logger.info(f"ASR : {asr_error_rate_signature}")
+
+    return

+ 177 - 0
scripts/eval_utils/lang_mapping.py

@@ -0,0 +1,177 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+LANG2_LANG3 = {
+    "en": "eng",
+    "ar": "arb",
+    "as": "asm",
+    "be": "bel",
+    "bg": "bul",
+    "bn": "ben",
+    "ca": "cat",
+    "ckb": "ckb",
+    "cs": "ces",
+    "cy": "cym",
+    "da": "dan",
+    "de": "deu",
+    "el": "ell",
+    "es": "spa",
+    "et": "est",
+    "fa": "pes",
+    "fi": "fin",
+    "fr": "fra",
+    "ga": "gle",
+    "hi": "hin",
+    "hu": "hun",
+    "id": "ind",
+    "it": "ita",
+    "ja": "jpn",
+    "ka": "kat",
+    "ky": "kir",
+    "lg": "lug",
+    "lt": "lit",
+    "lv": "lvs",
+    "mn": "khk",
+    "mr": "mar",
+    "mt": "mlt",
+    "nl": "nld",
+    "pa": "pan",
+    "pl": "pol",
+    "pt": "por",
+    "ro": "ron",
+    "ru": "rus",
+    "sk": "slk",
+    "sl": "slv",
+    "sv": "swe",
+    "sw": "swh",
+    "ta": "tam",
+    "th": "tha",
+    "tr": "tur",
+    "uk": "ukr",
+    "ur": "urd",
+    "uz": "uzn",
+    "vi": "vie",
+    "yue": "yue",
+    "af": "afr",
+    "is": "isl",
+    "lb": "ltz",
+    "no": "nob",
+    "gl": "glg",
+    "kea": "kea",
+    "bs": "bos",
+    "hr": "hrv",
+    "mk": "mkd",
+    "sr": "srp",
+    "hy": "hye",
+    "az": "azj",
+    "kk": "kaz",
+    "ko": "kor",
+    "gu": "guj",
+    "kn": "kan",
+    "ne": "npi",
+    "or": "ory",
+    "sd": "snd",
+    "te": "tel",
+    "ceb": "ceb",
+    "jv": "jav",
+    "ms": "zlm",
+    "ml": "mal",
+    "tl": "tgl",
+    "tl": "fil",
+    "my": "mya",
+    "km": "khm",
+    "lo": "lao",
+    "he": "heb",
+    "ps": "pbt",
+    "tg": "tgk",
+    "am": "amh",
+    "ig": "ibo",
+    "ln": "lin",
+    "nso": "nso",
+    "so": "som",
+    "xh": "xho",
+    "yo": "yor",
+    "zu": "zul",
+    "kam": "kam",
+    "luo": "luo",
+    "ny": "nya",
+    "om": "gaz",
+    "sn": "sna",
+    "umb": "umb",
+    "ga-IE": "gle",
+    "pa": "pan",
+    "sv": "swe",
+    "ast": "ast",
+    "ff": "ful",
+    "mi": "mri",
+    "ha": "hau",
+    "wo": "wol",
+    "oc": "oci",
+    "ilo": "ilo",
+    "ba": "bak",
+    "br": "bre",
+    "fy": "fry",
+    "yi": "yid",
+    "tn": "tsn",
+    "gd": "gla",
+    "ht": "hat",
+    "mg": "mlg",
+    "ns": "nso",
+    "si": "sin",
+    "sq": "sqi",
+    "ss": "ssw",
+    "su": "sun",
+    "zh": "cmn",
+    "ab": "abk",
+    "bas": "bas",
+    "cnh": "cnh",
+    "cv": "chv",
+    "dv": "div",
+    "eo": "epo",
+    "eu": "eus",
+    "fy-NL": "fry",
+    "gn": "grn",
+    "hsb": "hsb",
+    "hy": "hye",
+    "ia": "ina",
+    "kab": "kab",
+    "kmr": "kmr",
+    "mdf": "mdf",
+    "mhr": "mhr",
+    "myv": "myv",
+    "nan-tw": "hbl",
+    "ne": "npi",
+    "nn-NO": "nno",
+    "rm-sursilv": "rm-sursilv",
+    "rm-vallader": "rm-vallader",
+    "rw": "kin",
+    "sah": "sah",
+    "sat": "sat",
+    "sc": "srd",
+    "tig": "tig",
+    "tok": "tok",
+    "tt": "tat",
+    "ug": "uig",
+    "vot": "vot",
+    "mrj": "mrj",
+    "skr": "skr",
+    "ti": "tir",
+    "tw": "twi",
+    "bo": "bod",
+    "fo": "fao",
+    "gv": "glv",
+    "haw": "haw",
+    "la": "lat",
+    "sa": "san",
+    "sco": "sco",
+    "war": "war",
+    "he": "heb",
+    "jw": "jav",
+    "nn": "nno",
+    "tk": "tuk",
+}
+LANG3_LANG2 = {v: k for k, v in LANG2_LANG3.items()}
+

+ 8 - 53
scripts/m4t/evaluate/README.md

@@ -1,63 +1,18 @@
 # Evaluating SeamlessM4T models
-Refer to the [inference tutorial](../predict/README.md) for the supported tasks to run inference with SeamlessM4T models.
+Refer to the [inference tutorial](../predict/README.md) for the supported tasks and language directions to run inference with SeamlessM4T models.
 
 ## Quick start:
+We use SACREBLEU library for computing BLEU scores and [JiWER library](https://github.com/jitsi/jiwer) is used to compute these CER and WER scores. 
+
 Evaluation can be run with the CLI, from the root directory of the repository.
 
-The model can be specified with `--model_name`: `seamlessM4T_v2_large` or `seamlessM4T_large` or `seamlessM4T_medium`
+The model can be specified with `--model_name`: `seamlessM4T_v2_large` or `seamlessM4T_large` or `seamlessM4T_medium` 
 
 ```bash
 m4t_evaluate <path_to_data_tsv_file> <task_name> <tgt_lang> --output_path <path_to_save_evaluation_output> --ref_field <ref_field_name> --audio_root_dir <path_to_audio_root_directory>
 ```
+## Note
+1. We use raw (unnormalized) references to compute BLEU scores for S2TT, T2TT tasks.
+2. For ASR task, src_lang needs to be passed as <tgt_lang> 
+3. `--src_lang` arg needs to be specified to run evaluation for T2TT task
 
-### S2TT
-If provided a test_fleurs/dev_fleurs data tsv file, we parse through every example in the file, run model inference and save the first pass text generations and the computed first pass (S2TT) BLEU.
-
-### S2ST and T2ST
-Additionally from S2TT, we also save the unit generations, run vocoder inference to generate the translated audio waveforms and save the .wav files to a directory.
-
-To measure the quality of the translated speech outputs, the audios are first transcribed using Whisper ASR model and BLEU score is computed on these ASR transcriptions comparing them with the ground truth text references.
-
-Whisper large-v2 is used for non-English target languages and medium.en trained on English-only data is used for English due to its superior performance.
-
-```python
-import whisper
-
-model = whisper.load_model('medium.en')
-model = whisper.load_model('large-v2')
-```
-To reproduce the whisper transcriptions and thereby the ASR-BLEU scores, greedy decoding is used with a preset temperature value of 0. Target language information is also passed to the whisper model.
-
-```python
-prediction = model.transcribe(<AUDIO_PATH>, language=<LANGUAGE>, temperature=0, beam_size=1)["text"]
-```
-
-Whisper-normalizer is run on the ground truth <REFERENCES> and the model generated <PREDICTIONS>. ASR-BLEU scores are computed using sacrebleu following the same tokenization as described for S2TT.
-
-```python
-from whisper_normalizer.basic import BasicTextNormalizer
-
-normalizer = EnglishTextNormalizer() ## To be used for English
-normalizer = BasicTextNormalizer()  ## For non-English directions
-```
-
-### T2TT
-Similar to S2TT, raw (unnormalized) references and predictions are used to compute the chrF++ scores for text-to-text translation.
-
-```python
-import sacrebleu
-
-chrf_metric = sacrebleu.CHRF(word_order=2)
-chrf_score = chrf_metric.corpus_score(<REFERENCES>,<PREDICTIONS>)
-```
-
-### ASR
-Similar to Whisper, character-level error rate (CER) metric is used for Mandarin Chinese (cmn), Japanese (jpn), Thai (tha), Lao (lao), and Burmese (mya) languages. Word-level error rate (WER) metric is used for the remaining languages. Whisper-normalizer is applied on the ground truth <REFERENCES> and the model generated <PREDICTIONS>. [JiWER library](https://github.com/jitsi/jiwer) is used to compute these CER and WER scores.
-
-```python
-import jiwer
-
-wer = WER(<REFERENCES>,<PREDICTIONS>) ## WER
-cer = CER(<REFERENCES>,<PREDICTIONS>) ## CER
-
-```

+ 62 - 36
scripts/m4t/evaluate/evaluate.py

@@ -12,12 +12,12 @@ import subprocess
 import torch
 import torchaudio
 
+from argparse import Namespace
 from dataclasses import dataclass
 from pathlib import Path
 from torch import Tensor
 from tqdm import tqdm
-from typing import List, Optional, Tuple
-from sacrebleu.metrics import BLEU
+from typing import List, Optional, Tuple, Dict
 
 from fairseq2.data import Collater, DataPipeline, FileMapper
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
@@ -33,6 +33,9 @@ from seamless_communication.models.inference import (
     Translator,
 )
 from seamless_communication.models.unity import load_unity_text_tokenizer
+from scripts.eval_utils.compute_metrics import (
+    compute_quality_metrics,
+)
 
 logging.basicConfig(
     level=logging.INFO,
@@ -218,7 +221,10 @@ def adjust_output_for_corrupted_inputs(
 
 
 def run_eval(
-    translator: Translator, text_tokenizer: TextTokenizer, ctx: EvalContext
+    translator: Translator,
+    text_tokenizer: TextTokenizer,
+    ctx: EvalContext,
+    whisper_model_name: Optional[str] = None,
 ) -> None:
     pipeline = build_data_pipeline(ctx, text_tokenizer)
 
@@ -232,17 +238,18 @@ def run_eval(
         waveforms_dir = output_path / f"waveform_{ctx.data_file.stem}"
         waveforms_dir.mkdir(parents=True, exist_ok=True)
 
-    hyps = []
-    refs = []
-
-    with open(
-        output_path / f"text_output-{ctx.data_file.stem}.txt", "w"
-    ) as hyp_file, open(
-        output_path / f"unit_output-{ctx.data_file.stem}.txt", "w"
+    model_outputs_tsv = output_path / f"model-outputs-{ctx.data_file.stem}.txt"
+    unit_outputs_tsv = output_path / f"unit_output-{ctx.data_file.stem}.txt"
+    with open(model_outputs_tsv, "w") as hyp_file, open(
+        unit_outputs_tsv, "w"
     ) if ctx.output_modality == Modality.SPEECH else contextlib.nullcontext(
         itertools.repeat(None)
     ) as unit_file:
         sample_id = 0
+        if ctx.output_modality == Modality.SPEECH:
+            hyp_file.write(f"ref_tgt_text\tpred_tgt_text\tpred_tgt_audio\n")
+        else:
+            hyp_file.write(f"ref_tgt_text\tpred_tgt_text\n")
         for example in pipeline:
             valid_sequences: Optional[Tensor] = None
             if ctx.input_modality == Modality.SPEECH:
@@ -262,7 +269,10 @@ def run_eval(
 
             # Skip performing inference when the input is entirely corrupted.
             if src["seqs"].numel() > 0:
-                (text_output, speech_output,) = translator.predict(
+                (
+                    text_output,
+                    speech_output,
+                ) = translator.predict(
                     src,
                     ctx.task,
                     ctx.target_lang,
@@ -279,56 +289,58 @@ def run_eval(
                     speech_output = None
 
             if valid_sequences is not None and not valid_sequences.all():
-                (text_output, speech_output,) = adjust_output_for_corrupted_inputs(
+                (
+                    text_output,
+                    speech_output,
+                ) = adjust_output_for_corrupted_inputs(
                     valid_sequences,
                     text_output,
                     speech_output,
                 )
 
-            hyps += [str(s) for s in text_output]
-            refs += [str(s) for s in example[ctx.ref_field]]
+            hyps = [str(s) for s in text_output]
+            refs = [str(s) for s in example[ctx.ref_field]]
 
             for i in range(len(text_output)):
                 t = text_output[i]
-                hyp_file.write(f"{t}\n")
-
                 if ctx.output_modality == Modality.SPEECH:
                     assert speech_output is not None
                     u = speech_output.units[i]
                     str_units = [str(i) for i in u]
                     unit_file.write(" ".join(str_units) + "\n")
+                    wav_fp = str(waveforms_dir / f"{sample_id}_pred.wav")
                     torchaudio.save(
-                        waveforms_dir / f"{sample_id}_pred.wav",
+                        wav_fp,
                         speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
                         sample_rate=speech_output.sample_rate,
                     )
+                    hyp_file.write(f"{refs[i]}\t{hyps[i]}\t{wav_fp}\n")
+                else:
+                    hyp_file.write(f"{refs[i]}\t{hyps[i]}\n")
 
                 sample_id += 1
                 progress_bar.update(1)
 
     progress_bar.close()
-    logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
-
-    assert len(hyps) == len(refs)
-    if len(hyps) > 0:
-        if ctx.target_lang in ("cmn", "jpn", "lao", "mya", "tha"):
-            tokenizer = "char"
-        else:
-            tokenizer = "13a"
-
-        bleu = BLEU(tokenize=tokenizer)
-        score = bleu.corpus_score(hyps, [refs])
-        bleu_filename = output_path / f"{ctx.data_file.stem}_text_output_bleu.json"
-        with open(bleu_filename, "w") as f:
-            f.write(score.format(signature=str(bleu.get_signature()), is_json=True))
-        logger.info(score.format(signature=bleu.get_signature()))
+    logger.info(f"Processed {sample_id} samples")
+
+    compute_quality_metrics(
+        output_manifest_tsv_path=model_outputs_tsv,
+        output_dir=output_path,
+        tgt_lang=ctx.target_lang,
+        task=ctx.task,
+        device=ctx.device,
+        whisper_model_name=whisper_model_name,
+    )
 
 
-def main():
+def main(optional_args: Optional[Dict] = None):
     parser = argparse.ArgumentParser(
         description="M4T evaluation for tasks supported by Translator."
     )
-    parser.add_argument("data_file", type=str, help="Data file (.tsv) to be evaluated.")
+    parser.add_argument(
+        "--data_file", type=str, help="Data file (.tsv) to be evaluated."
+    )
 
     parser = add_inference_arguments(parser)
     parser.add_argument(
@@ -349,7 +361,21 @@ def main():
         help="Reference target text field to compute the BLEU score against.",
         default="tgt_text",
     )
-    args = parser.parse_args()
+    parser.add_argument(
+        "--whisper_model_name",
+        type=str,
+        help="Whisper model to be used for ASR-BLEU scoring",
+        default="large",
+    )
+    args, unknown = parser.parse_known_args()
+    default_args = vars(args)
+    default_args.update(optional_args) if optional_args else default_args
+    args = Namespace(**default_args)
+
+    if not args.data_file or not args.task or not args.tgt_lang:
+        raise Exception(
+            "Please provide required arguments for evaluation - data_file, task, tgt_lang"
+        )
 
     input_modality, output_modality = Translator.get_modalities_from_task_str(args.task)
 
@@ -407,7 +433,7 @@ def main():
     # fmt: on
     logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")
 
-    run_eval(translator, text_tokenizer, ctx)
+    run_eval(translator, text_tokenizer, ctx, args.whisper_model_name)
 
 
 if __name__ == "__main__":

+ 6 - 2
scripts/m4t/predict/predict.py

@@ -26,9 +26,9 @@ logger = logging.getLogger(__name__)
 
 
 def add_inference_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
-    parser.add_argument("task", type=str, help="Task type")
+    parser.add_argument("--task", type=str, help="Task type")
     parser.add_argument(
-        "tgt_lang", type=str, help="Target language to translate/transcribe into."
+        "--tgt_lang", type=str, help="Target language to translate/transcribe into."
     )
     parser.add_argument(
         "--src_lang",
@@ -178,6 +178,10 @@ def main():
 
     parser = add_inference_arguments(parser)
     args = parser.parse_args()
+    if not args.task or not args.tgt_lang:
+        raise Exception(
+            "Please provide required arguments for evaluation -  task, tgt_lang"
+        )
 
     if args.task.upper() in {"S2ST", "T2ST"} and args.output_path is None:
         raise ValueError("output_path must be provided to save the generated audio")