Browse Source

Streaming Evaluate CLI (#150)

* streaming cli improvements

* Streaming Evaluate CLI

* bump simuleval version requirement

* rebase and update

* revert simuleval version bump

* Remove old scripts

* revert simuleval dependency version change

* mypy issue in pretssel_vocoder

* setting cli defaults

* Logging vocoder load

* change max_len_a to 1
Abinesh Ramakrishnan 1 year ago
parent
commit
0d2c128b4a

+ 1 - 0
setup.py

@@ -37,6 +37,7 @@ setup(
             "m4t_finetune=seamless_communication.cli.m4t.finetune.finetune:main",
             "m4t_finetune=seamless_communication.cli.m4t.finetune.finetune:main",
             "m4t_prepare_dataset=seamless_communication.cli.m4t.finetune.dataset:main",
             "m4t_prepare_dataset=seamless_communication.cli.m4t.finetune.dataset:main",
             "m4t_audio_to_units=seamless_communication.cli.m4t.audio_to_units.audio_to_units:main",
             "m4t_audio_to_units=seamless_communication.cli.m4t.audio_to_units.audio_to_units:main",
+            "streaming_evaluate=seamless_communication.cli.streaming.evaluate:main",
         ],
         ],
     },
     },
     include_package_data=True,
     include_package_data=True,

+ 100 - 24
src/seamless_communication/cli/streaming/evaluate.py

@@ -4,45 +4,121 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 
 
+import argparse
+
+from fairseq2.assets import asset_store, download_manager
+from seamless_communication.cli.eval_utils import get_tokenizer
 from seamless_communication.cli.streaming.scorers.seamless_whisper_asr_bleu import (
 from seamless_communication.cli.streaming.scorers.seamless_whisper_asr_bleu import (
     SeamlessWhisperASRSacreBLEUScorer as SeamlessWhisperASRSacreBLEUScorer,
     SeamlessWhisperASRSacreBLEUScorer as SeamlessWhisperASRSacreBLEUScorer,
 )
 )
-from seamless_communication.streaming.agents.mma_m4t_s2st import MonotonicM4TS2STAgent
-from simuleval.cli import evaluate
+from seamless_communication.streaming.agents.mma_m4t_s2st import (
+    MonotonicM4TS2STAgent,
+    SeamlessS2STAgent,
+)
+from seamless_communication.streaming.agents.mma_m4t_s2t import MonotonicM4TS2TAgent
+from simuleval.evaluator import build_evaluator
+from simuleval.utils.agent import EVALUATION_SYSTEM_LIST, build_system_args
 
 
 
 
-if __name__ == "__main__":
-    tgt_lang = "eng"
+def main() -> None:
+    parser = argparse.ArgumentParser(
+        add_help=False,
+        description="Streaming evaluation of Seamless UnitY models",
+        conflict_handler="resolve",
+    )
 
 
-    data_configs = dict(
-        dataloader="fairseq2_s2tt",
-        dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
-        data_file="/large_experiments/seamless/ust/annaysun/datasets/s2ut_pt/x2t_v2/dev_fleurs_spa-eng.tsv",
-        tgt_lang=tgt_lang,
-        audio_root_dir="/large_experiments/seamless/ust/data/audio_zips",
-        end_index=10,
+    parser.add_argument(
+        "--task",
+        choices=["s2st", "s2tt"],
+        required=True,
+        type=str,
+        help="Target language to translate/transcribe into.",
+    )
+    parser.add_argument(
+        "--expressive",
+        action="store_true",
+        default=False,
+        help="Expressive streaming S2ST inference",
     )
     )
