Parcourir la source

Adding watermarked PretsselVocoderAgent (#149)

* adding watermarked PretsselVocoder Agent

* cleanup

---------

Co-authored-by: Anna Sun <13106449+annasun28@users.noreply.github.com>
Yilin Yang il y a 1 an
Parent
commit
64c0e73ac0

+ 1 - 1
src/seamless_communication/cli/streaming/evaluate.py

@@ -7,7 +7,7 @@
 from seamless_communication.cli.streaming.scorers.seamless_whisper_asr_bleu import (
     SeamlessWhisperASRSacreBLEUScorer as SeamlessWhisperASRSacreBLEUScorer,
 )
-from seamless_communication.streaming.agents import MonotonicM4TS2STAgent
+from seamless_communication.streaming.agents.mma_m4t_s2st import MonotonicM4TS2STAgent
 from simuleval.cli import evaluate
 
 

+ 49 - 0
src/seamless_communication/cli/streaming/evaluate_pretssel_vocoder.py

@@ -0,0 +1,49 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from seamless_communication.cli.streaming.scorers.seamless_whisper_asr_bleu import (
+    SeamlessWhisperASRSacreBLEUScorer as SeamlessWhisperASRSacreBLEUScorer,
+)
+from seamless_communication.streaming.agents.mma_m4t_s2st import SeamlessS2STAgent
+from simuleval.cli import evaluate
+
+
+if __name__ == "__main__":
+    tgt_lang = "eng"
+
+    data_configs = dict(
+        dataloader="fairseq2_s2tt",
+        dataloader_class="seamless_communication.streaming.dataloaders.s2tt.SimulEvalSpeechToTextDataloader",
+        data_file="/large_experiments/seamless/ust/annaysun/datasets/s2ut_pt/x2t_v2/dev_fleurs_spa-eng.tsv",
+        tgt_lang=tgt_lang,
+        audio_root_dir="/large_experiments/seamless/ust/data/audio_zips",
+        end_index=10,
+    )
+
+    model_configs = dict(
+        vocoder_name="vocoder_pretssel_16khz",
+        agent_class="seamless_communication.streaming.agents.mma_m4t_s2st.SeamlessS2STAgent",
+        source_segment_size=320,
+        task="s2st",
+        device="cuda:0",
+        dtype="fp16",
+        min_starting_wait_w2vbert=192,
+        decision_threshold=0.5,
+        min_unit_chunk_size=50,
+        no_early_stop=True,
+        max_len_a=0,
+        max_len_b=100,
+    )
+
+    eval_configs = dict(
+        output=f"SeamlessS2STAgent_spa-eng_debug",
+        quality_metrics="SEAMLESS_WHISPER_ASR_BLEU",
+        latency_metrics="StartOffset EndOffset",
+        whisper_model_size="large-v2",
+        normalize_asr_bleu_references=True,
+    )
+
+    evaluate(SeamlessS2STAgent, {**data_configs, **model_configs, **eval_configs})

+ 0 - 6
src/seamless_communication/streaming/agents/__init__.py

@@ -4,9 +4,3 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from seamless_communication.streaming.agents.mma_m4t_s2st import (
-    MonotonicM4TS2STAgent as MonotonicM4TS2STAgent,
-)
-from seamless_communication.streaming.agents.mma_m4t_s2t import (
-    MonotonicM4TS2TAgent as MonotonicM4TS2TAgent,
-)

+ 11 - 1
src/seamless_communication/streaming/agents/mma_m4t_s2st.py

@@ -18,6 +18,7 @@ from seamless_communication.streaming.agents.online_unit_decoder import (
 )
 from seamless_communication.streaming.agents.silero_vad import SileroVADAgent
 from seamless_communication.streaming.agents.online_vocoder import VocoderAgent
+from seamless_communication.streaming.agents.pretssel_vocoder import PretsselVocoderAgent
 
 from seamless_communication.streaming.agents.detokenizer import UnitYDetokenizerAgent
 from seamless_communication.streaming.agents.unity_pipeline import (
@@ -27,7 +28,6 @@ from seamless_communication.streaming.agents.unity_pipeline import (
 from simuleval.utils import entrypoint
 
 
-@entrypoint
 class MonotonicM4TS2STAgent(UnitYAgentPipeline):
     pipeline = [
         OnlineFeatureExtractorAgent,
@@ -38,6 +38,16 @@ class MonotonicM4TS2STAgent(UnitYAgentPipeline):
     ]
 
 
+class SeamlessS2STAgent(UnitYAgentPipeline):
+    pipeline = [
+        OnlineFeatureExtractorAgent,
+        OfflineWav2VecBertEncoderAgent,
+        UnitYMMATextDecoderAgent,
+        NARUnitYUnitDecoderAgent,
+        PretsselVocoderAgent,
+    ]
+
+
 class MonotonicM4TS2STVADAgent(UnitYAgentPipeline):
     pipeline = [
         SileroVADAgent,

+ 2 - 0
src/seamless_communication/streaming/agents/online_vocoder.py

@@ -7,6 +7,7 @@ from __future__ import annotations
 
 from argparse import ArgumentParser, Namespace
 from typing import Any, Dict
+import torch
 
 from seamless_communication.models.vocoder.vocoder import Vocoder
 from simuleval.agents import AgentStates, TextToSpeechAgent
@@ -22,6 +23,7 @@ class VocoderAgent(TextToSpeechAgent):  # type: ignore
         self.tgt_lang = args.tgt_lang
         self.speaker_id = args.vocoder_speaker_id
 
+    @torch.inference_mode()
     def policy(self, states: AgentStates) -> WriteAction:
         """
         The policy is always write if there are units

+ 118 - 0
src/seamless_communication/streaming/agents/pretssel_vocoder.py

@@ -0,0 +1,118 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+from argparse import ArgumentParser, Namespace
+import torch
+from typing import Any, Dict
+
+from fairseq2.data.audio import WaveformToFbankConverter
+from seamless_communication.models.unity import load_gcmvn_stats
+from seamless_communication.models.vocoder.vocoder import Vocoder
+from seamless_communication.models.generator.vocoder import PretsselVocoder
+from seamless_communication.streaming.agents.common import NoUpdateTargetMixin
+from simuleval.agents import AgentStates, TextToSpeechAgent
+from simuleval.agents.actions import ReadAction, WriteAction
+from simuleval.data.segments import SpeechSegment
+
+
+class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
+    def __init__(self, vocoder: Vocoder, args: Namespace) -> None:
+        super().__init__(args)
+        self.vocoder = vocoder
+        self.upstream_idx = args.upstream_idx
+        self.sample_rate = args.sample_rate
+        self.tgt_lang = args.tgt_lang
+        self.convert_to_fbank = WaveformToFbankConverter(
+            num_mel_bins=80,
+            waveform_scale=2**15,
+            channel_last=True,
+            standardize=False,
+            device=args.device,
+            dtype=args.dtype,
+        )
+
+
+        _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
+        self.gcmvn_mean = torch.tensor(_gcmvn_mean, device=args.device, dtype=args.dtype)
+        self.gcmvn_std = torch.tensor(_gcmvn_std, device=args.device, dtype=args.dtype)
+
+    def gcmvn_normalize(self, seqs: torch.Tensor) -> torch.Tensor:
+        return seqs.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
+
+    @torch.inference_mode()
+    def policy(self, states: AgentStates) -> WriteAction:
+        """
+        The policy is always write if there is a waveform
+        """
+        units = states.source
+
+        if len(units) == 0 or len(units[0]) == 0:
+            if states.source_finished:
+                return WriteAction(content=[], finished=True)
+            else:
+                return ReadAction()
+
+        unit = units[0][0]
+
+        # adjust the control symbols for the embedding
+        unit += 4
+
+        unit, duration = torch.unique_consecutive(unit, return_counts=True)
+
+        duration *= 2
+
+        if type(states.upstream_states[self.upstream_idx].source) == list:
+            source = sum(states.upstream_states[self.upstream_idx].source, [])
+        else:
+            source = states.upstream_states[self.upstream_idx].source
+
+        audio_dict = {
+            "waveform": torch.tensor(source, dtype=torch.float32, device=self.device).unsqueeze(1),
+            "sample_rate": 16000, # input audio is fixed to 16kHZ
+            "format": -1,
+        }
+
+        feats = self.convert_to_fbank(audio_dict)["fbank"]
+
+        feats = self.gcmvn_normalize(feats)
+
+        tgt_lang = states.tgt_lang if states.tgt_lang else self.tgt_lang
+
+        wav = self.vocoder(
+            unit,
+            tgt_lang=tgt_lang,
+            prosody_input_seqs=feats,
+            durations=duration.unsqueeze(0),
+            normalize_before=True,
+        )
+
+        states.source = []
+
+        return WriteAction(
+            SpeechSegment(
+                content=wav[0][0].tolist(),
+                finished=states.source_finished,
+                sample_rate=self.sample_rate,
+                tgt_lang=tgt_lang,
+            ),
+            finished=states.source_finished,
+        )
+
+    @classmethod
+    def add_args(cls, parser: ArgumentParser) -> None:
+        parser.add_argument(
+            "--upstream-idx",
+            type=int,
+            default=0,
+            help="index of encoder states where states.source contains input audio",
+        )
+
+    @classmethod
+    def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> PretsselVocoderAgent:
+        vocoder = kwargs.get("vocoder", None)
+        assert isinstance(vocoder, PretsselVocoder)
+        return cls(vocoder, args)

+ 10 - 3
src/seamless_communication/streaming/agents/unity_pipeline.py

@@ -23,6 +23,7 @@ from seamless_communication.models.unity import (
     load_unity_unit_tokenizer,
 )
 from seamless_communication.models.vocoder.loader import load_vocoder_model
+from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
 from seamless_communication.streaming.agents.common import (
     AgentStates,
     EarlyStoppingMixin,
@@ -141,9 +142,15 @@ class UnitYPipelineMixin:
 
         vocoder = None
         if args.vocoder_name is not None and output_modality == Modality.SPEECH:
-            vocoder = load_vocoder_model(
-                args.vocoder_name, device=args.device, dtype=args.dtype
-            )
+            if "pretssel" in args.vocoder_name:
+                vocoder = load_pretssel_vocoder_model(
+                    args.vocoder_name, device=args.device, dtype=args.dtype
+                )
+            else:
+                vocoder = load_vocoder_model(
+                    args.vocoder_name, device=args.device, dtype=args.dtype
+                )
+
             vocoder.eval()
 
         return {