Browse Source

Offline w2v-bert encoder agent with parity. (#110)

Kaushik Ram Sadagopan 1 year ago
parent
commit
239a9440a9

+ 4 - 1
src/seamless_communication/streaming/agents/mma_m4t_s2t.py

@@ -8,9 +8,12 @@ from seamless_communication.streaming.agents.online_feature_extractor import (
     OnlineFeatureExtractorAgent,
     OnlineFeatureExtractorAgent,
 )
 )
 from seamless_communication.streaming.agents.unity_pipeline import UnitYAgentPipeline
 from seamless_communication.streaming.agents.unity_pipeline import UnitYAgentPipeline
+from seamless_communication.streaming.agents.offline_w2v_bert_encoder import (
+    OfflineWav2VecBertEncoderAgent,
+)
 from simuleval.utils import entrypoint
 from simuleval.utils import entrypoint
 
 
 
 
 @entrypoint
 @entrypoint
 class MonotonicM4TS2TSPMAgent(UnitYAgentPipeline):
 class MonotonicM4TS2TSPMAgent(UnitYAgentPipeline):
-    pipeline = [OnlineFeatureExtractorAgent]
+    pipeline = [OnlineFeatureExtractorAgent, OfflineWav2VecBertEncoderAgent]

+ 99 - 0
src/seamless_communication/streaming/agents/offline_w2v_bert_encoder.py

@@ -0,0 +1,99 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+from argparse import ArgumentParser, Namespace
+from typing import Any, Dict
+
+import torch
+from fairseq2.data import SequenceData
+from fairseq2.data.data_pipeline import Collater
+from fairseq2.data.text import TextTokenizer
+from fairseq2.models.wav2vec2 import Wav2Vec2EncoderConfig
+from fairseq2.nn.padding import get_seqs_and_padding_mask
+from seamless_communication.models.unity.model import UnitYModel
+from simuleval.agents import AgentStates, SpeechToSpeechAgent
+from simuleval.agents.actions import Action, ReadAction, WriteAction
+from simuleval.data.segments import SpeechSegment
+
+
+class OfflineWav2VecBertEncoderAgent(SpeechToSpeechAgent):
+    """
+    Incremental encoding of an wav2vec encoder output
+    It update the whole encoder states every time when there is a new incoming segment.
+    """
+
+    def __init__(
+        self,
+        unity_model: UnitYModel,
+        w2v2_encoder_config: Wav2Vec2EncoderConfig,
+        text_tokenizer: TextTokenizer,
+        args: Namespace,
+    ) -> None:
+        super().__init__(args)
+        self.model = unity_model
+        self.w2v2_encoder_config = w2v2_encoder_config
+        self.collate = Collater(
+            pad_value=text_tokenizer.vocab_info.pad_idx, pad_to_multiple=2
+        )
+        self.device = args.device
+        self.dtype = args.dtype
+        self.min_starting_wait = args.min_starting_wait_w2vbert
+
+    @property
+    def min_input_length(self) -> int:
+        return self.w2v2_encoder_config.fbank_stride
+
+    @staticmethod
+    def add_args(parser: ArgumentParser):
+        parser.add_argument(
+            "--min-starting-wait-w2vbert",
+            default=None,
+            type=int,
+            help="Min starting wait in w2vbert",
+        )
+
+    @torch.inference_mode()
+    def policy(self, states: AgentStates) -> Action:
+        """
+        The policy for encoder is always write
+        only if the input is too short
+        """
+        if len(states.source) < self.min_input_length or (
+            self.min_starting_wait is not None
+            and len(states.source) < self.min_starting_wait
+        ):
+            if states.source_finished:
+                return WriteAction({}, finished=states.source_finished)
+            else:
+                return ReadAction()
+
+        inputs = torch.stack(states.source).to(device=self.device, dtype=self.dtype)
+        src: SequenceData = self.collate(inputs)
+
+        seqs, padding_mask = get_seqs_and_padding_mask(src)
+        encoder_output, _ = self.model.encode_speech(
+            seqs,
+            padding_mask,
+        )
+
+        return WriteAction(
+            SpeechSegment(
+                content=encoder_output,
+                tgt_lang=states.tgt_lang,
+                finished=states.source_finished,
+            ),
+            finished=states.source_finished,
+        )
+
+    @classmethod
+    def from_args(
+        cls, args: Namespace, **kwargs: Dict[str, Any]
+    ) -> OfflineWav2VecBertEncoderAgent:
+        unity_model = kwargs.get("unity_model", None)
+        unity_config = kwargs.get("unity_config", None)
+        text_tokenizer = kwargs.get("text_tokenizer", None)
+        return cls(unity_model, unity_config.w2v2_encoder_config, text_tokenizer, args)