+    parser.add_argument(
+        "--dtype",
+        default="fp16",
+        type=str,
+    )
+
+    args, _ = parser.parse_known_args()
 
 
     model_configs = dict(
     model_configs = dict(
-        agent_class="seamless_communication.streaming.agents.mma_m4t_s2st.MonotonicM4TS2STAgent",
         source_segment_size=320,
         source_segment_size=320,
-        task="s2st",
         device="cuda:0",
         device="cuda:0",
-        dtype="fp16",
+        dtype=args.dtype,
         min_starting_wait_w2vbert=192,
         min_starting_wait_w2vbert=192,
         decision_threshold=0.5,
         decision_threshold=0.5,
-        min_unit_chunk_size=50,
         no_early_stop=True,
         no_early_stop=True,
-        max_len_a=0,
-        max_len_b=100,
+        max_len_a=1,
+        max_len_b=200,
     )
     )
 
 
-    eval_configs = dict(
-        output=f"MonotonicM4TS2STAgent_spa-eng_debug",
-        quality_metrics="SEAMLESS_WHISPER_ASR_BLEU",
-        latency_metrics="StartOffset EndOffset",
-        whisper_model_size="large-v2",
-        normalize_asr_bleu_references=True,
+    if args.dtype == "fp16":
+        model_configs.update(dict(fp16=True))
+
+    EVALUATION_SYSTEM_LIST.clear()
+    if args.task == "s2st":
+        model_configs.update(
+            dict(
+                min_unit_chunk_size=50,
+            )
+        )
+        eval_configs = dict(
+            quality_metrics="SEAMLESS_WHISPER_ASR_BLEU",
+            latency_metrics="StartOffset EndOffset",
+            whisper_model_size="large-v2",
+            normalize_asr_bleu_references=True,
+        )
+        if args.expressive:
+            EVALUATION_SYSTEM_LIST.append(SeamlessS2STAgent)
+            model_configs.update(dict(vocoder_name="vocoder_pretssel"))
+        else:
+            EVALUATION_SYSTEM_LIST.append(MonotonicM4TS2STAgent)
+    elif args.task == "s2tt":
+        EVALUATION_SYSTEM_LIST.append(MonotonicM4TS2TAgent)
+        parser.add_argument(
+            "--unity-model-name",
+            type=str,
+            help="Unity model name.",
+            default="seamless_streaming_unity",
+        )
+        parser.add_argument(
+            "--tgt-lang",
+            default="eng",
+            type=str,
+            help="Target language to translate/transcribe into.",
+        )
+        args, _ = parser.parse_known_args()
+        asset_card = asset_store.retrieve_card(name=args.unity_model_name)
+        tokenizer_uri = asset_card.field("tokenizer").as_uri()
+        tokenizer_path = download_manager.download_tokenizer(
+            tokenizer_uri, asset_card.name, force=False, progress=True
+        )
+        eval_configs = dict(
+            sacrebleu_tokenizer=get_tokenizer(args.tgt_lang),
+            eval_latency_unit="spm",
+            eval_latency_spm_model=tokenizer_path,
+            latency_metrics="AL LAAL",
+        )
+
+    base_config = dict(
+        dataloader="fairseq2_s2tt",
+        dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
     )
     )
 
 
-    evaluate(MonotonicM4TS2STAgent, {**data_configs, **model_configs, **eval_configs})
+    system, args = build_system_args(
+        {**base_config, **model_configs, **eval_configs}, parser
+    )
+
+    evaluator = build_evaluator(args)
+    evaluator(system)
+
+
+if __name__ == "__main__":
+    main()

+ 0 - 49
src/seamless_communication/cli/streaming/evaluate_pretssel_vocoder.py

