Эх сурвалжийг харах

Add support for gated models (#208)

Can Balioglu 1 жил өмнө
parent
commit
e430d557c6

+ 7 - 3
README.md

@@ -79,17 +79,21 @@ Please refer to the [inference README](src/seamless_communication/cli/m4t/predic
 For running S2TT/ASR natively (without Python) using GGML, please refer to unity.cpp section below.
 
 ### SeamlessExpressive Inference
-Below are the script for efficient batched inference.
+> [!NOTE]
+> Please check the [section](#seamlessexpressive-models) on how to download the model.
+
+Below is the script for efficient batched inference.
 
 ```bash
+export MODEL_DIR="/path/to/SeamlessExpressive/model"
 export TEST_SET_TSV="input.tsv" # Your dataset in a TSV file, with headers "id", "audio"
 export TGT_LANG="spa" # Target language to translate into, options including "fra", "deu", "eng" ("cmn" and "ita" are experimental)
 export OUTPUT_DIR="tmp/" # Output directory for generated text/unit/waveform
 export TGT_TEXT_COL="tgt_text" # The column in your ${TEST_SET_TSV} for reference target text to calcuate BLEU score. You can skip this argument.
 export DFACTOR="1.0" # Duration factor for model inference to tune predicted duration (preddur=DFACTOR*preddur) per each position which affects output speech rate. Greater value means slower speech rate (default to 1.0). See expressive evaluation README for details on duration factor we used.
 python src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py \
-  ${TEST_SET_TSV} --task s2st --tgt_lang ${TGT_LANG} --audio_root_dir "" \
-  --output_path ${OUTPUT_DIR} --ref_field ${TGT_TEXT_COL} \
+  ${TEST_SET_TSV} --model-dir ${MODEL_DIR} --task s2st --tgt_lang ${TGT_LANG}\
+  --audio_root_dir "" --output_path ${OUTPUT_DIR} --ref_field ${TGT_TEXT_COL} \
   --model_name seamless_expressivity --vocoder_name vocoder_pretssel \
   --unit_generation_beam_size 1 --duration_factor ${DFACTOR}
 ```

+ 11 - 1
src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py

@@ -42,6 +42,7 @@ from seamless_communication.models.unity import (
     load_gcmvn_stats,
     load_unity_unit_tokenizer,
 )
+from seamless_communication.store import add_gated_assets
 
 logging.basicConfig(
     level=logging.INFO,
@@ -109,12 +110,18 @@ def build_data_pipeline(
 
 
 def main() -> None:
-    parser = argparse.ArgumentParser(description="Running PretsselModel inference")
+    parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference")
     parser.add_argument(
         "data_file", type=Path, help="Data file (.tsv) to be evaluated."
     )
 
     parser = add_inference_arguments(parser)
+    param = parser.add_argument(
+        "--gated-model-dir",
+        type=Path,
+        required=False,
+        help="SeamlessExpressive model directory.",
+    )
     parser.add_argument(
         "--batch_size",
         type=int,
@@ -147,6 +154,9 @@ def main() -> None:
     )
     args = parser.parse_args()
 
+    if args.gated_model_dir:
+        add_gated_assets(args.gated_model_dir)
+
     if torch.cuda.is_available():
         device = torch.device("cuda:0")
         dtype = torch.float16

+ 1 - 1
src/seamless_communication/cli/streaming/README.md

@@ -44,4 +44,4 @@ streaming_evaluate --task s2st --data-file <path_to_data_tsv_file> --audio-root-
 
 The Seamless model uses `vocoder_pretssel` which is a 24KHz version (`vocoder_pretssel`) by default. In the current version of our paper, we use 16KHz version (`vocoder_pretssel_16khz`) for the evaluation , so in order to reproduce those results please add this arg to the above command: `--vocoder-name vocoder_pretssel_16khz`.
 
-Also, to acquire `vocoder_pretssel` or `vocoder_pretssel_16khz` checkpoints, please check out [this section](../../README.md#seamlessexpressive-models).
+`vocoder_pretssel` or `vocoder_pretssel_16khz` checkpoints are gated, please check out [this section](/README.md#seamlessexpressive-models) to acquire these checkpoints. Also, make sure to add `--gated-model-dir <path_to_vocoder_checkpoints_dir>`

+ 32 - 0
src/seamless_communication/store.py

@@ -0,0 +1,32 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from pathlib import Path
+
+import torch
+
+from fairseq2.assets import InProcAssetMetadataProvider, asset_store
+
+
+def add_gated_assets(model_dir: Path) -> None:
+    asset_store.env_resolvers.append(lambda: "gated")
+
+    gated_metadata = [
+        {
+            "name": "seamless_expressivity@gated",
+            "checkpoint": f"/{model_dir}/m2m_expressive_unity.pt",
+        },
+        {
+            "name": "vocoder_pretssel@gated",
+            "checkpoint": f"/{model_dir}/pretssel_melhifigan_wm.pt",
+        },
+        {
+            "name": "vocoder_pretssel_16khz@gated",
+            "checkpoint": f"/{model_dir}/pretssel_melhifigan_wm-16khz.pt",
+        },
+    ]
+
+    asset_store.metadata_providers.append(InProcAssetMetadataProvider(gated_metadata))

+ 11 - 0
src/seamless_communication/streaming/agents/pretssel_vocoder.py

@@ -7,6 +7,7 @@ from __future__ import annotations
 
 import logging
 from argparse import ArgumentParser, Namespace
+from pathlib import Path
 from typing import Any, Dict, List
 
 import torch
@@ -14,6 +15,7 @@ from fairseq2.assets import asset_store
 from fairseq2.data.audio import WaveformToFbankConverter, WaveformToFbankInput
 from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
 from seamless_communication.models.unity import load_gcmvn_stats
+from seamless_communication.store import add_gated_assets
 from seamless_communication.streaming.agents.common import NoUpdateTargetMixin
 from simuleval.agents import AgentStates, TextToSpeechAgent
 from simuleval.agents.actions import ReadAction, WriteAction
@@ -31,6 +33,9 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):  # type: ign
     def __init__(self, args: Namespace) -> None:
         super().__init__(args)
 
+        if args.gated_model_dir:
+            add_gated_assets(args.gated_model_dir)
+
         logger.info(
             f"Loading the Vocoder model: {args.vocoder_name} on device={args.device}, dtype={args.dtype}"
         )
@@ -129,6 +134,12 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):  # type: ign
 
     @classmethod
     def add_args(cls, parser: ArgumentParser) -> None:
+        param = parser.add_argument(
+            "--gated-model-dir",
+            type=Path,
+            required=False,
+            help="SeamlessExpressive model directory.",
+        )
         parser.add_argument(
             "--vocoder-name",
             type=str,