Browse Source

Seamless Streaming Inference (#169)

* Adding README

* Remove unnecessary try catch block

* Adding ASR task

* Adding  README

* Fix WER scorer

* Change default strip silence behavior

* Fix vocoder sample rate fetching and add disbale standardize audio

* Standardizing audio logic in strip silence

* Addressing comments

* Addressing comments

* miinor README change

* renaming and reorganizing pipeline files
Abinesh Ramakrishnan 1 year ago
parent
commit
732d7bd5a5

+ 15 - 8
src/seamless_communication/cli/eval_utils/compute_metrics.py

@@ -4,23 +4,23 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
+import json
 import logging
 import logging
+from pathlib import Path
+from typing import Optional, Tuple, Union
+
 import pandas as pd
 import pandas as pd
 import whisper
 import whisper
-
 from fairseq2.typing import Device
 from fairseq2.typing import Device
 from jiwer import cer, wer
 from jiwer import cer, wer
-from pathlib import Path
 from sacrebleu.metrics.base import Score, Signature
 from sacrebleu.metrics.base import Score, Signature
 from sacrebleu.metrics.bleu import BLEU
 from sacrebleu.metrics.bleu import BLEU
 from sacrebleu.metrics.chrf import CHRF
 from sacrebleu.metrics.chrf import CHRF
 from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
 from tqdm import tqdm
 from tqdm import tqdm
-from typing import Optional, Tuple, Union
 from whisper import Whisper
 from whisper import Whisper
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
 
 
-
 logging.basicConfig(
 logging.basicConfig(
     level=logging.INFO,
     level=logging.INFO,
     format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
     format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
@@ -190,7 +190,7 @@ def compute_asr_error_rate(
     ref_text_series: pd.Series,
     ref_text_series: pd.Series,
     lang: str,
     lang: str,
     whisper_normalize_text: bool = True,
     whisper_normalize_text: bool = True,
-) -> Tuple[Score, str]:
+) -> Tuple[float, str]:
     """Wraps normalization functions and computes ASR WER/CER score
     """Wraps normalization functions and computes ASR WER/CER score
     Args:
     Args:
         hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
         hyp_text_series (pd.Series): each line contains s2t model prediction or first pass prediction
@@ -348,17 +348,24 @@ def compute_quality_metrics(
         logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
         logger.info(f"{task} ASR Normalized BLEU:\n{asr_bleu_normalized_json}")
 
 
     if task == "ASR":
     if task == "ASR":
-        _, asr_error_rate_signature = compute_asr_error_rate(
+        asr_error_rate, asr_error_rate_signature = compute_asr_error_rate(
             hyp_text_series=df[pred_text_col_name],
             hyp_text_series=df[pred_text_col_name],
             ref_text_series=df[ref_text_col_name],
             ref_text_series=df[ref_text_col_name],
             lang=tgt_lang,
             lang=tgt_lang,
             whisper_normalize_text=whisper_normalize_text_output,
             whisper_normalize_text=whisper_normalize_text_output,
         )
         )
+        d = {
+            "name": "WER",
+            "score": asr_error_rate,
+            "signature": asr_error_rate_signature,
+        }
+        asr_error_rate_json = json.dumps(d, indent=1, ensure_ascii=False)
+
         filename = "asr_error_rate.json"
         filename = "asr_error_rate.json"
 
 
         with open(output_path / filename, "w") as f:
         with open(output_path / filename, "w") as f:
-            f.write(asr_error_rate_signature)
+            f.write(asr_error_rate_json)
 
 
-        logger.info(f"ASR : {asr_error_rate_signature}")
+        logger.info(f"ASR : {asr_error_rate_json}")
 
 
     return filename
     return filename

+ 45 - 0
src/seamless_communication/cli/streaming/README.md

@@ -0,0 +1,45 @@
+# Evaluating SeamlessStreaming and Seamless models
+SeamlessStreaming is the streaming only model and Seamless is the expressive streaming model.
+
+## Quick start:
+
+Evaluation can be run with the `streaming_evaluate` CLI.
+
+We use the `seamless_streaming_unity` for loading the speech encoder and T2U models, and `seamless_streaming_monotonic_decoder` for loading the text decoder for streaming evaluation. This is already set as defaults for the `streaming_evaluate` CLI, but can be overridden using the `--unity-model-name` and  `--monotonic-decoder-model-name` args if required.
+
+Note that the numbers in the paper use single precision floating point format (fp32) for evaluation by setting `--dtype fp32`.
+
+### S2TT:
+Set the task to `s2tt` for evaluating the speech-to-text translation part of the SeamlessStreaming model.
+
+```bash
+streaming_evaluate --task s2tt --data-file <path_to_data_tsv_file> --audio-root-dir <path_to_audio_root_directory> --output <path_to_evaluation_output_directory> --tgt-lang <3_letter_lang_code>
+```
+
+Note: The `--ref-field` can be used to specify the name of the reference column in the dataset.
+
+### ASR:
+Set the task to `asr` for evaluating the automatic speech recognition part of the SeamlessStreaming model. Make sure to pass the source language as the `--tgt-lang` arg.
+
+```bash
+streaming_evaluate --task s2tt --data-file <path_to_data_tsv_file> --audio-root-dir <path_to_audio_root_directory> --output <path_to_evaluation_output_directory> --tgt-lang <3_letter_source_lang_code> 
+```
+
+### S2ST:
+
+#### SeamlessStreaming:
+
+Set the task to `s2st` for evaluating the speech-to-speech translation part of the SeamlessStreaming model. 
+
+```bash
+streaming_evaluate --task s2st --data-file <path_to_data_tsv_file> --audio-root-dir <path_to_audio_root_directory> --output <path_to_evaluation_output_directory> --tgt-lang <3_letter_lang_code>
+```
+
+#### Seamless:
+The Seamless model is an unified model for streaming expressive speech-to-speech tranlsation. Use the `--expressive` arg for running evaluation of this unified model.
+
+```bash
+streaming_evaluate --task s2st --data-file <path_to_data_tsv_file> --audio-root-dir <path_to_audio_root_directory> --output <path_to_evaluation_output_directory> --tgt-lang <3_letter_lang_code> --expressive
+```
+
+Note: In the current version of our paper, we use vocoder_pretssel_16khz for the evaluation , so in order to reproduce those results please add this arg to the above command: `--vocoder-name vocoder_pretssel_16khz`

+ 26 - 26
src/seamless_communication/cli/streaming/evaluate.py

@@ -5,19 +5,29 @@
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
 import argparse
 import argparse
+import logging
 
 
 from fairseq2.assets import asset_store, download_manager
 from fairseq2.assets import asset_store, download_manager
-from seamless_communication.streaming.agents.mma_m4t_s2st import (
-    MonotonicM4TS2STAgent,
-    SeamlessS2STAgent,
-)
 from seamless_communication.cli.streaming.scorers.seamless_quality_scorer import (
 from seamless_communication.cli.streaming.scorers.seamless_quality_scorer import (
     SeamlessQualityScorer,
     SeamlessQualityScorer,
 )
 )
-from seamless_communication.streaming.agents.mma_m4t_s2t import MonotonicM4TS2TAgent
+from seamless_communication.streaming.agents.seamless_s2st import SeamlessS2STAgent
+from seamless_communication.streaming.agents.seamless_streaming_s2st import (
+    SeamlessStreamingS2STAgent,
+)
+from seamless_communication.streaming.agents.seamless_streaming_s2t import (
+    SeamlessStreamingS2TAgent,
+)
 from simuleval.evaluator import build_evaluator
 from simuleval.evaluator import build_evaluator
 from simuleval.utils.agent import EVALUATION_SYSTEM_LIST, build_system_args
 from simuleval.utils.agent import EVALUATION_SYSTEM_LIST, build_system_args
 
 
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
 
 
 def main() -> None:
 def main() -> None:
     parser = argparse.ArgumentParser(
     parser = argparse.ArgumentParser(
@@ -28,7 +38,7 @@ def main() -> None:
 
 
     parser.add_argument(
     parser.add_argument(
         "--task",
         "--task",
-        choices=["s2st", "s2tt"],
+        choices=["s2st", "s2tt", "asr"],
         required=True,
         required=True,
         type=str,
         type=str,
         help="Target language to translate/transcribe into.",
         help="Target language to translate/transcribe into.",
@@ -39,46 +49,33 @@ def main() -> None:
         default=False,
         default=False,
         help="Expressive streaming S2ST inference",
         help="Expressive streaming S2ST inference",
     )
     )
-    parser.add_argument(
-        "--dtype",
-        default="fp16",
-        type=str,
-    )
 
 
     args, _ = parser.parse_known_args()
     args, _ = parser.parse_known_args()
 
 
     model_configs = dict(
     model_configs = dict(
         source_segment_size=320,
         source_segment_size=320,
         device="cuda:0",
         device="cuda:0",
-        dtype=args.dtype,
+        dtype="fp16",
         min_starting_wait_w2vbert=192,
         min_starting_wait_w2vbert=192,
         decision_threshold=0.5,
         decision_threshold=0.5,
         no_early_stop=True,
         no_early_stop=True,
-        max_len_a=1,
-        max_len_b=200,
+        max_len_a=0,
+        max_len_b=100,
     )
     )
 
 
-    if args.dtype == "fp16":
-        model_configs.update(dict(fp16=True))
-
     EVALUATION_SYSTEM_LIST.clear()
     EVALUATION_SYSTEM_LIST.clear()
     eval_configs = dict(quality_metrics="SEAMLESS_QUALITY_SCORER")
     eval_configs = dict(quality_metrics="SEAMLESS_QUALITY_SCORER")
     if args.task == "s2st":
     if args.task == "s2st":
-        model_configs.update(
-            dict(
-                min_unit_chunk_size=50,
-            )
-        )
+        model_configs["min_unit_chunk_size"] = 50
         eval_configs["latency_metrics"] = "StartOffset EndOffset"
         eval_configs["latency_metrics"] = "StartOffset EndOffset"
 
 
         if args.expressive:
         if args.expressive:
             EVALUATION_SYSTEM_LIST.append(SeamlessS2STAgent)
             EVALUATION_SYSTEM_LIST.append(SeamlessS2STAgent)
-            model_configs.update(dict(vocoder_name="vocoder_pretssel"))
         else:
         else:
-            EVALUATION_SYSTEM_LIST.append(MonotonicM4TS2STAgent)
-    elif args.task == "s2tt":
+            EVALUATION_SYSTEM_LIST.append(SeamlessStreamingS2STAgent)
+    elif args.task in ["s2tt", "asr"]:
         assert args.expressive is False, "S2TT inference cannot be expressive."
         assert args.expressive is False, "S2TT inference cannot be expressive."
-        EVALUATION_SYSTEM_LIST.append(MonotonicM4TS2TAgent)
+        EVALUATION_SYSTEM_LIST.append(SeamlessStreamingS2TAgent)
         parser.add_argument(
         parser.add_argument(
             "--unity-model-name",
             "--unity-model-name",
             type=str,
             type=str,
@@ -104,6 +101,9 @@ def main() -> None:
         {**base_config, **model_configs, **eval_configs}, parser
         {**base_config, **model_configs, **eval_configs}, parser
     )
     )
 
 
+    if args.fp16:
+        logger.warn("--fp16 arg will be ignorned, use --dtype instead")
+
     evaluator = build_evaluator(args)
     evaluator = build_evaluator(args)
     evaluator(system)
     evaluator(system)
 
 

+ 15 - 25
src/seamless_communication/cli/streaming/scorers/seamless_quality_scorer.py

@@ -5,23 +5,19 @@
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
 from __future__ import annotations
 from __future__ import annotations
-import pandas
-from fairseq2.typing import Device
-from pathlib import Path
-from typing import Optional
+
 import json
 import json
 from argparse import ArgumentParser, Namespace
 from argparse import ArgumentParser, Namespace
-from typing import Dict
+from pathlib import Path
+from typing import Dict, Optional
 
 
+import pandas
+from fairseq2.typing import Device
+from seamless_communication.cli.eval_utils import compute_quality_metrics
+from simuleval.evaluator.instance import LogInstance
 from simuleval.evaluator.scorers.quality_scorer import (
 from simuleval.evaluator.scorers.quality_scorer import (
-    register_quality_scorer,
     QualityScorer,
     QualityScorer,
-)
-
-from simuleval.evaluator.instance import LogInstance
-
-from seamless_communication.cli.eval_utils import (
-    compute_quality_metrics,
+    register_quality_scorer,
 )
 )
 
 
 
 
@@ -90,19 +86,13 @@ class SeamlessQualityScorer(QualityScorer):  # type: ignore
 
 
     @staticmethod
     @staticmethod
     def add_args(parser: ArgumentParser) -> None:
     def add_args(parser: ArgumentParser) -> None:
-        try:
-            parser.add_argument(
-                "--task", type=str, help="Task to evaluate", required=True
-            )
-            parser.add_argument(
-                "--tgt-lang",
-                type=str,
-                help="Target language to translate/transcribe into.",
-                required=True,
-            )
-        except:
-            pass
-
+        parser.add_argument("--task", type=str, help="Task to evaluate", required=True)
+        parser.add_argument(
+            "--tgt-lang",
+            type=str,
+            help="Target language to translate/transcribe into.",
+            required=True,
+        )
         parser.add_argument(
         parser.add_argument(
             "--whisper-model-name", type=str, help="Whisper model name", default="large"
             "--whisper-model-name", type=str, help="Whisper model name", default="large"
         )
         )

+ 27 - 7
src/seamless_communication/streaming/agents/online_vocoder.py

@@ -5,21 +5,37 @@
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 from __future__ import annotations
 from __future__ import annotations
 
 
+import logging
 from argparse import ArgumentParser, Namespace
 from argparse import ArgumentParser, Namespace
 from typing import Any, Dict
 from typing import Any, Dict
-import torch
 
 
-from seamless_communication.models.vocoder.vocoder import Vocoder
+import torch
+from seamless_communication.models.vocoder.loader import load_vocoder_model
 from simuleval.agents import AgentStates, TextToSpeechAgent
 from simuleval.agents import AgentStates, TextToSpeechAgent
 from simuleval.agents.actions import ReadAction, WriteAction
 from simuleval.agents.actions import ReadAction, WriteAction
 from simuleval.data.segments import SpeechSegment
 from simuleval.data.segments import SpeechSegment
 
 
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
 
 
 class VocoderAgent(TextToSpeechAgent):  # type: ignore
 class VocoderAgent(TextToSpeechAgent):  # type: ignore
-    def __init__(self, vocoder: Vocoder, args: Namespace) -> None:
+    def __init__(self, args: Namespace) -> None:
         super().__init__(args)
         super().__init__(args)
+
+        logger.info(
+            f"Loading the Vocoder model: {args.vocoder_name} on device={args.device}, dtype={args.dtype}"
+        )
+        self.vocoder = load_vocoder_model(
+            args.vocoder_name, device=args.device, dtype=args.dtype
+        )
+        self.vocoder.eval()
+
         self.sample_rate = args.sample_rate
         self.sample_rate = args.sample_rate
-        self.vocoder = vocoder
         self.tgt_lang = args.tgt_lang
         self.tgt_lang = args.tgt_lang
         self.speaker_id = args.vocoder_speaker_id
         self.speaker_id = args.vocoder_speaker_id
 
 
@@ -54,6 +70,12 @@ class VocoderAgent(TextToSpeechAgent):  # type: ignore
 
 
     @classmethod
     @classmethod
     def add_args(cls, parser: ArgumentParser) -> None:
     def add_args(cls, parser: ArgumentParser) -> None:
+        parser.add_argument(
+            "--vocoder-name",
+            type=str,
+            help="Vocoder name.",
+            default="vocoder_v2",
+        )
         parser.add_argument(
         parser.add_argument(
             "--vocoder-speaker-id",
             "--vocoder-speaker-id",
             type=int,
             type=int,
@@ -64,6 +86,4 @@ class VocoderAgent(TextToSpeechAgent):  # type: ignore
 
 
     @classmethod
     @classmethod
     def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> VocoderAgent:
     def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> VocoderAgent:
-        vocoder = kwargs.get("vocoder", None)
-        assert isinstance(vocoder, Vocoder)
-        return cls(vocoder, args)
+        return cls(args)

+ 31 - 14
src/seamless_communication/streaming/agents/pretssel_vocoder.py

@@ -5,27 +5,46 @@
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 from __future__ import annotations
 from __future__ import annotations
 
 
+import logging
 from argparse import ArgumentParser, Namespace
 from argparse import ArgumentParser, Namespace
 from typing import Any, Dict, List
 from typing import Any, Dict, List
 
 
 import torch
 import torch
+from fairseq2.assets import asset_store
 from fairseq2.data.audio import WaveformToFbankConverter, WaveformToFbankInput
 from fairseq2.data.audio import WaveformToFbankConverter, WaveformToFbankInput
-from seamless_communication.models.generator.vocoder import PretsselVocoder
+from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
 from seamless_communication.models.unity import load_gcmvn_stats
 from seamless_communication.models.unity import load_gcmvn_stats
-from seamless_communication.models.vocoder.vocoder import Vocoder
 from seamless_communication.streaming.agents.common import NoUpdateTargetMixin
 from seamless_communication.streaming.agents.common import NoUpdateTargetMixin
 from simuleval.agents import AgentStates, TextToSpeechAgent
 from simuleval.agents import AgentStates, TextToSpeechAgent
 from simuleval.agents.actions import ReadAction, WriteAction
 from simuleval.agents.actions import ReadAction, WriteAction
 from simuleval.data.segments import SpeechSegment
 from simuleval.data.segments import SpeechSegment
 
 
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
 
 
 class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):  # type: ignore
 class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):  # type: ignore
-    def __init__(self, vocoder: Vocoder, args: Namespace) -> None:
+    def __init__(self, args: Namespace) -> None:
         super().__init__(args)
         super().__init__(args)
-        self.vocoder = vocoder
+
+        logger.info(
+            f"Loading the Vocoder model: {args.vocoder_name} on device={args.device}, dtype={args.dtype}"
+        )
+        assert "pretssel" in args.vocoder_name
+        self.vocoder = load_pretssel_vocoder_model(
+            args.vocoder_name, device=args.device, dtype=args.dtype
+        )
+        self.vocoder.eval()
+
+        vocoder_model_card = asset_store.retrieve_card(args.vocoder_name)
+        self.vocoder_sample_rate = vocoder_model_card.field("sample_rate").as_(int)
+
         self.upstream_idx = args.upstream_idx
         self.upstream_idx = args.upstream_idx
         self.sample_rate = args.sample_rate  # input sample rate
         self.sample_rate = args.sample_rate  # input sample rate
-        self.vocoder_sample_rate = args.vocoder_sample_rate  # output sample rate
         self.tgt_lang = args.tgt_lang
         self.tgt_lang = args.tgt_lang
         self.convert_to_fbank = WaveformToFbankConverter(
         self.convert_to_fbank = WaveformToFbankConverter(
             num_mel_bins=80,
             num_mel_bins=80,
@@ -110,23 +129,21 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):  # type: ign
 
 
     @classmethod
     @classmethod
     def add_args(cls, parser: ArgumentParser) -> None:
     def add_args(cls, parser: ArgumentParser) -> None:
+        parser.add_argument(
+            "--vocoder-name",
+            type=str,
+            help="Vocoder name.",
+            default="vocoder_pretssel",
+        )
         parser.add_argument(
         parser.add_argument(
             "--upstream-idx",
             "--upstream-idx",
             type=int,
             type=int,
             default=0,
             default=0,
             help="index of encoder states where states.source contains input audio",
             help="index of encoder states where states.source contains input audio",
         )
         )
-        parser.add_argument(
-            "--vocoder-sample-rate",
-            type=int,
-            default=16000,
-            help="sample rate out of the vocoder",
-        )
 
 
     @classmethod
     @classmethod
     def from_args(
     def from_args(
         cls, args: Namespace, **kwargs: Dict[str, Any]
         cls, args: Namespace, **kwargs: Dict[str, Any]
     ) -> PretsselVocoderAgent:
     ) -> PretsselVocoderAgent:
-        vocoder = kwargs.get("vocoder", None)
-        assert isinstance(vocoder, PretsselVocoder)
-        return cls(vocoder, args)
+        return cls(args)

+ 5 - 30
src/seamless_communication/streaming/agents/mma_m4t_s2st.py → src/seamless_communication/streaming/agents/seamless_streaming_s2st.py

@@ -4,6 +4,7 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
+from seamless_communication.streaming.agents.detokenizer import UnitYDetokenizerAgent
 from seamless_communication.streaming.agents.offline_w2v_bert_encoder import (
 from seamless_communication.streaming.agents.offline_w2v_bert_encoder import (
     OfflineWav2VecBertEncoderAgent,
     OfflineWav2VecBertEncoderAgent,
 )
 )
@@ -16,19 +17,15 @@ from seamless_communication.streaming.agents.online_text_decoder import (
 from seamless_communication.streaming.agents.online_unit_decoder import (
 from seamless_communication.streaming.agents.online_unit_decoder import (
     NARUnitYUnitDecoderAgent,
     NARUnitYUnitDecoderAgent,
 )
 )
-from seamless_communication.streaming.agents.silero_vad import SileroVADAgent
 from seamless_communication.streaming.agents.online_vocoder import VocoderAgent
 from seamless_communication.streaming.agents.online_vocoder import VocoderAgent
-from seamless_communication.streaming.agents.pretssel_vocoder import PretsselVocoderAgent
-
-from seamless_communication.streaming.agents.detokenizer import UnitYDetokenizerAgent
+from seamless_communication.streaming.agents.silero_vad import SileroVADAgent
 from seamless_communication.streaming.agents.unity_pipeline import (
 from seamless_communication.streaming.agents.unity_pipeline import (
     UnitYAgentPipeline,
     UnitYAgentPipeline,
     UnitYAgentTreePipeline,
     UnitYAgentTreePipeline,
 )
 )
-from simuleval.utils import entrypoint
 
 
 
 
-class MonotonicM4TS2STAgent(UnitYAgentPipeline):
+class SeamlessStreamingS2STAgent(UnitYAgentPipeline):
     pipeline = [
     pipeline = [
         OnlineFeatureExtractorAgent,
         OnlineFeatureExtractorAgent,
         OfflineWav2VecBertEncoderAgent,
         OfflineWav2VecBertEncoderAgent,
@@ -38,17 +35,7 @@ class MonotonicM4TS2STAgent(UnitYAgentPipeline):
     ]
     ]
 
 
 
 
-class SeamlessS2STAgent(UnitYAgentPipeline):
-    pipeline = [
-        OnlineFeatureExtractorAgent,
-        OfflineWav2VecBertEncoderAgent,
-        UnitYMMATextDecoderAgent,
-        NARUnitYUnitDecoderAgent,
-        PretsselVocoderAgent,
-    ]
-
-
-class MonotonicM4TS2STVADAgent(UnitYAgentPipeline):
+class SeamlessStreamingS2STVADAgent(UnitYAgentPipeline):
     pipeline = [
     pipeline = [
         SileroVADAgent,
         SileroVADAgent,
         OnlineFeatureExtractorAgent,
         OnlineFeatureExtractorAgent,
@@ -59,7 +46,7 @@ class MonotonicM4TS2STVADAgent(UnitYAgentPipeline):
     ]
     ]
 
 
 
 
-class MonotonicM4TS2STJointVADAgent(UnitYAgentTreePipeline):
+class SeamlessStreamingS2STJointVADAgent(UnitYAgentTreePipeline):
     pipeline = {
     pipeline = {
         SileroVADAgent: [OnlineFeatureExtractorAgent],
         SileroVADAgent: [OnlineFeatureExtractorAgent],
         OnlineFeatureExtractorAgent: [OfflineWav2VecBertEncoderAgent],
         OnlineFeatureExtractorAgent: [OfflineWav2VecBertEncoderAgent],
@@ -69,15 +56,3 @@ class MonotonicM4TS2STJointVADAgent(UnitYAgentTreePipeline):
         NARUnitYUnitDecoderAgent: [VocoderAgent],
         NARUnitYUnitDecoderAgent: [VocoderAgent],
         VocoderAgent: [],
         VocoderAgent: [],
     }
     }
-
-
-class SeamlessS2STJointVADAgent(UnitYAgentTreePipeline):
-    pipeline = {
-        SileroVADAgent: [OnlineFeatureExtractorAgent],
-        OnlineFeatureExtractorAgent: [OfflineWav2VecBertEncoderAgent],
-        OfflineWav2VecBertEncoderAgent: [UnitYMMATextDecoderAgent],
-        UnitYMMATextDecoderAgent: [UnitYDetokenizerAgent, NARUnitYUnitDecoderAgent],
-        UnitYDetokenizerAgent: [],
-        NARUnitYUnitDecoderAgent: [PretsselVocoderAgent],
-        PretsselVocoderAgent: [],
-    }

+ 3 - 6
src/seamless_communication/streaming/agents/mma_m4t_s2t.py → src/seamless_communication/streaming/agents/seamless_streaming_s2t.py

@@ -16,11 +16,9 @@ from seamless_communication.streaming.agents.online_text_decoder import (
 )
 )
 from seamless_communication.streaming.agents.silero_vad import SileroVADAgent
 from seamless_communication.streaming.agents.silero_vad import SileroVADAgent
 from seamless_communication.streaming.agents.unity_pipeline import UnitYAgentPipeline
 from seamless_communication.streaming.agents.unity_pipeline import UnitYAgentPipeline
-from simuleval.utils import entrypoint
 
 
 
 
-@entrypoint
-class MonotonicM4TS2TDetokAgent(UnitYAgentPipeline):
+class SeamlessStreamingS2TDetokAgent(UnitYAgentPipeline):
     pipeline = [
     pipeline = [
         OnlineFeatureExtractorAgent,
         OnlineFeatureExtractorAgent,
         OfflineWav2VecBertEncoderAgent,
         OfflineWav2VecBertEncoderAgent,
@@ -29,8 +27,7 @@ class MonotonicM4TS2TDetokAgent(UnitYAgentPipeline):
     ]
     ]
 
 
 
 
-@entrypoint
-class MonotonicM4TS2TAgent(UnitYAgentPipeline):
+class SeamlessStreamingS2TAgent(UnitYAgentPipeline):
     pipeline = [
     pipeline = [
         OnlineFeatureExtractorAgent,
         OnlineFeatureExtractorAgent,
         OfflineWav2VecBertEncoderAgent,
         OfflineWav2VecBertEncoderAgent,
@@ -38,7 +35,7 @@ class MonotonicM4TS2TAgent(UnitYAgentPipeline):
     ]
     ]
 
 
 
 
-class MonotonicM4TS2TVADAgent(UnitYAgentPipeline):
+class SeamlessStreamingS2TVADAgent(UnitYAgentPipeline):
     pipeline = [
     pipeline = [
         SileroVADAgent,
         SileroVADAgent,
         OnlineFeatureExtractorAgent,
         OnlineFeatureExtractorAgent,

+ 1 - 23
src/seamless_communication/streaming/agents/unity_pipeline.py

@@ -74,12 +74,7 @@ class UnitYPipelineMixin:
             help="Monotonic decoder model name.",
             help="Monotonic decoder model name.",
             default="seamless_streaming_monotonic_decoder",
             default="seamless_streaming_monotonic_decoder",
         )
         )
-        parser.add_argument(
-            "--vocoder-name",
-            type=str,
-            help="Vocoder name.",
-            default="vocoder_v2",
-        )
+
         parser.add_argument(
         parser.add_argument(
             "--sample-rate",
             "--sample-rate",
             default=16000,
             default=16000,
@@ -147,22 +142,6 @@ class UnitYPipelineMixin:
         )
         )
         monotonic_decoder_model.eval()
         monotonic_decoder_model.eval()
 
 
-        vocoder: Optional[Union[PretsselVocoder, Vocoder]] = None
-        if args.vocoder_name is not None and output_modality == Modality.SPEECH:
-            logger.info(
-                f"Loading the Vocoder model: {args.vocoder_name} on device={args.device}, dtype={args.dtype}"
-            )
-            if "pretssel" in args.vocoder_name:
-                vocoder = load_pretssel_vocoder_model(
-                    args.vocoder_name, device=args.device, dtype=args.dtype
-                )
-            else:
-                vocoder = load_vocoder_model(
-                    args.vocoder_name, device=args.device, dtype=args.dtype
-                )
-            assert vocoder is not None
-            vocoder.eval()
-
         return {
         return {
             "unity_model": unity_model,
             "unity_model": unity_model,
             "unity_config": unity_config,
             "unity_config": unity_config,
@@ -170,7 +149,6 @@ class UnitYPipelineMixin:
             "monotonic_decoder_config": monotonic_decoder_config,
             "monotonic_decoder_config": monotonic_decoder_config,
             "text_tokenizer": text_tokenizer,
             "text_tokenizer": text_tokenizer,
             "unit_tokenizer": unit_tokenizer,
             "unit_tokenizer": unit_tokenizer,
-            "vocoder": vocoder,
         }
         }
 
 
 
 

+ 35 - 15
src/seamless_communication/streaming/dataloaders/s2tt.py

@@ -55,7 +55,13 @@ class SileroVADSilenceRemover:
             onnx=False,
             onnx=False,
         )
         )
 
 
-    def __call__(self, sample_list: List[float]) -> List[float]:
+    def __call__(self, sample: torch.Tensor, is_standardized: bool) -> List[float]:
+        if not is_standardized:
+            # Standardizing here just for getting silence boundaries
+            standarized_sample_list = F.layer_norm(sample, sample.shape).tolist()
+        else:
+            standarized_sample_list = sample.tolist()
+
         (
         (
             get_speech_timestamps,
             get_speech_timestamps,
             save_audio,
             save_audio,
@@ -64,8 +70,10 @@ class SileroVADSilenceRemover:
             collect_chunks,
             collect_chunks,
         ) = self.utils
         ) = self.utils
         speech_timestamps = get_speech_timestamps(
         speech_timestamps = get_speech_timestamps(
-            sample_list, self.model, sampling_rate=self.sample_rate
+            standarized_sample_list, self.model, sampling_rate=self.sample_rate
         )
         )
+
+        sample_list: List[float] = sample.tolist()
         if len(speech_timestamps) == 0:
         if len(speech_timestamps) == 0:
             return sample_list
             return sample_list
         speech_start_time = speech_timestamps[0]["start"]
         speech_start_time = speech_timestamps[0]["start"]
@@ -75,7 +83,9 @@ class SileroVADSilenceRemover:
 
 
 @register_dataloader("fairseq2_s2tt")
 @register_dataloader("fairseq2_s2tt")
 class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader):  # type: ignore
 class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader):  # type: ignore
-    def __init__(self, data_pipeline: DataPipeline, args: Namespace) -> None:
+    def __init__(
+        self, data_pipeline: DataPipeline, is_standardized: bool, args: Namespace
+    ) -> None:
         self.args = args
         self.args = args
         self.data_file: Path = Path(getattr(self.args, "data_file", ""))
         self.data_file: Path = Path(getattr(self.args, "data_file", ""))
         if not self.data_file.exists():
         if not self.data_file.exists():
@@ -83,10 +93,12 @@ class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader
         self.start_index: int = getattr(self.args, "start_index", 0)
         self.start_index: int = getattr(self.args, "start_index", 0)
         self.end_index: int = getattr(self.args, "end_index", -1)
         self.end_index: int = getattr(self.args, "end_index", -1)
         self.data_pipeline = data_pipeline
         self.data_pipeline = data_pipeline
+        self.is_standardized = is_standardized
         self.data_itr = iter(self.data_pipeline)
         self.data_itr = iter(self.data_pipeline)
         self.cur_index = self.start_index - 1
         self.cur_index = self.start_index - 1
+        self.no_strip_silence = self.args.no_strip_silence
         self.silence_remover = None
         self.silence_remover = None
-        if self.args.strip_silence:
+        if not self.no_strip_silence:
             logger.warn(
             logger.warn(
                 "Stripping silence in the beginning and the end of audio with SileroVAD."
                 "Stripping silence in the beginning and the end of audio with SileroVAD."
             )
             )
@@ -113,12 +125,12 @@ class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader
         return self.end_index - self.start_index
         return self.end_index - self.start_index
 
 
     def get_source(self, index: Optional[int] = None) -> List[float]:
     def get_source(self, index: Optional[int] = None) -> List[float]:
-        source: List[float] = (
-            self.item["audio"]["data"]["waveform"]["seqs"].squeeze().tolist()
-        )
+        squeezed_item = self.item["audio"]["data"]["waveform"]["seqs"].squeeze()
 
 
-        if self.silence_remover is not None:
-            source = self.silence_remover(source)
+        if not self.no_strip_silence and self.silence_remover is not None:
+            source = self.silence_remover(squeezed_item, self.is_standardized)
+        else:
+            source = squeezed_item.tolist()
 
 
         return source
         return source
 
 
@@ -168,10 +180,13 @@ class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader
             selector="audio.data",
             selector="audio.data",
         )
         )
 
 