@@ -1,49 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from seamless_communication.cli.streaming.scorers.seamless_whisper_asr_bleu import (
-    SeamlessWhisperASRSacreBLEUScorer as SeamlessWhisperASRSacreBLEUScorer,
-)
-from seamless_communication.streaming.agents.mma_m4t_s2st import SeamlessS2STAgent
-from simuleval.cli import evaluate
-
-
-if __name__ == "__main__":
-    tgt_lang = "eng"
-
-    data_configs = dict(
-        dataloader="fairseq2_s2tt",
-        dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
-        data_file="/large_experiments/seamless/ust/annaysun/datasets/s2ut_pt/x2t_v2/dev_fleurs_spa-eng.tsv",
-        tgt_lang=tgt_lang,
-        audio_root_dir="/large_experiments/seamless/ust/data/audio_zips",
-        end_index=10,
-    )
-
-    model_configs = dict(
-        vocoder_name="vocoder_pretssel_16khz",
-        agent_class="seamless_communication.streaming.agents.mma_m4t_s2st.SeamlessS2STAgent",
-        source_segment_size=320,
-        task="s2st",
-        device="cuda:0",
-        dtype="fp16",
-        min_starting_wait_w2vbert=192,
-        decision_threshold=0.5,
-        min_unit_chunk_size=50,
-        no_early_stop=True,
-        max_len_a=0,
-        max_len_b=100,
-    )
-
-    eval_configs = dict(
-        output=f"SeamlessS2STAgent_spa-eng_debug",
-        quality_metrics="SEAMLESS_WHISPER_ASR_BLEU",
-        latency_metrics="StartOffset EndOffset",
-        whisper_model_size="large-v2",
-        normalize_asr_bleu_references=True,
-    )
-
-    evaluate(SeamlessS2STAgent, {**data_configs, **model_configs, **eval_configs})

+ 1 - 1
src/seamless_communication/cli/streaming/scorers/seamless_whisper_asr_bleu.py

@@ -9,7 +9,7 @@ from argparse import ArgumentParser, Namespace
 from typing import Dict, List
 from typing import Dict, List
 
 
 from sacrebleu.metrics.bleu import BLEU
 from sacrebleu.metrics.bleu import BLEU
-from seamless_communication.cli.eval_utils import get_tokenizer, LANG3_LANG2
+from seamless_communication.cli.eval_utils import LANG3_LANG2, get_tokenizer
 from simuleval.evaluator.instance import LogInstance
 from simuleval.evaluator.instance import LogInstance
 from simuleval.evaluator.scorers.quality_scorer import (
 from simuleval.evaluator.scorers.quality_scorer import (
     WhisperASRSacreBLEUScorer,
     WhisperASRSacreBLEUScorer,

+ 1 - 1
src/seamless_communication/streaming/agents/online_text_decoder.py

@@ -125,8 +125,8 @@ class OnlineTextDecoderAgent(GenericAgent):  # type: ignore
         )
         )
         parser.add_argument(
         parser.add_argument(
             "--tgt-lang",
             "--tgt-lang",
+            default="eng",
             type=str,
             type=str,
-            default=None,
         )
         )
 
 
     def policy(self, states: DecoderAgentStates) -> Action:
     def policy(self, states: DecoderAgentStates) -> Action:

+ 22 - 15
src/seamless_communication/streaming/agents/pretssel_vocoder.py

@@ -6,20 +6,20 @@
 from __future__ import annotations
 from __future__ import annotations
 
 
 from argparse import ArgumentParser, Namespace
 from argparse import ArgumentParser, Namespace
-import torch
-from typing import Any, Dict
+from typing import Any, Dict, List
 
 
-from fairseq2.data.audio import WaveformToFbankConverter
+import torch
+from fairseq2.data.audio import WaveformToFbankConverter, WaveformToFbankInput
+from seamless_communication.models.generator.vocoder import PretsselVocoder
 from seamless_communication.models.unity import load_gcmvn_stats
 from seamless_communication.models.unity import load_gcmvn_stats
 from seamless_communication.models.vocoder.vocoder import Vocoder
 from seamless_communication.models.vocoder.vocoder import Vocoder
