Эх сурвалжийг харах

Streaming Evaluate CLI (#150)

* streaming cli improvements

* Streaming Evaluate CLI

* bump simuleval version requirement

* rebase and update

* revert simuleval version bump

* Remove old scripts

* revert simuleval dependency version change

* mypy issue in pretssel_vocoder

* setting cli defaults

* Logging vocoder load

* change max_len_a to 1
Abinesh Ramakrishnan 1 жил өмнө
parent
commit
0d2c128b4a

+ 1 - 0
setup.py

@@ -37,6 +37,7 @@ setup(
             "m4t_finetune=seamless_communication.cli.m4t.finetune.finetune:main",
             "m4t_prepare_dataset=seamless_communication.cli.m4t.finetune.dataset:main",
             "m4t_audio_to_units=seamless_communication.cli.m4t.audio_to_units.audio_to_units:main",
+            "streaming_evaluate=seamless_communication.cli.streaming.evaluate:main",
         ],
     },
     include_package_data=True,

+ 100 - 24
src/seamless_communication/cli/streaming/evaluate.py

@@ -4,45 +4,121 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+import argparse
+
+from fairseq2.assets import asset_store, download_manager
+from seamless_communication.cli.eval_utils import get_tokenizer
 from seamless_communication.cli.streaming.scorers.seamless_whisper_asr_bleu import (
     SeamlessWhisperASRSacreBLEUScorer as SeamlessWhisperASRSacreBLEUScorer,
 )
-from seamless_communication.streaming.agents.mma_m4t_s2st import MonotonicM4TS2STAgent
-from simuleval.cli import evaluate
+from seamless_communication.streaming.agents.mma_m4t_s2st import (
+    MonotonicM4TS2STAgent,
+    SeamlessS2STAgent,
+)
+from seamless_communication.streaming.agents.mma_m4t_s2t import MonotonicM4TS2TAgent
+from simuleval.evaluator import build_evaluator
+from simuleval.utils.agent import EVALUATION_SYSTEM_LIST, build_system_args
 
 
-if __name__ == "__main__":
-    tgt_lang = "eng"
+def main() -> None:
+    parser = argparse.ArgumentParser(
+        add_help=False,
+        description="Streaming evaluation of Seamless UnitY models",
+        conflict_handler="resolve",
+    )
 