-        pipeline_builder.map(
-            lambda x: F.layer_norm(x, x.shape),
-            selector="audio.data.waveform",
-        )
+        is_standardized = False
+        if args.standardize_audio:
+            pipeline_builder.map(
+                lambda x: F.layer_norm(x, x.shape),
+                selector="audio.data.waveform",
+            )
+            is_standardized = True
 
 
         collate = Collater(pad_value=0, pad_to_multiple=1)
         collate = Collater(pad_value=0, pad_to_multiple=1)
 
 
@@ -181,7 +196,7 @@ class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader
 
 
         data_pipeline = pipeline_builder.and_return()
         data_pipeline = pipeline_builder.and_return()
 
 
-        return cls(data_pipeline, args)
+        return cls(data_pipeline, is_standardized, args)
 
 
     @staticmethod
     @staticmethod
     def add_args(parser: ArgumentParser) -> None:
     def add_args(parser: ArgumentParser) -> None:
@@ -222,8 +237,13 @@ class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader
             help="Output directory. Required if using iterable dataloader.",
             help="Output directory. Required if using iterable dataloader.",
         )
         )
         parser.add_argument(
         parser.add_argument(
-            "--strip-silence",
+            "--no-strip-silence",
             action="store_true",
             action="store_true",
             default=False,
             default=False,
             help="Strip silence in the beginning and the end of audio.",
             help="Strip silence in the beginning and the end of audio.",
         )
         )
+        parser.add_argument(
+            "--standardize-audio",
+            action="store_true",
+            help="Standardize audio.",
+        )