-from seamless_communication.models.generator.vocoder import PretsselVocoder
 from seamless_communication.streaming.agents.common import NoUpdateTargetMixin
 from seamless_communication.streaming.agents.common import NoUpdateTargetMixin
 from simuleval.agents import AgentStates, TextToSpeechAgent
 from simuleval.agents import AgentStates, TextToSpeechAgent
 from simuleval.agents.actions import ReadAction, WriteAction
 from simuleval.agents.actions import ReadAction, WriteAction
 from simuleval.data.segments import SpeechSegment
 from simuleval.data.segments import SpeechSegment
 
 
 
 
-class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
+class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):  # type: ignore
     def __init__(self, vocoder: Vocoder, args: Namespace) -> None:
     def __init__(self, vocoder: Vocoder, args: Namespace) -> None:
         super().__init__(args)
         super().__init__(args)
         self.vocoder = vocoder
         self.vocoder = vocoder
@@ -36,13 +36,15 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
             dtype=args.dtype,
             dtype=args.dtype,
         )
         )
 
 
-
         _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
         _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
-        self.gcmvn_mean = torch.tensor(_gcmvn_mean, device=args.device, dtype=args.dtype)
+        self.gcmvn_mean = torch.tensor(
+            _gcmvn_mean, device=args.device, dtype=args.dtype
+        )
         self.gcmvn_std = torch.tensor(_gcmvn_std, device=args.device, dtype=args.dtype)
         self.gcmvn_std = torch.tensor(_gcmvn_std, device=args.device, dtype=args.dtype)
 
 
     def gcmvn_normalize(self, seqs: torch.Tensor) -> torch.Tensor:
     def gcmvn_normalize(self, seqs: torch.Tensor) -> torch.Tensor:
-        return seqs.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
+        result: torch.Tensor = seqs.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
+        return result
 
 
     @torch.inference_mode()
     @torch.inference_mode()
     def policy(self, states: AgentStates) -> WriteAction:
     def policy(self, states: AgentStates) -> WriteAction:
@@ -66,15 +68,18 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
 
 
         duration *= 2
         duration *= 2
 
 
-        if type(states.upstream_states[self.upstream_idx].source) == list:
-            source = sum(states.upstream_states[self.upstream_idx].source, [])
+        if isinstance(states.upstream_states[self.upstream_idx].source, list):
+            source: List[float] = sum(
+                states.upstream_states[self.upstream_idx].source, []
+            )
         else:
         else:
             source = states.upstream_states[self.upstream_idx].source
             source = states.upstream_states[self.upstream_idx].source
 
 
-        audio_dict = {
-            "waveform": torch.tensor(source, dtype=torch.float32, device=self.device).unsqueeze(1),
+        audio_dict: WaveformToFbankInput = {
+            "waveform": torch.tensor(
+                source, dtype=torch.float32, device=self.device
+            ).unsqueeze(1),
             "sample_rate": self.sample_rate,
             "sample_rate": self.sample_rate,
-            "format": -1,
         }
         }
 
 
         feats = self.convert_to_fbank(audio_dict)["fbank"]
         feats = self.convert_to_fbank(audio_dict)["fbank"]
@@ -115,11 +120,13 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
             "--vocoder-sample-rate",
             "--vocoder-sample-rate",
             type=int,
             type=int,
             default=16000,
             default=16000,
-            help="sample rate out of the vocoder"
+            help="sample rate out of the vocoder",
         )
         )
 
 
     @classmethod
     @classmethod
-    def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> PretsselVocoderAgent:
+    def from_args(
+        cls, args: Namespace, **kwargs: Dict[str, Any]
+    ) -> PretsselVocoderAgent:
         vocoder = kwargs.get("vocoder", None)
         vocoder = kwargs.get("vocoder", None)
         assert isinstance(vocoder, PretsselVocoder)
         assert isinstance(vocoder, PretsselVocoder)
         return cls(vocoder, args)
         return cls(vocoder, args)

+ 14 - 4
src/seamless_communication/streaming/agents/unity_pipeline.py