-    data_configs = dict(
-        dataloader="fairseq2_s2tt",
-        dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
-        data_file="/large_experiments/seamless/ust/annaysun/datasets/s2ut_pt/x2t_v2/dev_fleurs_spa-eng.tsv",
-        tgt_lang=tgt_lang,
-        audio_root_dir="/large_experiments/seamless/ust/data/audio_zips",
-        end_index=10,
+    parser.add_argument(
+        "--task",
+        choices=["s2st", "s2tt"],
+        required=True,
+        type=str,
+        help="Target language to translate/transcribe into.",
+    )
+    parser.add_argument(
+        "--expressive",
+        action="store_true",
+        default=False,
+        help="Expressive streaming S2ST inference",
     )
+    parser.add_argument(
+        "--dtype",
+        default="fp16",
+        type=str,
+    )
+
+    args, _ = parser.parse_known_args()
 
     model_configs = dict(
-        agent_class="seamless_communication.streaming.agents.mma_m4t_s2st.MonotonicM4TS2STAgent",
         source_segment_size=320,
-        task="s2st",
         device="cuda:0",
-        dtype="fp16",
+        dtype=args.dtype,
         min_starting_wait_w2vbert=192,
         decision_threshold=0.5,
-        min_unit_chunk_size=50,
         no_early_stop=True,
-        max_len_a=0,
-        max_len_b=100,
+        max_len_a=1,
+        max_len_b=200,
     )
 
-    eval_configs = dict(
-        output=f"MonotonicM4TS2STAgent_spa-eng_debug",
-        quality_metrics="SEAMLESS_WHISPER_ASR_BLEU",
-        latency_metrics="StartOffset EndOffset",
-        whisper_model_size="large-v2",
-        normalize_asr_bleu_references=True,
+    if args.dtype == "fp16":
+        model_configs.update(dict(fp16=True))
+
+    EVALUATION_SYSTEM_LIST.clear()
+    if args.task == "s2st":
+        model_configs.update(
+            dict(
+                min_unit_chunk_size=50,
+            )
+        )
+        eval_configs = dict(
+            quality_metrics="SEAMLESS_WHISPER_ASR_BLEU",
+            latency_metrics="StartOffset EndOffset",
+            whisper_model_size="large-v2",
+            normalize_asr_bleu_references=True,
+        )
+        if args.expressive:
+            EVALUATION_SYSTEM_LIST.append(SeamlessS2STAgent)
+            model_configs.update(dict(vocoder_name="vocoder_pretssel"))
+        else:
+            EVALUATION_SYSTEM_LIST.append(MonotonicM4TS2STAgent)
+    elif args.task == "s2tt":
+        EVALUATION_SYSTEM_LIST.append(MonotonicM4TS2TAgent)
+        parser.add_argument(
+            "--unity-model-name",
+            type=str,
+            help="Unity model name.",
+            default="seamless_streaming_unity",
+        )
+        parser.add_argument(
+            "--tgt-lang",
+            default="eng",
+            type=str,
+            help="Target language to translate/transcribe into.",
+        )
+        args, _ = parser.parse_known_args()
+        asset_card = asset_store.retrieve_card(name=args.unity_model_name)
+        tokenizer_uri = asset_card.field("tokenizer").as_uri()
+        tokenizer_path = download_manager.download_tokenizer(
+            tokenizer_uri, asset_card.name, force=False, progress=True
+        )
+        eval_configs = dict(
+            sacrebleu_tokenizer=get_tokenizer(args.tgt_lang),
+            eval_latency_unit="spm",
+            eval_latency_spm_model=tokenizer_path,
+            latency_metrics="AL LAAL",
+        )
+
+    base_config = dict(
+        dataloader="fairseq2_s2tt",
+        dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
     )
 
-    evaluate(MonotonicM4TS2STAgent, {**data_configs, **model_configs, **eval_configs})
+    system, args = build_system_args(
+        {**base_config, **model_configs, **eval_configs}, parser
+    )
+
+    evaluator = build_evaluator(args)
+    evaluator(system)
+
+
+if __name__ == "__main__":
+    main()

+ 0 - 49
src/seamless_communication/cli/streaming/evaluate_pretssel_vocoder.py

@@ -1,49 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates
-# All rights reserved.
-#
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from seamless_communication.cli.streaming.scorers.seamless_whisper_asr_bleu import (
-    SeamlessWhisperASRSacreBLEUScorer as SeamlessWhisperASRSacreBLEUScorer,
-)
-from seamless_communication.streaming.agents.mma_m4t_s2st import SeamlessS2STAgent
-from simuleval.cli import evaluate
-
-
-if __name__ == "__main__":
-    tgt_lang = "eng"
-
-    data_configs = dict(
-        dataloader="fairseq2_s2tt",
-        dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
-        data_file="/large_experiments/seamless/ust/annaysun/datasets/s2ut_pt/x2t_v2/dev_fleurs_spa-eng.tsv",
-        tgt_lang=tgt_lang,
-        audio_root_dir="/large_experiments/seamless/ust/data/audio_zips",
-        end_index=10,
-    )
-
-    model_configs = dict(
-        vocoder_name="vocoder_pretssel_16khz",
-        agent_class="seamless_communication.streaming.agents.mma_m4t_s2st.SeamlessS2STAgent",
-        source_segment_size=320,
-        task="s2st",
-        device="cuda:0",
-        dtype="fp16",
-        min_starting_wait_w2vbert=192,
-        decision_threshold=0.5,
-        min_unit_chunk_size=50,
-        no_early_stop=True,
-        max_len_a=0,
-        max_len_b=100,
-    )
-
-    eval_configs = dict(
-        output=f"SeamlessS2STAgent_spa-eng_debug",
-        quality_metrics="SEAMLESS_WHISPER_ASR_BLEU",
-        latency_metrics="StartOffset EndOffset",
-        whisper_model_size="large-v2",
-        normalize_asr_bleu_references=True,
-    )
-
-    evaluate(SeamlessS2STAgent, {**data_configs, **model_configs, **eval_configs})

+ 1 - 1
src/seamless_communication/cli/streaming/scorers/seamless_whisper_asr_bleu.py

@@ -9,7 +9,7 @@ from argparse import ArgumentParser, Namespace
 from typing import Dict, List
 
 from sacrebleu.metrics.bleu import BLEU
-from seamless_communication.cli.eval_utils import get_tokenizer, LANG3_LANG2
+from seamless_communication.cli.eval_utils import LANG3_LANG2, get_tokenizer
 from simuleval.evaluator.instance import LogInstance
 from simuleval.evaluator.scorers.quality_scorer import (
     WhisperASRSacreBLEUScorer,

+ 1 - 1
src/seamless_communication/streaming/agents/online_text_decoder.py

@@ -125,8 +125,8 @@ class OnlineTextDecoderAgent(GenericAgent):  # type: ignore
         )
         parser.add_argument(
             "--tgt-lang",
+            default="eng",
             type=str,
-            default=None,
         )
 
     def policy(self, states: DecoderAgentStates) -> Action:

+ 22 - 15
src/seamless_communication/streaming/agents/pretssel_vocoder.py

@@ -6,20 +6,20 @@
 from __future__ import annotations
 
 from argparse import ArgumentParser, Namespace
-import torch
-from typing import Any, Dict
+from typing import Any, Dict, List
 
-from fairseq2.data.audio import WaveformToFbankConverter
+import torch
+from fairseq2.data.audio import WaveformToFbankConverter, WaveformToFbankInput
+from seamless_communication.models.generator.vocoder import PretsselVocoder
 from seamless_communication.models.unity import load_gcmvn_stats
 from seamless_communication.models.vocoder.vocoder import Vocoder
-from seamless_communication.models.generator.vocoder import PretsselVocoder
 from seamless_communication.streaming.agents.common import NoUpdateTargetMixin
 from simuleval.agents import AgentStates, TextToSpeechAgent
 from simuleval.agents.actions import ReadAction, WriteAction
 from simuleval.data.segments import SpeechSegment
 
 
-class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
+class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):  # type: ignore
     def __init__(self, vocoder: Vocoder, args: Namespace) -> None:
         super().__init__(args)
         self.vocoder = vocoder
@@ -36,13 +36,15 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
             dtype=args.dtype,
         )
 
