Explorar o código

Vocoder SimulEval Agent and getting online S2ST parity. (#135)

* Vocoder Agent

* cleanup and other changes

* Additional changes

* Fixing input, output of vocoder inference.

* Add whisper asr scorer and verify S2ST parity.

* remove prefix_tgt_lang and tokenizer args (#139)

* Fix a bunch of mypy issues.

* add tgt-lang arg

* update decoder_features in ngram, add tgt-lang

* remove extra tgt-lang arg

---------

Co-authored-by: Kaushik Ram Sadagopan <kaushikram2811@gmail.com>
Co-authored-by: Anna Sun <13106449+annasun28@users.noreply.github.com>
Abinesh Ramakrishnan hai 1 ano
pai
achega
a1b5d918eb

+ 3 - 0
src/seamless_communication/cli/eval_utils/__init__.py

@@ -14,3 +14,6 @@ from seamless_communication.cli.eval_utils.compute_metrics import (
 from seamless_communication.cli.eval_utils.lang_mapping import (
     LANG2_LANG3 as LANG2_LANG3,
 )
+from seamless_communication.cli.eval_utils.lang_mapping import (
+    LANG3_LANG2 as LANG3_LANG2,
+)

+ 0 - 5
src/seamless_communication/cards/unity_sans_decoder.yaml → src/seamless_communication/cli/streaming/__init__.py

@@ -3,8 +3,3 @@
 #
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
-
-name: unity_sans_decoder
-base: unity_nllb-100
-model_arch: base_v2
-checkpoint: "file:///large_experiments/seamless/ust/krs/fairseq2_checkpoints/unity_sans_decoder.pt"

+ 48 - 0
src/seamless_communication/cli/streaming/evaluate.py

@@ -0,0 +1,48 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from seamless_communication.cli.streaming.scorers.seamless_whisper_asr_bleu import (
+    SeamlessWhisperASRSacreBLEUScorer as SeamlessWhisperASRSacreBLEUScorer,
+)
+from seamless_communication.streaming.agents import MonotonicM4TS2STAgent
+from simuleval.cli import evaluate
+
+
+if __name__ == "__main__":
+    tgt_lang = "eng"
+
+    data_configs = dict(
+        dataloader="fairseq2_s2tt",
+        dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
+        data_file="/large_experiments/seamless/ust/annaysun/datasets/s2ut_pt/x2t_v2/dev_fleurs_spa-eng.tsv",
+        tgt_lang=tgt_lang,
+        audio_root_dir="/large_experiments/seamless/ust/data/audio_zips",
+        end_index=10,
+    )
+
+    model_configs = dict(
+        agent_class="seamless_communication.streaming.agents.mma_m4t_s2st.MonotonicM4TS2STAgent",
+        source_segment_size=320,
+        task="s2st",
+        device="cuda:0",
+        dtype="fp16",
+        min_starting_wait_w2vbert=192,
+        decision_threshold=0.5,
+        min_unit_chunk_size=50,
+        no_early_stop=True,
+        max_len_a=0,
+        max_len_b=100,
+    )
+
+    eval_configs = dict(
+        output=f"MonotonicM4TS2STAgent_spa-eng_debug",
+        quality_metrics="SEAMLESS_WHISPER_ASR_BLEU",
+        latency_metrics="StartOffset EndOffset",
+        whisper_model_size="large-v2",
+        normalize_asr_bleu_references=True,
+    )
+
+    evaluate(MonotonicM4TS2STAgent, {**data_configs, **model_configs, **eval_configs})

+ 0 - 5
src/seamless_communication/cards/monotonic_decoder.yaml → src/seamless_communication/cli/streaming/scorers/__init__.py

@@ -3,8 +3,3 @@
 #
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
-
-name: monotonic_decoder
-model_type: monotonic_decoder
-model_arch: dense_1b
-checkpoint: "file:///large_experiments/seamless/ust/krs/fairseq2_checkpoints/monotonic_decoder.pt"

+ 84 - 0
src/seamless_communication/cli/streaming/scorers/seamless_whisper_asr_bleu.py

@@ -0,0 +1,84 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+from argparse import ArgumentParser, Namespace
+from typing import Dict, List
+
+from sacrebleu.metrics.bleu import BLEU
+from seamless_communication.cli.eval_utils import get_tokenizer, LANG3_LANG2
+from simuleval.evaluator.instance import LogInstance
+from simuleval.evaluator.scorers.quality_scorer import (
+    WhisperASRSacreBLEUScorer,
+    register_quality_scorer,
+)
+from whisper.normalizers import BasicTextNormalizer, EnglishTextNormalizer
+
+
+def normalize_text_whisper(sentences: List[str], lang: str) -> List[str]:
+    if lang in ["en", "eng"]:
+        normalizer = EnglishTextNormalizer()
+    else:
+        normalizer = BasicTextNormalizer()
+    normalized_sentences = []
+    for text in sentences:
+        normalized_sentences.append(normalizer(text))
+    return normalized_sentences
+
+
+@register_quality_scorer("SEAMLESS_WHISPER_ASR_BLEU")
+class SeamlessWhisperASRSacreBLEUScorer(WhisperASRSacreBLEUScorer):
+    def __init__(
+        self,
+        tokenizer: str = "13a",
+        target_lang: str = "en",
+        model_size: str = "base",
+        lowercase: bool = False,
+        remove_punctuations: bool = False,
+        normalize_asr_bleu_references: bool = False,
+    ) -> None:
+        super().__init__()
+        self.tokenizer = tokenizer
+        self.target_lang = target_lang
+        self.model_size = model_size
+        self.lowercase = lowercase
+        self.remove_punctuations = remove_punctuations
+        self.normalize_asr_bleu_references = normalize_asr_bleu_references
+
+    def __call__(self, instances: Dict[int, LogInstance]) -> float:
+        transcripts = self.asr_transcribe(instances)
+        references = [[ins.reference for ins in instances.values()]]
+
+        if self.normalize_asr_bleu_references:
+            transcripts = normalize_text_whisper(transcripts, self.target_lang)
+            references = [normalize_text_whisper(references[0], self.target_lang)]
+
+        score = (
+            BLEU(tokenize=self.tokenizer).corpus_score(transcripts, references).score
+        )
+        return score  # type: ignore[no-any-return]
+
+    @staticmethod
+    def add_args(parser: ArgumentParser) -> None:
+        WhisperASRSacreBLEUScorer.add_args(parser)
+        parser.add_argument(
+            "--normalize-asr-bleu-references",
+            action="store_true",
+            help="Normalize asr transcript and reference",
+        )
+
+    @classmethod
+    def from_args(cls, args: Namespace) -> SeamlessWhisperASRSacreBLEUScorer:
+        sacrebleu_tokenizer = get_tokenizer(args.tgt_lang)
+        tgt_lang_2ltr = LANG3_LANG2[args.tgt_lang]
+        return cls(
+            sacrebleu_tokenizer,
+            tgt_lang_2ltr,
+            args.whisper_model_size,
+            args.transcript_lowercase,
+            args.transcript_non_punctuation,
+            args.normalize_asr_bleu_references,
+        )

+ 1 - 1
src/seamless_communication/streaming/agents/common.py

@@ -22,7 +22,7 @@ class EarlyStoppingMixin:
 
 
 class AgentStates(AgentStatesOrig):
-    def update_target(self, segment: Segment):
+    def update_target(self, segment: Segment) -> None:
         """An AgentStates impl which doesn't update states.target"""
         self.target_finished = segment.finished
 

+ 3 - 0
src/seamless_communication/streaming/agents/mma_m4t_s2st.py

@@ -16,6 +16,8 @@ from seamless_communication.streaming.agents.online_text_decoder import (
 from seamless_communication.streaming.agents.online_unit_decoder import (
     NARUnitYUnitDecoderAgent,
 )
+from seamless_communication.streaming.agents.online_vocoder import VocoderAgent
+
 from seamless_communication.streaming.agents.unity_pipeline import UnitYAgentPipeline
 from simuleval.utils import entrypoint
 
@@ -27,4 +29,5 @@ class MonotonicM4TS2STAgent(UnitYAgentPipeline):
         OfflineWav2VecBertEncoderAgent,
         UnitYMMATextDecoderAgent,
         NARUnitYUnitDecoderAgent,
+        VocoderAgent,
     ]

+ 15 - 22
src/seamless_communication/streaming/agents/online_text_decoder.py

@@ -16,9 +16,9 @@ from seamless_communication.models.monotonic_decoder import (
     MonotonicDecoderConfig,
     MonotonicDecoderModel,
 )
+from seamless_communication.streaming.agents.common import AgentStates
 from simuleval.agents import GenericAgent
 from simuleval.agents.actions import Action, ReadAction, WriteAction
-from seamless_communication.streaming.agents.common import AgentStates
 from simuleval.data.segments import Segment, TextSegment
 from torch import Tensor
 
@@ -78,16 +78,8 @@ class OnlineTextDecoderAgent(GenericAgent):
         self.device = args.device
         self.dtype = args.dtype
         self.eos_idx = text_tokenizer.vocab_info.eos_idx
-        if (
-            hasattr(args, "tgt_lang")
-            and hasattr(args, "prefix_tgt_lang")
-            and args.tgt_lang is not None
-            and args.prefix_tgt_lang is not None
-        ):
-            assert args.tgt_lang == args.prefix_tgt_lang
-        tgt_lang = getattr(args, "tgt_lang", None) or getattr(
-            args, "prefix_tgt_lang", None
-        )
+
+        tgt_lang = getattr(args, "tgt_lang", None)
         assert tgt_lang is not None
         self.token_encoder = text_tokenizer.create_encoder(lang=tgt_lang, mode="target")
         prefix_indices = self.token_encoder.prefix_indices
@@ -132,7 +124,7 @@ class OnlineTextDecoderAgent(GenericAgent):
             default=False,
         )
         parser.add_argument(
-            "--prefix-tgt-lang",
+            "--tgt-lang",
             type=str,
             default=None,
         )
@@ -265,7 +257,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
             tgt_lang=states.tgt_lang,
         )
 
-    def get_blocked_ngrams(self, target_indices: List[int]):
+    def get_blocked_ngrams(self, target_indices: List[int]) -> Optional[Set[str]]:
         # TODO: make it configurable and use itertools
         if not self.block_ngrams:
             return None
@@ -285,25 +277,26 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
         self,
         states: DecoderAgentStates,
         pred_indices: List[int],
-        blocked_ngrams: Set[int],
+        decoder_features_out: Tensor,
+        blocked_ngrams: Set[str],
         index: int,
-    ):
+    ) -> bool:
         """
         This check is used to force a READ decision when n-gram repeat
         happens before source_finished
         """
         if not self.block_ngrams or states.source_finished:
-            return False
+            return False, decoder_features_out
         all_indices = states.target_indices + pred_indices + [index]
         for n in [3, 2]:  # TODO: make it configurable
             if len(all_indices) >= n and states.ngram_block_count <= 4:
                 if str(all_indices[-n:]) in blocked_ngrams:
                     states.ngram_block_count += 1
                     pred_indices[:] = pred_indices[: -(n - 1)]
-                    # decoder_features_out = decoder_features_out[:, : -(n - 1)]
-                    return True
+                    decoder_features_out = decoder_features_out[:, : -(n - 1)]
+                    return True, decoder_features_out
                 blocked_ngrams.add(str(all_indices[-n:]))
-        return False
+        return False, decoder_features_out
 
     @torch.inference_mode()
     def policy(self, states: DecoderAgentStates) -> Action:
@@ -350,9 +343,9 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
                 if prob == 1.0:
                     pred_indices = []
                 break
-            block_ngram = self.maybe_block_ngrams(
-                states, pred_indices, blocked_ngrams, index
-            )  # TODO: add back decoder_features_out processing for unity2
+            block_ngram, decoder_features_out = self.maybe_block_ngrams(
+                states, pred_indices, decoder_features_out, blocked_ngrams, index
+            )
             if block_ngram:
                 break
             if (

+ 68 - 0
src/seamless_communication/streaming/agents/online_vocoder.py

@@ -0,0 +1,68 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+from argparse import ArgumentParser, Namespace
+from typing import Any, Dict
+
+from seamless_communication.models.vocoder.vocoder import Vocoder
+from simuleval.agents import AgentStates, TextToSpeechAgent
+from simuleval.agents.actions import ReadAction, WriteAction
+from simuleval.data.segments import SpeechSegment
+
+
+class VocoderAgent(TextToSpeechAgent):
+    def __init__(self, vocoder: Vocoder, args: Namespace) -> None:
+        super().__init__(args)
+        self.sample_rate = args.sample_rate
+        self.vocoder = vocoder
+        self.tgt_lang = args.tgt_lang
+        self.speaker_id = args.vocoder_speaker_id
+
+    def policy(self, states: AgentStates) -> WriteAction:
+        """
+        The policy is always write if there are units
+        """
+        units = states.source
+
+        if len(units) == 0 or len(units[0]) == 0:
+            if states.source_finished:
+                return WriteAction([], finished=True)
+            else:
+                return ReadAction()
+
+        tgt_lang = states.tgt_lang if states.tgt_lang else self.tgt_lang
+        u = units[0][0].tolist()
+        wav_samples = self.vocoder(u, tgt_lang, self.speaker_id, dur_prediction=False)[
+            0
+        ][0].tolist()
+        states.source = []
+
+        return WriteAction(
+            SpeechSegment(
+                content=wav_samples,
+                finished=states.source_finished,
+                sample_rate=self.sample_rate,
+                tgt_lang=tgt_lang,
+            ),
+            finished=states.source_finished,
+        )
+
+    @classmethod
+    def add_args(cls, parser: ArgumentParser) -> None:
+        parser.add_argument(
+            "--vocoder-speaker-id",
+            type=int,
+            required=False,
+            default=-1,
+            help="Vocoder speaker id",
+        )
+
+    @classmethod
+    def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> VocoderAgent:
+        vocoder = kwargs.get("vocoder", None)
+        assert isinstance(vocoder, Vocoder)
+        return cls(vocoder, args)

+ 1 - 1
src/seamless_communication/streaming/agents/silero_vad.py

@@ -164,7 +164,7 @@ class SileroVADStates(EarlyStoppingMixin, AgentStates):
             self.source_finished = True
             self.debug_write_wav(np.empty(0, dtype=np.int16), finished=True)
 
-    def decay_silence_acc_ms(self):
+    def decay_silence_acc_ms(self) -> None:
         if self.consecutive_silence_decay_count <= 2:
             self.silence_acc_ms = self.silence_acc_ms // 2
             self.consecutive_silence_decay_count += 1

+ 15 - 0
src/seamless_communication/streaming/agents/unity_pipeline.py

@@ -22,6 +22,7 @@ from seamless_communication.models.unity import (
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
 )
+from seamless_communication.models.vocoder.loader import load_vocoder_model
 from seamless_communication.streaming.agents.common import (
     AgentStates,
     EarlyStoppingMixin,
@@ -70,6 +71,12 @@ class UnitYPipelineMixin:
             help="Monotonic decoder model name.",
             default="seamless_streaming_monotonic_decoder",
         )
+        parser.add_argument(
+            "--vocoder-name",
+            type=str,
+            help="Vocoder name.",
+            default="vocoder_v2",
+        )
         parser.add_argument(
             "--sample-rate",
             default=16000,
@@ -135,6 +142,13 @@ class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
         )
         monotonic_decoder_model.eval()
 
+        self.vocoder = None
+        if args.vocoder_name is not None and output_modality == Modality.SPEECH:
+            self.vocoder = load_vocoder_model(
+                args.vocoder_name, device=args.device, dtype=args.dtype
+            )
+            self.vocoder.eval()
+
         module_list = []
         for p in self.pipeline:
             module_list.append(
@@ -146,6 +160,7 @@ class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
                     monotonic_decoder_config=monotonic_decoder_config,
                     text_tokenizer=text_tokenizer,
                     unit_tokenizer=unit_tokenizer,
+                    vocoder=self.vocoder,
                 )
             )
 

+ 0 - 1
src/seamless_communication/streaming/dataloaders/s2tt.py

@@ -49,7 +49,6 @@ class SimulEvalSpeechToTextDataloader(SpeechToTextDataloader, IterableDataloader
         self.data_pipeline = data_pipeline
         self.data_itr = iter(self.data_pipeline)
         self.cur_index = self.start_index - 1
-        self.item = None
 
     def __iter__(self) -> SimulEvalSpeechToTextDataloader:
         return self