@@ -7,11 +7,13 @@ from __future__ import annotations
 
 
 import logging
 import logging
 from argparse import ArgumentParser, Namespace
 from argparse import ArgumentParser, Namespace
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Union
 
 
 import torch
 import torch
 from fairseq2.assets import asset_store
 from fairseq2.assets import asset_store
 from seamless_communication.inference.translator import Modality, Translator
 from seamless_communication.inference.translator import Modality, Translator
+from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
+from seamless_communication.models.generator.vocoder import PretsselVocoder
 from seamless_communication.models.monotonic_decoder import (
 from seamless_communication.models.monotonic_decoder import (
     load_monotonic_decoder_config,
     load_monotonic_decoder_config,
     load_monotonic_decoder_model,
     load_monotonic_decoder_model,
@@ -23,7 +25,7 @@ from seamless_communication.models.unity import (
     load_unity_unit_tokenizer,
     load_unity_unit_tokenizer,
 )
 )
 from seamless_communication.models.vocoder.loader import load_vocoder_model
 from seamless_communication.models.vocoder.loader import load_vocoder_model
-from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
+from seamless_communication.models.vocoder.vocoder import Vocoder
 from seamless_communication.streaming.agents.common import (
 from seamless_communication.streaming.agents.common import (
     AgentStates,
     AgentStates,
     EarlyStoppingMixin,
     EarlyStoppingMixin,
@@ -85,8 +87,13 @@ class UnitYPipelineMixin:
         )
         )
         parser.add_argument(
         parser.add_argument(
             "--dtype",
             "--dtype",
+            choices=["fp16", "fp32"],
             default="fp16",
             default="fp16",
             type=str,
             type=str,
+            help=(
+                "Choose between half-precision (fp16) and single precision (fp32) floating point formats."
+                + " Prefer this over the fp16 flag."
+            ),
         )
         )
 
 
     @classmethod
     @classmethod
@@ -140,8 +147,11 @@ class UnitYPipelineMixin:
         )
         )
         monotonic_decoder_model.eval()
         monotonic_decoder_model.eval()
 
 
-        vocoder = None
+        vocoder: Optional[Union[PretsselVocoder, Vocoder]] = None
         if args.vocoder_name is not None and output_modality == Modality.SPEECH:
         if args.vocoder_name is not None and output_modality == Modality.SPEECH:
+            logger.info(
+                f"Loading the Vocoder model: {args.vocoder_name} on device={args.device}, dtype={args.dtype}"
+            )
             if "pretssel" in args.vocoder_name:
             if "pretssel" in args.vocoder_name:
                 vocoder = load_pretssel_vocoder_model(
                 vocoder = load_pretssel_vocoder_model(
                     args.vocoder_name, device=args.device, dtype=args.dtype
                     args.vocoder_name, device=args.device, dtype=args.dtype
@@ -150,7 +160,7 @@ class UnitYPipelineMixin:
                 vocoder = load_vocoder_model(
                 vocoder = load_vocoder_model(
                     args.vocoder_name, device=args.device, dtype=args.dtype
                     args.vocoder_name, device=args.device, dtype=args.dtype
                 )
                 )
-
+            assert vocoder is not None
             vocoder.eval()
             vocoder.eval()
 
 
         return {
         return {

+ 10 - 1
src/seamless_communication/streaming/dataloaders/s2tt.py

@@ -210,7 +210,16 @@ class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader
             help="Source segment size, For text the unit is # token, for speech is ms",
             help="Source segment size, For text the unit is # token, for speech is ms",
         )
         )
         parser.add_argument(
         parser.add_argument(
-            "--tgt-lang", type=str, help="Target language to translate/transcribe into."
+            "--tgt-lang",
+            default="eng",
+            type=str,
+            help="Target language to translate/transcribe into.",
+        )
+        parser.add_argument(
+            "--output",
+            type=str,
+            required=True,
+            help="Output directory. Required if using iterable dataloader.",
         )
         )
         parser.add_argument(
         parser.add_argument(
             "--strip-silence",
             "--strip-silence",