-
         _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
-        self.gcmvn_mean = torch.tensor(_gcmvn_mean, device=args.device, dtype=args.dtype)
+        self.gcmvn_mean = torch.tensor(
+            _gcmvn_mean, device=args.device, dtype=args.dtype
+        )
         self.gcmvn_std = torch.tensor(_gcmvn_std, device=args.device, dtype=args.dtype)
 
     def gcmvn_normalize(self, seqs: torch.Tensor) -> torch.Tensor:
-        return seqs.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
+        result: torch.Tensor = seqs.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
+        return result
 
     @torch.inference_mode()
     def policy(self, states: AgentStates) -> WriteAction:
@@ -66,15 +68,18 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
 
         duration *= 2
 
-        if type(states.upstream_states[self.upstream_idx].source) == list:
-            source = sum(states.upstream_states[self.upstream_idx].source, [])
+        if isinstance(states.upstream_states[self.upstream_idx].source, list):
+            source: List[float] = sum(
+                states.upstream_states[self.upstream_idx].source, []
+            )
         else:
             source = states.upstream_states[self.upstream_idx].source
 
-        audio_dict = {
-            "waveform": torch.tensor(source, dtype=torch.float32, device=self.device).unsqueeze(1),
+        audio_dict: WaveformToFbankInput = {
+            "waveform": torch.tensor(
+                source, dtype=torch.float32, device=self.device
+            ).unsqueeze(1),
             "sample_rate": self.sample_rate,
-            "format": -1,
         }
 
         feats = self.convert_to_fbank(audio_dict)["fbank"]
@@ -115,11 +120,13 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
             "--vocoder-sample-rate",
             type=int,
             default=16000,
-            help="sample rate out of the vocoder"
+            help="sample rate out of the vocoder",
         )
 
     @classmethod
