Răsfoiți Sursa

m4t evaluation script (#57)

* Add an eval script that computes S2T (first-pass) BLEU and generates audio waveform from the vocoder for any data_file.

* translator's predict() now supports a dict input, revamped evaluation script.

* Fix mypy issues in translator, Collater API change, BLEU tokenizer scheme.

* Create and update evaluate/README.md

* Add m4t_evaluate convenience console script.

* Skip corrupted input tensors without breaking eval.

* Apply changes to translator.predict() everywhere.

* Make corrupted input fix compatible for all batch sizes.

* Modified translator.predict() to take in text, unit generation opts, upstream.

* Fix bug for evals with text input, moving corruption logic to input_modality speech.

* Reverting predict API changes, centralizing inference args in predict.py

* Cosmetic changes.
Kaushik Ram Sadagopan 1 an în urmă
părinte
comite
eba940ace3

+ 16 - 7
demo/app.py

@@ -13,6 +13,7 @@ import torchaudio
 from huggingface_hub import hf_hub_download
 from seamless_communication.models.inference.translator import Translator
 
+
 DESCRIPTION = """# SeamlessM4T
 
 [SeamlessM4T](https://github.com/facebookresearch/seamless_communication) is designed to provide high-quality
@@ -321,6 +322,7 @@ def predict(
     target_language: str,
 ) -> tuple[tuple[int, np.ndarray] | None, str]:
     task_name = task_name.split()[0]
+
     source_language_code = (
         LANGUAGE_NAME_TO_CODE[source_language] if source_language else None
     )
@@ -345,17 +347,24 @@ def predict(
         torchaudio.save(input_data, new_arr, sample_rate=int(AUDIO_SAMPLE_RATE))
     else:
         input_data = input_text
-    text_out, wav, sr = translator.predict(
-        input=input_data,
-        task_str=task_name,
-        tgt_lang=target_language_code,
+
+    assert input_data is not None
+    text_output, speech_output = translator.predict(
+        input_data,
+        task_name,
+        target_language_code,
         src_lang=source_language_code,
-        ngram_filtering=True,
+        unit_generation_ngram_filtering=True,
     )
     if task_name in ["S2ST", "T2ST"]:
-        return (sr, wav.cpu().detach().numpy()), text_out
+        assert speech_output is not None
+
+        return (
+            speech_output.sample_rate,
+            speech_output.audio_wavs[0].cpu().detach().numpy(),
+        ), str(text_output[0])
     else:
-        return None, text_out
+        return None, str(text_output[0])
 
 
 def process_s2st_example(

+ 1 - 0
requirements.txt

@@ -1,6 +1,7 @@
 pre-commit
 datasets
 torchaudio
+tqdm
 soundfile
 librosa
 fairseq2==0.2.*

+ 12 - 8
docs/m4t/eval_README.md → scripts/m4t/evaluate/README.md

@@ -1,17 +1,21 @@
-## Evaluation protocols for various SeamlessM4T tasks
-Refer to the [inference tutorial](../../scripts/m4t/predict/README.md) for detailed guidance on how to run inference using SeamlessM4T models. In this tutorial, the evaluation protocol used for all tasks supported by SeamlessM4T is briefly described.
+# Evaluating SeamlessM4T models
+Refer to the [inference tutorial](../predict/README.md) for the supported tasks to run inference with SeamlessM4T models.
 
-### S2TT
-[Sacrebleu library](https://github.com/mjpost/sacrebleu) is used to compute the BLEU scores. To be consistent with Whisper, a character-level (*char*) tokenizer for Mandarin Chinese (cmn), Japanese (jpn), Thai (tha), Lao (lao), and Burmese (mya) is used. The default *13a* tokenizer is used for other languages. Raw (unnormalized) references and predictions are used for computing the scores.
+## Quick start:
+Evaluation can be run with the CLI, from the root directory of the repository.
 
-```python
-import sacrebleu
+The model can be specified with `--model_name`: `seamlessM4T_v2_large` or `seamlessM4T_large` or `seamlessM4T_medium`
 
-bleu_metric = sacrebleu.BLEU(tokenize=<TOKENIZER>)
-bleu_score = bleu_metric.corpus_score(<PREDICTIONS>, [<REFERENCES>])
+```bash
+m4t_evaluate <path_to_data_tsv_file> <task_name> <tgt_lang> --output_path <path_to_save_audio> --ref_field <ref_field_name> --audio_root_dir <path_to_audio_root_directory>
 ```
 
+### S2TT
+If provided a test_fleurs/dev_fleurs data tsv file, we parse through every example in the file, run model inference and save the first pass text generations and the computed first pass (S2TT) BLEU.
+
 ### S2ST and T2ST
+Additionally from S2TT, we also save the unit generations, run vocoder inference to generate the translated audio waveforms and save the .wav files to a directory.
+
 To measure the quality of the translated speech outputs, the audios are first transcribed using Whisper ASR model and BLEU score is computed on these ASR transcriptions comparing them with the ground truth text references.
 
 Whisper large-v2 is used for non-English target languages and medium.en trained on English-only data is used for English due to its superior performance.

+ 5 - 0
scripts/m4t/evaluate/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.

+ 407 - 0
scripts/m4t/evaluate/evaluate.py

@@ -0,0 +1,407 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import contextlib
+import itertools
+import logging
+import subprocess
+import torch
+import torchaudio
+
+from dataclasses import dataclass
+from pathlib import Path
+from torch import Tensor
+from tqdm import tqdm
+from typing import List, Optional, Tuple
+from sacrebleu.metrics import BLEU
+
+from fairseq2.data import Collater, DataPipeline, FileMapper
+from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
+from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
+from fairseq2.data.typing import StringLike
+from fairseq2.generation import SequenceGeneratorOptions
+from fairseq2.typing import Device, DataType
+
+from m4t_scripts.predict import add_inference_arguments, set_generation_opts
+from seamless_communication.models.inference import (
+    BatchedSpeechOutput,
+    Modality,
+    Translator,
+)
+from seamless_communication.models.unity import load_unity_text_tokenizer
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class EvalContext:
+    task: str
+    """String representing the task. Valid choices are
+    "S2ST", "S2TT", "T2ST", "T2TT", "ASR"."""
+
+    input_modality: Modality
+    """The input modality of the task."""
+
+    output_modality: Modality
+    """The output modality of the task."""
+
+    model_name: str
+    """The name of the S2T UnitY model."""
+
+    data_file: Path
+    """The pathname of the test TSV data file."""
+
+    audio_root_dir: Path
+    """The pathname of the directory under which
+    audio files are stored."""
+
+    target_lang: str
+    """The target translation language."""
+
+    source_lang: Optional[str]
+    """The source language."""
+
+    batch_size: int
+    """The batch size for model input."""
+
+    device: Device
+    """The device on which to run inference."""
+
+    dtype: DataType
+    """The data type with which to run inference."""
+
+    output_path: Path
+    """The pathname of the output directory to save
+    the evaluation results."""
+
+    ref_field: str
+    """The reference target text field to compute
+    the BLEU score against."""
+
+    text_generation_opts: SequenceGeneratorOptions
+    """Text generation hyperparameters."""
+
+    unit_generation_opts: Optional[SequenceGeneratorOptions]
+    """Unit generation hyperparameters, not applicable
+    for the NAR T2U decoder."""
+
+    unit_generation_ngram_filtering: bool
+    """If True, removes consecutive repeating ngrams
+    from the decoded unit output."""
+
+
+def count_lines(filename: Path) -> int:
+    result = subprocess.run(["wc", "-l", filename], stdout=subprocess.PIPE)
+    return int(result.stdout.decode().split()[0])
+
+
+def build_data_pipeline(
+    ctx: EvalContext,
+    text_tokenizer: TextTokenizer,
+) -> DataPipeline:
+    with open(ctx.data_file, "r") as f:
+        header = f.readline().strip("\n").split("\t")
+        first_example = f.readline().strip("\n").split("\t")
+
+    # TODO: This will be soon auto-tuned. Right now hand-tuned for devfair.
+    n_parallel = 4
+
+    split_tsv = StrSplitter(names=header)
+
+    pipeline_builder = read_text(ctx.data_file, rtrim=True).skip(1).map(split_tsv)
+
+    if ctx.input_modality == Modality.SPEECH:
+        map_file = FileMapper(root_dir=ctx.audio_root_dir, cached_fd_count=10)
+
+        pipeline_builder.map(map_file, selector="audio", num_parallel_calls=n_parallel)
+
+        decode_audio = AudioDecoder(dtype=torch.float32, device=ctx.device)
+
+        convert_to_fbank = WaveformToFbankConverter(
+            num_mel_bins=80,
+            waveform_scale=2**15,
+            channel_last=True,
+            standardize=True,
+            device=ctx.device,
+            dtype=ctx.dtype,
+        )
+
+        pipeline_builder.map(
+            [decode_audio, convert_to_fbank],
+            selector="audio.data",
+            num_parallel_calls=n_parallel,
+        )
+    else:
+        if "src_lang" in header:
+            source_lang = first_example[header.index("src_lang")]
+            ctx.source_lang = source_lang
+        elif ctx.source_lang is None:
+            raise ValueError(
+                (
+                    "'src_lang' is missing in the data_file"
+                    "header and in the arguments."
+                )
+            )
+
+        token_encoder = text_tokenizer.create_encoder(
+            task="translation", lang=source_lang, mode="source", device=ctx.device
+        )
+        pipeline_builder.map(
+            [token_encoder],
+            selector="src_text",
+            num_parallel_calls=n_parallel,
+        )
+
+    pipeline_builder.bucket(bucket_size=ctx.batch_size)
+
+    collate = Collater(pad_value=0, pad_to_multiple=1)
+
+    pipeline_builder.map(collate, num_parallel_calls=n_parallel)
+
+    pipeline_builder.prefetch(4)
+
+    return pipeline_builder.and_return()
+
+
+def adjust_output_for_corrupted_inputs(
+    valid_sequences: Tensor,
+    text_output: List[StringLike],
+    speech_output: Optional[BatchedSpeechOutput],
+) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
+    adjusted_text_output: List[StringLike] = []
+    adjusted_speech_output: Optional[BatchedSpeechOutput] = None
+
+    if speech_output is not None:
+        assert (
+            len(text_output)
+            == len(speech_output.units)
+            == len(speech_output.audio_wavs)
+        )
+        adjusted_speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
+
+    batch_counter = 0
+    for is_valid in valid_sequences:
+        if is_valid:
+            adjusted_text_output.append(text_output[batch_counter])
+            if speech_output is not None:
+                assert adjusted_speech_output is not None
+                adjusted_speech_output.units.append(speech_output.units[batch_counter])
+                adjusted_speech_output.audio_wavs.append(
+                    speech_output.audio_wavs[batch_counter]
+                )
+            batch_counter += 1
+        else:
+            # For the corrupted inputs, we save the following dummy outputs:
+            # empty string for text, empty list for units, 1 second of silence for audio.
+            adjusted_text_output.append("")
+            if adjusted_speech_output is not None:
+                sample_rate = adjusted_speech_output.sample_rate
+                adjusted_speech_output.units.append([])
+                adjusted_speech_output.audio_wavs.append(
+                    torch.zeros(sample_rate).unsqueeze(0).unsqueeze(0)
+                )
+    return (
+        adjusted_text_output,
+        adjusted_speech_output,
+    )
+
+
+def run_eval(
+    translator: Translator, text_tokenizer: TextTokenizer, ctx: EvalContext
+) -> None:
+    pipeline = build_data_pipeline(ctx, text_tokenizer)
+
+    total_steps = count_lines(ctx.data_file) - 1
+    progress_bar = tqdm(total=total_steps)
+
+    output_path = ctx.output_path / ctx.data_file.stem
+    output_path.mkdir(parents=True, exist_ok=True)
+
+    if ctx.output_modality == Modality.SPEECH:
+        waveforms_dir = output_path / f"waveform_{ctx.data_file.stem}"
+        waveforms_dir.mkdir(parents=True, exist_ok=True)
+
+    hyps = []
+    refs = []
+
+    with open(
+        output_path / f"text_output-{ctx.data_file.stem}.txt", "w"
+    ) as hyp_file, open(
+        output_path / f"unit_output-{ctx.data_file.stem}.txt", "w"
+    ) if ctx.output_modality == Modality.SPEECH else contextlib.nullcontext(
+        itertools.repeat(None)
+    ) as unit_file:
+        sample_id = 0
+        for example in pipeline:
+            valid_sequences: Optional[Tensor] = None
+            if ctx.input_modality == Modality.SPEECH:
+                src = example["audio"]["data"]["fbank"]
+                # Skip corrupted audio tensors.
+                valid_sequences = ~torch.any(
+                    torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
+                )
+                if not valid_sequences.all():
+                    logger.warning(
+                        f"Sample IDs {sample_id} to {sample_id + ctx.batch_size} has some corrupted input."
+                    )
+                    src["seqs"] = src["seqs"][valid_sequences]
+                    src["seq_lens"] = src["seq_lens"][valid_sequences]
+            else:
+                src = example["src_text"]
+
+            # Skip performing inference when the input is entirely corrupted.
+            if src["seqs"].numel() > 0:
+                (text_output, speech_output,) = translator.predict(
+                    src,
+                    ctx.task,
+                    ctx.target_lang,
+                    src_lang=ctx.source_lang,
+                    text_generation_opts=ctx.text_generation_opts,
+                    unit_generation_opts=ctx.unit_generation_opts,
+                    unit_generation_ngram_filtering=ctx.unit_generation_ngram_filtering,
+                )
+            else:
+                text_output = []
+                if ctx.output_modality == Modality.SPEECH:
+                    speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
+                else:
+                    speech_output = None
+
+            if valid_sequences is not None and not valid_sequences.all():
+                (text_output, speech_output,) = adjust_output_for_corrupted_inputs(
+                    valid_sequences,
+                    text_output,
+                    speech_output,
+                )
+
+            hyps += [str(s) for s in text_output]
+            refs += [str(s) for s in example[ctx.ref_field]]
+
+            for i in range(len(text_output)):
+                t = text_output[i]
+                hyp_file.write(f"{t}\n")
+
+                if ctx.output_modality == Modality.SPEECH:
+                    assert speech_output is not None
+                    u = speech_output.units[i]
+                    str_units = [str(i) for i in u]
+                    unit_file.write(" ".join(str_units) + "\n")
+                    torchaudio.save(
+                        waveforms_dir / f"{sample_id}_pred.wav",
+                        speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
+                        sample_rate=speech_output.sample_rate,
+                    )
+
+                sample_id += 1
+                progress_bar.update(1)
+
+    progress_bar.close()
+    logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
+
+    assert len(hyps) == len(refs)
+    if len(hyps) > 0:
+        if ctx.target_lang in ("cmn", "jpn", "lao", "mya", "tha"):
+            tokenizer = "char"
+        else:
+            tokenizer = "13a"
+
+        bleu = BLEU(tokenize=tokenizer)
+        score = bleu.corpus_score(hyps, [refs])
+        bleu_filename = output_path / f"{ctx.data_file.stem}_text_output_bleu.json"
+        with open(bleu_filename, "w") as f:
+            f.write(score.format(signature=str(bleu.get_signature()), is_json=True))
+        logger.info(score.format(signature=bleu.get_signature()))
+
+
+def main():
+    parser = argparse.ArgumentParser(
+        description="M4T evaluation for tasks supported by Translator."
+    )
+    parser.add_argument("data_file", type=str, help="Data file (.tsv) to be evaluated.")
+
+    parser = add_inference_arguments(parser)
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        help="Inference batch size.",
+        default=4,
+    )
+    parser.add_argument(
+        "--audio_root_dir",
+        type=str,
+        help="Root directory for the audio filenames in the data file.",
+        required=True,
+    )
+    parser.add_argument(
+        "--ref_field",
+        type=str,
+        help="Reference target text field to compute the BLEU score against.",
+        default="tgt_text",
+    )
+    args = parser.parse_args()
+
+    input_modality, output_modality = Translator.get_modalities_from_task_str(args.task)
+
+    if torch.cuda.is_available():
+        device = torch.device("cuda:0")
+        dtype = torch.float16
+    else:
+        device = torch.device("cpu")
+        dtype = torch.float32
+
+    text_tokenizer = load_unity_text_tokenizer(args.model_name)
+
+    # TODO: Avoid loading the T2U model, vocoder when the output
+    # modality is text.
+    translator = Translator(
+        args.model_name,
+        args.vocoder_name,
+        device,
+        text_tokenizer=text_tokenizer,
+        dtype=dtype,
+    )
+
+    text_generation_opts, unit_generation_opts = set_generation_opts(args)
+
+    logger.info(f"{text_generation_opts=}")
+    logger.info(f"{unit_generation_opts=}")
+    logger.info(
+        f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
+    )
+
+    # fmt: off
+    ctx = EvalContext(
+        task=args.task,
+        input_modality=input_modality,
+        output_modality=output_modality,
+        model_name=args.model_name,
+        data_file=Path(args.data_file),
+        audio_root_dir=Path(args.audio_root_dir),
+        target_lang=args.tgt_lang,
+        source_lang=args.src_lang,
+        batch_size=args.batch_size,
+        device=device,
+        dtype=dtype,
+        ref_field=args.ref_field,
+        text_generation_opts=text_generation_opts,
+        unit_generation_opts=unit_generation_opts,
+        unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
+        output_path=Path(args.output_path),
+    )
+    # fmt: on
+    logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")
+
+    run_eval(translator, text_tokenizer, ctx)
+
+
+if __name__ == "__main__":
+    main()

+ 8 - 12
scripts/m4t/predict/README.md

@@ -7,8 +7,6 @@ SeamlessM4T models currently support five tasks:
 - Text-to-text translation (T2TT)
 - Automatic speech recognition (ASR)
 
-
-
 ## Quick start:
 Inference is run with the CLI, from the root directory of the repository.
 
@@ -70,16 +68,16 @@ translator = Translator("seamlessM4T_large", "vocoder_36langs", torch.device("cu
 Now `predict()` can be used to run inference as many times on any of the supported tasks.
 
 Given an input audio with `<path_to_input_audio>` or an input text `<input_text>` in `<src_lang>`,
-we can translate into `<tgt_lang>` as follows:
+we first set the `text_generation_opts`, `unit_generation_opts` and then translate into `<tgt_lang>` as follows:
 
 ## S2ST and T2ST:
 
 ```python
 # S2ST
-translated_text, wav, sr = translator.predict(<path_to_input_audio>, "s2st", <tgt_lang>)
+text_output, speech_output = translator.predict(<path_to_input_audio>, "s2st", <tgt_lang>, text_generation_opts=text_generation_opts, unit_generation_opts=unit_generation_opts)
 
 # T2ST
-translated_text, wav, sr = translator.predict(<input_text>, "t2st", <tgt_lang>, src_lang=<src_lang>)
+text_output, speech_output = translator.predict(<input_text>, "t2st", <tgt_lang>, src_lang=<src_lang>, text_generation_opts=text_generation_opts,unit_generation_opts=unit_generation_opts)
 
 ```
 Note that `<src_lang>` must be specified for T2ST.
@@ -87,27 +85,25 @@ Note that `<src_lang>` must be specified for T2ST.
 The generated units are synthesized and the output audio file is saved with:
 
 ```python
-wav, sr = translator.synthesize_speech(<speech_units>, <tgt_lang>)
-
 # Save the translated audio generation.
 torchaudio.save(
     <path_to_save_audio>,
-    wav[0].cpu(),
-    sample_rate=sr,
+    speech_output.audio_wavs[0][0].cpu(),
+    sample_rate=speech_output.sample_rate,
 )
 ```
 ## S2TT, T2TT and ASR:
 
 ```python
 # S2TT
-translated_text, _, _ = translator.predict(<path_to_input_audio>, "s2tt", <tgt_lang>)
+text_output, _ = translator.predict(<path_to_input_audio>, "s2tt", <tgt_lang>, text_generation_opts=text_generation_opts, unit_generation_opts=None)
 
 # ASR
 # This is equivalent to S2TT with `<tgt_lang>=<src_lang>`.
-transcribed_text, _, _ = translator.predict(<path_to_input_audio>, "asr", <src_lang>)
+text_output, _ = translator.predict(<path_to_input_audio>, "asr", <src_lang>, text_generation_opts=text_generation_opts, unit_generation_opts=None)
 
 # T2TT
-translated_text, _, _ = translator.predict(<input_text>, "t2tt", <tgt_lang>, src_lang=<src_lang>)
+text_output, _ = translator.predict(<input_text>, "t2tt", <tgt_lang>, src_lang=<src_lang>, text_generation_opts=text_generation_opts, unit_generation_opts=None)
 
 ```
 Note that `<src_lang>` must be specified for T2TT

+ 5 - 0
scripts/m4t/predict/__init__.py

@@ -3,3 +3,8 @@
 #
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
+
+from m4t_scripts.predict.predict import (
+    add_inference_arguments as add_inference_arguments,
+)
+from m4t_scripts.predict.predict import set_generation_opts as set_generation_opts

+ 154 - 20
scripts/m4t/predict/predict.py

@@ -7,7 +7,15 @@ import argparse
 import logging
 import torch
 import torchaudio
-from seamless_communication.models.inference import Translator
+
+from argparse import Namespace
+from fairseq2.generation import SequenceGeneratorOptions
+from seamless_communication.models.inference import (
+    NGramRepeatBlockProcessor,
+    Translator,
+)
+from typing import Tuple
+
 
 logging.basicConfig(
     level=logging.INFO,
@@ -17,11 +25,7 @@ logging.basicConfig(
 logger = logging.getLogger(__name__)
 
 
-def main():
-    parser = argparse.ArgumentParser(
-        description="M4T inference on supported tasks using Translator."
-    )
-    parser.add_argument("input", type=str, help="Audio WAV file path or text input.")
+def add_inference_arguments(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
     parser.add_argument("task", type=str, help="Task type")
     parser.add_argument(
         "tgt_lang", type=str, help="Target language to translate/transcribe into."
@@ -41,19 +45,138 @@ def main():
     parser.add_argument(
         "--model_name",
         type=str,
-        help="Base model name (`seamlessM4T_medium`, `seamlessM4T_large`)",
-        default="seamlessM4T_large",
+        help=(
+            "Base model name (`seamlessM4T_medium`, "
+            "`seamlessM4T_large`, `seamlessM4T_v2_large`)"
+        ),
+        default="seamlessM4T_v2_large",
+    )
+    parser.add_argument(
+        "--vocoder_name",
+        type=str,
+        help="Vocoder model name",
+        default="vocoder_commercial",
+    )
+    # Text generation args.
+    parser.add_argument(
+        "--text_generation_beam_size",
+        type=int,
+        help="Beam size for incremental text decoding.",
+        default=5,
+    )
+    parser.add_argument(
+        "--text_generation_max_len_a",
+        type=int,
+        help="`a` in `ax + b` for incremental text decoding.",
+        default=1,
+    )
+    parser.add_argument(
+        "--text_generation_max_len_b",
+        type=int,
+        help="`b` in `ax + b` for incremental text decoding.",
+        default=200,
+    )
+    parser.add_argument(
+        "--text_generation_ngram_blocking",
+        type=bool,
+        help=(
+            "Enable ngram_repeat_block for incremental text decoding."
+            "This blocks hypotheses with repeating ngram tokens."
+        ),
+        default=False,
+    )
+    parser.add_argument(
+        "--no_repeat_ngram_size",
+        type=int,
+        help="Size of ngram repeat block for both text & unit decoding.",
+        default=4,
+    )
+    # Unit generation args.
+    parser.add_argument(
+        "--unit_generation_beam_size",
+        type=int,
+        help=(
+            "Beam size for incremental unit decoding"
+            "not applicable for the NAR T2U decoder."
+        ),
+        default=5,
     )
     parser.add_argument(
-        "--vocoder_name", type=str, help="Vocoder name", default="vocoder_36langs"
+        "--unit_generation_max_len_a",
+        type=int,
+        help=(
+            "`a` in `ax + b` for incremental unit decoding"
+            "not applicable for the NAR T2U decoder."
+        ),
+        default=25,
     )
     parser.add_argument(
-        "--ngram-filtering",
+        "--unit_generation_max_len_b",
+        type=int,
+        help=(
+            "`b` in `ax + b` for incremental unit decoding"
+            "not applicable for the NAR T2U decoder."
+        ),
+        default=50,
+    )
+    parser.add_argument(
+        "--unit_generation_ngram_blocking",
         type=bool,
-        help="Enable ngram_repeat_block (currently hardcoded to 4, during decoding) and ngram filtering over units (postprocessing)",
+        help=(
+            "Enable ngram_repeat_block for incremental unit decoding."
+            "This blocks hypotheses with repeating ngram tokens."
+        ),
         default=False,
     )
+    parser.add_argument(
+        "--unit_generation_ngram_filtering",
+        type=bool,
+        help=(
+            "If True, removes consecutive repeated ngrams"
+            "from the decoded unit output."
+        ),
+        default=False,
+    )
+    return parser
+
+
+def set_generation_opts(
+    args: Namespace,
+) -> Tuple[SequenceGeneratorOptions, SequenceGeneratorOptions]:
+    # Set text, unit generation opts.
+    text_generation_opts = SequenceGeneratorOptions(
+        beam_size=args.text_generation_beam_size,
+        soft_max_seq_len=(
+            args.text_generation_max_len_a,
+            args.text_generation_max_len_b,
+        ),
+    )
+    if args.text_generation_ngram_blocking:
+        text_generation_opts.logits_processor = NGramRepeatBlockProcessor(
+            no_repeat_ngram_size=args.no_repeat_ngram_size
+        )
+
+    unit_generation_opts = SequenceGeneratorOptions(
+        beam_size=args.unit_generation_beam_size,
+        soft_max_seq_len=(
+            args.unit_generation_max_len_a,
+            args.unit_generation_max_len_b,
+        ),
+    )
+    if args.unit_generation_ngram_blocking:
+        unit_generation_opts.logits_processor = NGramRepeatBlockProcessor(
+            no_repeat_ngram_size=args.no_repeat_ngram_size
+        )
+    return text_generation_opts, unit_generation_opts
 
+
+def main():
+    parser = argparse.ArgumentParser(
+        description="M4T inference on supported tasks using Translator."
+    )
+    parser.add_argument("input", type=str, help="Audio WAV file path or text input.")
+
+    parser = add_inference_arguments(parser)
     args = parser.parse_args()
 
     if args.task.upper() in {"S2ST", "T2ST"} and args.output_path is None:
@@ -62,29 +185,40 @@ def main():
     if torch.cuda.is_available():
         device = torch.device("cuda:0")
         dtype = torch.float16
-        logger.info(f"Running inference on the GPU in {dtype}.")
     else:
         device = torch.device("cpu")
         dtype = torch.float32
-        logger.info(f"Running inference on the CPU in {dtype}.")
 
-    translator = Translator(args.model_name, args.vocoder_name, device, dtype)
-    translated_text, wav, sr = translator.predict(
+    logger.info(f"Running inference on {device=} with {dtype=}.")
+
+    translator = Translator(args.model_name, args.vocoder_name, device, dtype=dtype)
+
+    text_generation_opts, unit_generation_opts = set_generation_opts(args)
+
+    logger.info(f"{text_generation_opts=}")
+    logger.info(f"{unit_generation_opts=}")
+    logger.info(
+        f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
+    )
+
+    text_output, speech_output = translator.predict(
         args.input,
         args.task,
         args.tgt_lang,
         src_lang=args.src_lang,
-        ngram_filtering=args.ngram_filtering,
+        text_generation_opts=text_generation_opts,
+        unit_generation_opts=unit_generation_opts,
+        unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
     )
 
-    if wav is not None and sr is not None:
+    if speech_output is not None:
         logger.info(f"Saving translated audio in {args.tgt_lang}")
         torchaudio.save(
             args.output_path,
-            wav[0].to(torch.float32).cpu(),
-            sample_rate=sr,
+            speech_output.audio_wavs[0][0].to(torch.float32).cpu(),
+            sample_rate=speech_output.sample_rate,
         )
-    logger.info(f"Translated text in {args.tgt_lang}: {translated_text}")
+    logger.info(f"Translated text in {args.tgt_lang}: {text_output[0]}")
 
 
 if __name__ == "__main__":

+ 1 - 0
setup.py

@@ -63,6 +63,7 @@ setup(
     extras_require={"dev": default_requirements + dev_requirements},
     entry_points={
         "console_scripts": [
+            "m4t_evaluate=m4t_scripts.evaluate.evaluate:main",
             "m4t_predict=m4t_scripts.predict.predict:main",
             "m4t_finetune=m4t_scripts.finetune.finetune:main",
             "m4t_prepare_dataset=m4t_scripts.finetune.dataset:main",

+ 8 - 0
src/seamless_communication/models/inference/__init__.py

@@ -3,4 +3,12 @@
 #
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
+from seamless_communication.models.inference.ngram_repeat_block_processor import (
+    NGramRepeatBlockProcessor,
+)
+from seamless_communication.models.inference.translator import (
+    BatchedSpeechOutput as BatchedSpeechOutput,
+)
+from seamless_communication.models.inference.translator import Modality as Modality
+from seamless_communication.models.inference.translator import Task as Task
 from seamless_communication.models.inference.translator import Translator as Translator

+ 102 - 70
src/seamless_communication/models/inference/translator.py

@@ -3,30 +3,31 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+from dataclasses import dataclass
+from enum import Enum, auto
 from pathlib import Path
-from typing import Callable, Optional, Tuple, Union
+from torch import Tensor
+from typing import Callable, List, Optional, Tuple, Union, cast
 
 import torch
 import torch.nn as nn
+
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater, SequenceData
 from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
-from fairseq2.data.text.text_tokenizer import TextTokenizer
+from fairseq2.data.text import TextTokenizer
 from fairseq2.data.typing import StringLike
 from fairseq2.generation import SequenceToTextOutput, SequenceGeneratorOptions
 from fairseq2.memory import MemoryBlock
 from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
-from torch import Tensor
-from enum import Enum, auto
-from seamless_communication.models.inference.ngram_repeat_block_processor import (
-    NGramRepeatBlockProcessor,
-)
+
 
 from seamless_communication.models.unity import (
     UnitTokenizer,
     UnitYGenerator,
     UnitYModel,
+    UnitYNART2UModel,
     UnitYT2UModel,
     load_unity_model,
     load_unity_text_tokenizer,
@@ -49,12 +50,25 @@ class Modality(Enum):
     TEXT = "text"
 
 
+@dataclass
+class BatchedSpeechOutput:
+    units: List[List[int]]
+    """The batched list of generated units."""
+
+    audio_wavs: List[Tensor]
+    """The batched list of audio waveforms."""
+
+    sample_rate: int = 16000
+    """Sample rate of the audio waveforms."""
+
+
 class Translator(nn.Module):
     def __init__(
         self,
         model_name_or_card: Union[str, AssetCard],
         vocoder_name_or_card: Union[str, AssetCard],
         device: Device,
+        text_tokenizer: Optional[TextTokenizer] = None,
         dtype: DataType = torch.float16,
     ):
         super().__init__()
@@ -66,7 +80,12 @@ class Translator(nn.Module):
         )
         assert isinstance(self.model, UnitYModel)
 
-        self.text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
+        if text_tokenizer is None:
+            self.text_tokenizer: TextTokenizer = load_unity_text_tokenizer(
+                model_name_or_card
+            )
+        else:
+            self.text_tokenizer = text_tokenizer
 
         self.unit_tokenizer: Optional[UnitTokenizer] = None
         if self.model.t2u_model is not None:
@@ -112,40 +131,23 @@ class Translator(nn.Module):
         input_modality: Modality,
         output_modality: Modality,
         tgt_lang: str,
-        ngram_filtering: bool = False,
-        text_max_len_a: int = 1,
-        text_max_len_b: int = 200,
-        unit_max_len_a: Optional[int] = None,
-        unit_max_len_b: Optional[int] = None,
+        text_generation_opts: SequenceGeneratorOptions,
+        unit_generation_opts: Optional[SequenceGeneratorOptions],
+        unit_generation_ngram_filtering: bool = False,
     ) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
-        if unit_max_len_a is None:
-            # need to adjust this for T2ST since src_len is smaller for text.
-            if input_modality == Modality.TEXT:
-                unit_max_len_a = 25
-            else:
-                unit_max_len_a = 1
-
-        text_opts = SequenceGeneratorOptions(
-            beam_size=5, soft_max_seq_len=(text_max_len_a, text_max_len_b)
-        )
-        unit_opts = SequenceGeneratorOptions(
-            beam_size=5, soft_max_seq_len=(unit_max_len_a, unit_max_len_b or 50)
-        )
+        # We disregard unit generations opts for the NAR T2U decoder.
+        if output_modality != Modality.SPEECH or isinstance(
+            model.t2u_model, UnitYNART2UModel
+        ):
+            unit_generation_opts = None
 
-        if ngram_filtering:
-            text_opts.logits_processor = NGramRepeatBlockProcessor(
-                no_repeat_ngram_size=4
-            )
-            unit_opts.logits_processor = NGramRepeatBlockProcessor(
-                no_repeat_ngram_size=4
-            )
         generator = UnitYGenerator(
             model,
             text_tokenizer,
             tgt_lang,
             unit_tokenizer if output_modality == Modality.SPEECH else None,
-            text_opts=text_opts,
-            unit_opts=unit_opts,
+            text_opts=text_generation_opts,
+            unit_opts=unit_generation_opts,
         )
         seqs, padding_mask = get_seqs_and_padding_mask(src)
         return generator(
@@ -153,10 +155,16 @@ class Translator(nn.Module):
             padding_mask,
             input_modality.value,
             output_modality.value,
-            ngram_filtering=ngram_filtering,
+            ngram_filtering=unit_generation_ngram_filtering,
         )
 
-    def get_modalities_from_task(self, task: Task) -> Tuple[Modality, Modality]:
+    @staticmethod
+    def get_modalities_from_task_str(task_str: str) -> Tuple[Modality, Modality]:
+        try:
+            task = Task[task_str.upper()]
+        except KeyError:
+            raise ValueError(f"Unsupported task: {task_str}")
+
         if task == Task.S2ST:
             return Modality.SPEECH, Modality.SPEECH
         # ASR is treated as S2TT with src_lang == tgt_lang
@@ -170,18 +178,20 @@ class Translator(nn.Module):
     @torch.inference_mode()
     def predict(
         self,
-        input: Union[str, Tensor],
+        input: Union[str, Tensor, dict],
         task_str: str,
         tgt_lang: str,
         src_lang: Optional[str] = None,
+        text_generation_opts: SequenceGeneratorOptions = SequenceGeneratorOptions(
+            beam_size=5, soft_max_seq_len=(1, 200)
+        ),
+        unit_generation_opts: Optional[
+            SequenceGeneratorOptions
+        ] = SequenceGeneratorOptions(beam_size=5, soft_max_seq_len=(25, 50)),
         spkr: Optional[int] = -1,
-        ngram_filtering: bool = False,
         sample_rate: int = 16000,
-        text_max_len_a: int = 1,
-        text_max_len_b: int = 200,
-        unit_max_len_a: Optional[int] = None,
-        unit_max_len_b: Optional[int] = None,
-    ) -> Tuple[StringLike, Optional[Tensor], Optional[int]]:
+        unit_generation_ngram_filtering: bool = False,
+    ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
         """
         The main method used to perform inference on all tasks.
 
@@ -194,22 +204,27 @@ class Translator(nn.Module):
             Target language to decode into.
         :param src_lang:
             Source language of input, only required for T2ST, T2TT tasks.
+        :param text_generation_opts:
+            Text generation hyperparameters for incremental decoding.
+        :param unit_generation_opts:
+            Unit generation hyperparameters for incremental decoding.
         :param spkr:
             Speaker id for vocoder.
+        :param unit_generation_ngram_filtering:
+            If True, removes consecutive repeated ngrams
+            from the decoded unit output.
 
         :returns:
-            - Translated text.
-            - Generated output audio waveform corresponding to the translated text.
-            - Sample rate of output audio waveform.
+            - Batched list of Translated text.
+            - Translated BatchedSpeechOutput.
         """
-        try:
-            task = Task[task_str.upper()]
-        except KeyError:
-            raise ValueError(f"Unsupported task: {task_str}")
-
-        input_modality, output_modality = self.get_modalities_from_task(task)
+        input_modality, output_modality = self.get_modalities_from_task_str(task_str)
 
-        if input_modality == Modality.SPEECH:
+        if isinstance(input, dict):
+            assert "seqs" in input
+            assert "seq_lens" in input
+            src = cast(SequenceData, input)
+        elif input_modality == Modality.SPEECH:
             audio = input
             if isinstance(audio, str):
                 with Path(audio).open("rb") as fb:
@@ -235,34 +250,51 @@ class Translator(nn.Module):
             src = self.collate(self.token_encoder(text))
 
         assert isinstance(self.model, UnitYModel)
-        result = self.get_prediction(
+        text_output, unit_output = self.get_prediction(
             self.model,
             self.text_tokenizer,
             self.unit_tokenizer,
             src,
             input_modality,
             output_modality,
-            tgt_lang=tgt_lang,
-            ngram_filtering=ngram_filtering,
-            text_max_len_a=text_max_len_a,
-            text_max_len_b=text_max_len_b,
-            unit_max_len_a=unit_max_len_a,
-            unit_max_len_b=unit_max_len_b,
+            tgt_lang,
+            text_generation_opts,
+            unit_generation_opts,
+            unit_generation_ngram_filtering=unit_generation_ngram_filtering,
         )
 
-        text_out = result[0]
-        unit_out = result[1]
         if output_modality == Modality.TEXT:
-            return text_out.sentences[0], None, None
+            return text_output.sentences, None
         else:
+            assert unit_output is not None
+
             if isinstance(self.model.t2u_model, UnitYT2UModel):
                 # Remove the lang token for AR UnitY since the vocoder doesn't need it
                 # in the unit sequence. tgt_lang is fed as an argument to the vocoder.
-                units = unit_out.units[:, 1:]
+                units = unit_output.units[:, 1:]
             else:
-                units = unit_out.units
+                units = unit_output.units
 
-            # TODO: batch_size set to 1 for now, implement batching.
-            units = units[0].cpu().numpy().tolist()
-            wav_out = self.vocoder(units, tgt_lang, spkr, dur_prediction=True)
-            return text_out.sentences[0], wav_out, sample_rate
+            audio_wavs = []
+            speech_units = []
+            for i in range(len(unit_output.units)):
+                u = units[i].cpu().numpy().tolist()
+                index_of_first_one = next(
+                    (index for index, value in enumerate(u) if value == 1), len(u)
+                )
+                u = u[:index_of_first_one]
+                speech_units.append(u)
+                # TODO: Implement batched inference for vocoder.
+                translated_audio_wav = self.vocoder(
+                    u, tgt_lang, spkr, dur_prediction=True
+                )
+                audio_wavs.append(translated_audio_wav)
+
+            return (
+                text_output.sentences,
+                BatchedSpeechOutput(
+                    units=speech_units,
+                    audio_wavs=audio_wavs,
+                    sample_rate=sample_rate,
+                ),
+            )

+ 3 - 0
src/seamless_communication/models/unity/generator.py

@@ -168,6 +168,9 @@ class UnitYGenerator:
             The type of modality to encode.
         :param output_modality:
             The type of modality to decode.
+        :param ngram_filtering:
+            If True, removes consecutive repeated ngrams
+            from the decoded unit output.
 
         :returns:
             - The output of the text generator.