فهرست منبع

Introduce expressivity_predict, and change pretssel_inference to expressivity_evaluate. (#251)

* Changing pretssel_inference to expressivity_evaluate.

* Implement expressivity_predict to run SeamlessExpressive inference.

* Don't hardcode --model_name, --vocoder_name in expressivity_predict.

* Revert addition of --gated-model-dir to streaming/evaluate.
Kaushik Ram Sadagopan 1 سال پیش
والد
کامیت
6ab3787931

+ 19 - 12
README.md

@@ -95,20 +95,10 @@ For running S2TT/ASR natively (without Python) using GGML, please refer to [the
 > [!NOTE]
 > Please check the [section](#seamlessexpressive-models) on how to download the model.
 
-Below is the script for efficient batched inference.
+Here’s an example of using the CLI from the root directory to run inference.
 
 ```bash
-export MODEL_DIR="/path/to/SeamlessExpressive/model"
-export TEST_SET_TSV="input.tsv" # Your dataset in a TSV file, with headers "id", "audio"
-export TGT_LANG="spa" # Target language to translate into, options including "fra", "deu", "eng" ("cmn" and "ita" are experimental)
-export OUTPUT_DIR="tmp/" # Output directory for generated text/unit/waveform
-export TGT_TEXT_COL="tgt_text" # The column in your ${TEST_SET_TSV} for reference target text to calcuate BLEU score. You can skip this argument.
-export DFACTOR="1.0" # Duration factor for model inference to tune predicted duration (preddur=DFACTOR*preddur) per each position which affects output speech rate. Greater value means slower speech rate (default to 1.0). See expressive evaluation README for details on duration factor we used.
-python src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py \
-  ${TEST_SET_TSV} --gated-model-dir ${MODEL_DIR} --task s2st --tgt_lang ${TGT_LANG}\
-  --audio_root_dir "" --output_path ${OUTPUT_DIR} --ref_field ${TGT_TEXT_COL} \
-  --model_name seamless_expressivity --vocoder_name vocoder_pretssel \
-  --text_unk_blocking True --duration_factor ${DFACTOR}
+expressivity_predict <path_to_input_audio> --tgt_lang <tgt_lang> --model_name seamless_expressivity --vocoder_name vocoder_pretssel --output--path <path_to_save_audio>
 ```
 
 ### SeamlessStreaming and Seamless Inference
@@ -166,6 +156,23 @@ Please check out above [section](#seamlessexpressive-models) on how to acquire `
 ### SeamlessM4T Evaluation
 To reproduce our results, or to evaluate using the same metrics over your own test sets, please check out the [README here](src/seamless_communication/cli/m4t/evaluate).
 ### SeamlessExpressive Evaluation
+
+Below is the script for efficient batched evaluation.
+
+```bash
+export MODEL_DIR="/path/to/SeamlessExpressive/model"
+export TEST_SET_TSV="input.tsv" # Your dataset in a TSV file, with headers "id", "audio"
+export TGT_LANG="spa" # Target language to translate into, options including "fra", "deu", "eng" ("cmn" and "ita" are experimental)
+export OUTPUT_DIR="tmp/" # Output directory for generated text/unit/waveform
+export TGT_TEXT_COL="tgt_text" # The column in your ${TEST_SET_TSV} for reference target text to calcuate BLEU score. You can skip this argument.
+export DFACTOR="1.0" # Duration factor for model inference to tune predicted duration (preddur=DFACTOR*preddur) per each position which affects output speech rate. Greater value means slower speech rate (default to 1.0). See expressive evaluation README for details on duration factor we used.
+expressivity_evaluate ${TEST_SET_TSV} \
+  --gated-model-dir ${MODEL_DIR} --task s2st --tgt_lang ${TGT_LANG} \
+  --audio_root_dir "" --output_path ${OUTPUT_DIR} --ref_field ${TGT_TEXT_COL} \
+  --model_name seamless_expressivity --vocoder_name vocoder_pretssel \
+  --text_unk_blocking True --duration_factor ${DFACTOR}
+```
+
 Please check out this [README section](docs/expressive/README.md#automatic-evaluation)
 
 ### SeamlessStreaming and Seamless Evaluation

+ 5 - 9
demo/expressive/app.py

@@ -13,7 +13,7 @@ import gradio as gr
 import torch
 import torchaudio
 from fairseq2.assets import InProcAssetMetadataProvider, asset_store
-from fairseq2.data import Collater, SequenceData, VocabularyInfo
+from fairseq2.data import Collater
 from fairseq2.data.audio import (
     AudioDecoder,
     WaveformToFbankConverter,
@@ -23,19 +23,15 @@ from fairseq2.data.audio import (
 from seamless_communication.inference import SequenceGeneratorOptions
 from fairseq2.generation import NGramRepeatBlockProcessor
 from fairseq2.memory import MemoryBlock
-from fairseq2.typing import DataType, Device
 from huggingface_hub import snapshot_download
-from seamless_communication.inference import BatchedSpeechOutput, Translator, SequenceGeneratorOptions
-from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
+from seamless_communication.inference import Translator, SequenceGeneratorOptions
 from seamless_communication.models.unity import (
-    UnitTokenizer,
     load_gcmvn_stats,
-    load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
 )
-from torch.nn import Module
-from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import PretsselGenerator
+from seamless_communication.cli.expressivity.predict.pretssel_generator import PretsselGenerator
 
+from typing import Tuple
 from utils import LANGUAGE_CODE_TO_NAME
 
 DESCRIPTION = """\
@@ -183,7 +179,7 @@ def run(
     input_audio_path: str,
     source_language: str,
     target_language: str,
-) -> tuple[str, str]:
+) -> Tuple[str, str]:
     target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
     source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
 

+ 2 - 0
setup.py

@@ -39,6 +39,8 @@ setup(
             "m4t_finetune=seamless_communication.cli.m4t.finetune.finetune:main",
             "m4t_prepare_dataset=seamless_communication.cli.m4t.finetune.dataset:main",
             "m4t_audio_to_units=seamless_communication.cli.m4t.audio_to_units.audio_to_units:main",
+            "expressivity_evaluate=seamless_communication.cli.expressivity.evaluate.evaluate:main",
+            "expressivity_predict=seamless_communication.cli.expressivity.predict.predict:main",
             "streaming_evaluate=seamless_communication.cli.streaming.evaluate:main",
         ],
     },

+ 5 - 0
src/seamless_communication/cli/expressivity/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.

+ 5 - 0
src/seamless_communication/cli/expressivity/data/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.

+ 3 - 4
src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py → src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -22,11 +22,10 @@ from fairseq2.data.audio import (
 )
 from fairseq2.data.text import StrSplitter, read_text
 from fairseq2.typing import DataType, Device
-from sacrebleu.metrics import BLEU  # type: ignore[attr-defined]
 from torch import Tensor
 from tqdm import tqdm
 
-from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import (
+from seamless_communication.cli.expressivity.predict.pretssel_generator import (
     PretsselGenerator,
 )
 from seamless_communication.cli.m4t.evaluate.evaluate import (
@@ -121,7 +120,7 @@ def main() -> None:
     )
 
     parser = add_inference_arguments(parser)
-    param = parser.add_argument(
+    parser.add_argument(
         "--gated-model-dir",
         type=Path,
         required=False,
@@ -246,7 +245,7 @@ def main() -> None:
                 prosody_encoder_input = example[args.audio_field]["data"]["gcmvn_fbank"]
                 text_output, unit_output = translator.predict(
                     src,
-                    args.task,
+                    "s2st",
                     args.tgt_lang,
                     src_lang=args.src_lang,
                     text_generation_opts=text_generation_opts,

+ 0 - 3
src/seamless_communication/cli/expressivity/evaluate/run_asr_bleu.py

@@ -5,12 +5,9 @@
 # MIT_LICENSE file in the root directory of this source tree.
 
 from fire import Fire
-import pandas as pd
-import csv
 from seamless_communication.cli.eval_utils.compute_metrics import (
     compute_quality_metrics,
 )
-import os
 from fairseq2.typing import Device
 from pathlib import Path
 

+ 5 - 0
src/seamless_communication/cli/expressivity/predict/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.

+ 179 - 0
src/seamless_communication/cli/expressivity/predict/predict.py

@@ -0,0 +1,179 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# MIT_LICENSE file in the root directory of this source tree.
+
+import argparse
+import logging
+import torch
+import torchaudio
+from pathlib import Path
+
+from fairseq2.data import SequenceData
+from fairseq2.data.audio import WaveformToFbankConverter
+
+from seamless_communication.cli.expressivity.predict.pretssel_generator import (
+    PretsselGenerator,
+)
+from seamless_communication.cli.m4t.predict import (
+    add_inference_arguments,
+    set_generation_opts,
+)
+from seamless_communication.inference import Translator
+from seamless_communication.models.unity import (
+    load_gcmvn_stats,
+    load_unity_unit_tokenizer,
+)
+from seamless_communication.store import add_gated_assets
+
+
+AUDIO_SAMPLE_RATE = 16000
+
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
+
+def remove_prosody_tokens_from_text(text: str) -> str:
+    # filter out prosody tokens, there is only emphasis '*', and pause '='
+    text = text.replace("*", "").replace("=", "")
+    text = " ".join(text.split())
+    return text
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference.")
+    parser.add_argument("input", type=str, help="Audio WAV file path.")
+
+    parser = add_inference_arguments(parser)
+    parser.add_argument(
+        "--gated-model-dir",
+        type=Path,
+        required=False,
+        help="SeamlessExpressive model directory.",
+    )
+    parser.add_argument(
+        "--duration_factor",
+        type=float,
+        help="The duration factor for NAR T2U model.",
+        default=1.0,
+    )
+    args = parser.parse_args()
+
+    if not args.tgt_lang or args.output_path is None:
+        raise Exception(
+            "--tgt_lang, --output_path must be provided for SeamlessExpressive inference."
+        )
+        
+    if args.gated_model_dir:
+        add_gated_assets(args.gated_model_dir)
+    
+    if torch.cuda.is_available():
+        device = torch.device("cuda:0")
+        dtype = torch.float16
+    else:
+        device = torch.device("cpu")
+        dtype = torch.float32
+
+    logger.info(f"Running inference on {device=} with {dtype=}.")
+
+    unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
+    
+    translator = Translator(
+        args.model_name,
+        vocoder_name_or_card=None,
+        device=device,
+        dtype=dtype,
+    )
+
+    pretssel_generator = PretsselGenerator(
+        args.vocoder_name,
+        vocab_info=unit_tokenizer.vocab_info,
+        device=device,
+        dtype=dtype,
+    )
+
+    fbank_extractor = WaveformToFbankConverter(
+        num_mel_bins=80,
+        waveform_scale=2**15,
+        channel_last=True,
+        standardize=False,
+        device=device,
+        dtype=dtype,
+    )
+
+    _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
+    gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
+    gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
+
+    wav, sample_rate = torchaudio.load(args.input)
+    wav = torchaudio.functional.resample(wav, orig_freq=sample_rate, new_freq=16_000)
+    wav = wav.transpose(0, 1)
+
+    data = fbank_extractor(
+        {
+            "waveform": wav,
+            "sample_rate": 16000,
+        }
+    )
+    fbank = data["fbank"]
+    gcmvn_fbank = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
+    std, mean = torch.std_mean(fbank, dim=0)
+    fbank = fbank.subtract(mean).divide(std)
+
+    src = SequenceData(
+        seqs=fbank.unsqueeze(0),
+        seq_lens=torch.LongTensor([fbank.shape[0]]),
+        is_ragged=False,
+    )
+    src_gcmvn = SequenceData(
+        seqs=gcmvn_fbank.unsqueeze(0),
+        seq_lens=torch.LongTensor([gcmvn_fbank.shape[0]]),
+        is_ragged=False,
+    )
+
+    text_generation_opts, unit_generation_opts = set_generation_opts(args)
+
+    logger.info(f"{text_generation_opts=}")
+    logger.info(f"{unit_generation_opts=}")
+    logger.info(
+        f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
+    )
+
+    text_output, unit_output = translator.predict(
+        src,
+        "s2st",
+        args.tgt_lang,
+        text_generation_opts=text_generation_opts,
+        unit_generation_opts=unit_generation_opts,
+        unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
+        duration_factor=args.duration_factor,
+        prosody_encoder_input=src_gcmvn,
+    )
+
+    assert unit_output is not None
+    speech_output = pretssel_generator.predict(
+        unit_output.units,
+        tgt_lang=args.tgt_lang,
+        prosody_encoder_input=src_gcmvn,
+    )
+
+    logger.info(f"Saving expressive translated audio in {args.tgt_lang}")
+    torchaudio.save(
+        args.output_path,
+        speech_output.audio_wavs[0][0].to(torch.float32).cpu(),
+        sample_rate=speech_output.sample_rate,
+    )
+
+    text_out = remove_prosody_tokens_from_text(str(text_output[0]))
+
+    logger.info(f"Translated text in {args.tgt_lang}: {text_out}")
+
+
+if __name__ == "__main__":
+    main()

+ 0 - 0
src/seamless_communication/cli/expressivity/evaluate/pretssel_inference_helper.py → src/seamless_communication/cli/expressivity/predict/pretssel_generator.py


+ 2 - 2
src/seamless_communication/cli/streaming/README.md

@@ -39,9 +39,9 @@ streaming_evaluate --task s2st --data-file <path_to_data_tsv_file> --audio-root-
 The Seamless model is an unified model for streaming expressive speech-to-speech tranlsation. Use the `--expressive` arg for running evaluation of this unified model.
 
 ```bash
-streaming_evaluate --task s2st --data-file <path_to_data_tsv_file> --audio-root-dir <path_to_audio_root_directory> --output <path_to_evaluation_output_directory> --tgt-lang <3_letter_lang_code> --expressive
+streaming_evaluate --task s2st --data-file <path_to_data_tsv_file> --audio-root-dir <path_to_audio_root_directory> --output <path_to_evaluation_output_directory> --tgt-lang <3_letter_lang_code> --expressive --gated-model-dir <path_to_vocoder_checkpoints_dir>
 ```
 
-The Seamless model uses `vocoder_pretssel` which is a 24KHz version (`vocoder_pretssel`) by default. In the current version of our paper, we use 16KHz version (`vocoder_pretssel_16khz`) for the evaluation , so in order to reproduce those results please add this arg to the above command: `--vocoder-name vocoder_pretssel_16khz`.
+The Seamless model uses `vocoder_pretssel` which is a 24KHz version (`vocoder_pretssel`) by default. In the current version of our paper, we use 16KHz version (`vocoder_pretssel_16khz`) for the evaluation, so in order to reproduce those results please add this arg to the above command: `--vocoder-name vocoder_pretssel_16khz`.
 
 `vocoder_pretssel` or `vocoder_pretssel_16khz` checkpoints are gated, please check out [this section](/README.md#seamlessexpressive-models) to acquire these checkpoints. Also, make sure to add `--gated-model-dir <path_to_vocoder_checkpoints_dir>`

+ 3 - 1
src/seamless_communication/cli/streaming/evaluate.py

@@ -8,8 +8,9 @@ import argparse
 import logging
 
 from fairseq2.assets import asset_store, download_manager
+
 from seamless_communication.cli.streaming.scorers.seamless_quality_scorer import (
-    SeamlessQualityScorer,
+    SeamlessQualityScorer as SeamlessQualityScorer,
 )
 from seamless_communication.streaming.agents.seamless_s2st import SeamlessS2STAgent
 from seamless_communication.streaming.agents.seamless_streaming_s2st import (
@@ -18,6 +19,7 @@ from seamless_communication.streaming.agents.seamless_streaming_s2st import (
 from seamless_communication.streaming.agents.seamless_streaming_s2t import (
     SeamlessStreamingS2TAgent,
 )
+
 from simuleval.cli import evaluate
 
 logging.basicConfig(

+ 1 - 1
src/seamless_communication/models/unit_extractor/unit_extractor.py

@@ -53,7 +53,7 @@ class UnitExtractor(nn.Module):
         assert isinstance(wav2vec2_model, Wav2Vec2Model)
         self.model = Wav2Vec2LayerOutputModel(wav2vec2_model)
         self.decode_audio = AudioDecoder(dtype=torch.float32, device=device)
-        self.collate = Collater(pad_value=2, pad_to_multiple=2)
+        self.collate = Collater(pad_value=1, pad_to_multiple=2)
         self.kmeans_model = KmeansModel(kmeans_uri, device, dtype)
         self.device = device
         self.dtype = dtype

+ 0 - 2
src/seamless_communication/store.py

@@ -6,8 +6,6 @@
 
 from pathlib import Path
 
-import torch
-
 from fairseq2.assets import InProcAssetMetadataProvider, asset_store
 
 

+ 1 - 1
src/seamless_communication/streaming/agents/pretssel_vocoder.py

@@ -144,7 +144,7 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):  # type: ign
 
     @classmethod
     def add_args(cls, parser: ArgumentParser) -> None:
-        param = parser.add_argument(
+        parser.add_argument(
             "--gated-model-dir",
             type=Path,
             required=False,