-    def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> PretsselVocoderAgent:
+    def from_args(
+        cls, args: Namespace, **kwargs: Dict[str, Any]
+    ) -> PretsselVocoderAgent:
         vocoder = kwargs.get("vocoder", None)
         assert isinstance(vocoder, PretsselVocoder)
         return cls(vocoder, args)

+ 14 - 4
src/seamless_communication/streaming/agents/unity_pipeline.py

@@ -7,11 +7,13 @@ from __future__ import annotations
 
 import logging
 from argparse import ArgumentParser, Namespace
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Union
 
 import torch
 from fairseq2.assets import asset_store
 from seamless_communication.inference.translator import Modality, Translator
+from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
+from seamless_communication.models.generator.vocoder import PretsselVocoder
 from seamless_communication.models.monotonic_decoder import (
     load_monotonic_decoder_config,
     load_monotonic_decoder_model,
@@ -23,7 +25,7 @@ from seamless_communication.models.unity import (
     load_unity_unit_tokenizer,
 )
 from seamless_communication.models.vocoder.loader import load_vocoder_model
-from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
+from seamless_communication.models.vocoder.vocoder import Vocoder
 from seamless_communication.streaming.agents.common import (
     AgentStates,
     EarlyStoppingMixin,
@@ -85,8 +87,13 @@ class UnitYPipelineMixin:
         )
         parser.add_argument(
             "--dtype",
+            choices=["fp16", "fp32"],
             default="fp16",
             type=str,
+            help=(
+                "Choose between half-precision (fp16) and single precision (fp32) floating point formats."
+                + " Prefer this over the fp16 flag."
+            ),
         )
 
     @classmethod
@@ -140,8 +147,11 @@ class UnitYPipelineMixin:
         )
         monotonic_decoder_model.eval()
 
-        vocoder = None
+        vocoder: Optional[Union[PretsselVocoder, Vocoder]] = None
         if args.vocoder_name is not None and output_modality == Modality.SPEECH:
+            logger.info(
+                f"Loading the Vocoder model: {args.vocoder_name} on device={args.device}, dtype={args.dtype}"
+            )
             if "pretssel" in args.vocoder_name:
                 vocoder = load_pretssel_vocoder_model(
                     args.vocoder_name, device=args.device, dtype=args.dtype
@@ -150,7 +160,7 @@ class UnitYPipelineMixin:
                 vocoder = load_vocoder_model(
                     args.vocoder_name, device=args.device, dtype=args.dtype
                 )
-
+            assert vocoder is not None
             vocoder.eval()
 
         return {

+ 10 - 1
src/seamless_communication/streaming/dataloaders/s2tt.py

@@ -210,7 +210,16 @@ class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader
             help="Source segment size, For text the unit is # token, for speech is ms",
         )
         parser.add_argument(
-            "--tgt-lang", type=str, help="Target language to translate/transcribe into."
+            "--tgt-lang",
+            default="eng",
+            type=str,
+            help="Target language to translate/transcribe into.",
+        )
+        parser.add_argument(
+            "--output",
+            type=str,
+            required=True,
+            help="Output directory. Required if using iterable dataloader.",
         )
         parser.add_argument(
             "--strip-silence",