Prechádzať zdrojové kódy

Introduce MMASpeechToTextDecoderAgent and related agents for online_text_decoder. (#113)

* Initial commit declaring MMASpeechToTextDecoderAgent and related agents.

* Cleanup and polish the text decoder agent (#114)

* Cleanup and polish the text decoder agent

* Additional cleanup and fixes

* Using the correct dim for source length

* override add_args in MMATextDecoderAgent

* Fix mypy issues and other cosmetic changes.

* Fix text content of post_process.

* Setting the right defaults for max_len_a and max_len_b.

* SPM Detokenizer Agent

* Clean up spm_detokenizer, address mypy issues.

* Set default args to be compatible with the M4T S2T agent.

---------

Co-authored-by: Abinesh Ramakrishnan <3632454+ibanesh@users.noreply.github.com>
Kaushik Ram Sadagopan 1 rok pred
rodič
commit
5dd9722b8d

+ 3 - 1
src/seamless_communication/models/monotonic_decoder/__init__.py

@@ -3,13 +3,15 @@
 #
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
-
 from seamless_communication.models.monotonic_decoder.builder import (
     MonotonicDecoderBuilder as MonotonicDecoderBuilder,
 )
 from seamless_communication.models.monotonic_decoder.builder import (
     MonotonicDecoderConfig as MonotonicDecoderConfig,
 )
+from seamless_communication.models.monotonic_decoder.model import (
+    MonotonicDecoderModel as MonotonicDecoderModel,
+)
 from seamless_communication.models.monotonic_decoder.builder import (
     create_monotonic_decoder_model as create_monotonic_decoder_model,
 )

+ 1 - 1
src/seamless_communication/streaming/agents/__init__.py

@@ -5,5 +5,5 @@
 # LICENSE file in the root directory of this source tree.
 
 from seamless_communication.streaming.agents.mma_m4t_s2t import (
-    MonotonicM4TS2TSPMAgent as MonotonicM4TS2TSPMAgent,
+    MonotonicM4TS2TAgent as MonotonicM4TS2TAgent,
 )

+ 53 - 0
src/seamless_communication/streaming/agents/detokenizer.py

@@ -0,0 +1,53 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+from argparse import ArgumentParser, Namespace
+from typing import Any, Dict
+
+from simuleval.agents import TextToTextAgent
+from simuleval.agents.actions import Action, ReadAction, WriteAction
+from simuleval.agents.states import AgentStates
+
+
+class DetokenizerAgent(TextToTextAgent):
+    def __init__(self, args: Namespace):
+        super().__init__(args)
+        self.detokenize_only = args.detokenize_only
+
+    @classmethod
+    def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> DetokenizerAgent:
+        return cls(args)
+
+    def add_args(parser: ArgumentParser) -> None:
+        parser.add_argument(
+            "--detokenize-only",
+            action="store_true",
+            default=True,
+            help="Run detokenization without waiting for a new token.",
+        )
+
+    def policy(self, states: AgentStates) -> Action:
+        possible_full_words = self.decode(" ".join([x for x in states.source]))
+
+        if self.detokenize_only and len(states.source) > 0:
+            states.source = []
+            if len(possible_full_words) == 0 and not states.source_finished:
+                return ReadAction()
+            else:
+                return WriteAction(possible_full_words, states.source_finished)
+
+        if states.source_finished:
+            return WriteAction(possible_full_words, True)
+        elif len(possible_full_words.split()) > 1:
+            full_word = possible_full_words.split()[0]
+            states.source = states.source[-1:]
+            return WriteAction(full_word, finished=False)
+        else:
+            return ReadAction()
+
+    def decode(self, x: str) -> str:
+        return x.replace(" ", "").replace("\u2581", " ").strip()

+ 23 - 3
src/seamless_communication/streaming/agents/mma_m4t_s2t.py

@@ -4,16 +4,36 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+from seamless_communication.streaming.agents.detokenizer import (
+    DetokenizerAgent,
+)
 from seamless_communication.streaming.agents.online_feature_extractor import (
     OnlineFeatureExtractorAgent,
 )
-from seamless_communication.streaming.agents.unity_pipeline import UnitYAgentPipeline
 from seamless_communication.streaming.agents.offline_w2v_bert_encoder import (
     OfflineWav2VecBertEncoderAgent,
 )
+from seamless_communication.streaming.agents.online_text_decoder import (
+    MMASpeechToTextDecoderAgent,
+)
+from seamless_communication.streaming.agents.unity_pipeline import UnitYAgentPipeline
 from simuleval.utils import entrypoint
 
 
 @entrypoint
-class MonotonicM4TS2TSPMAgent(UnitYAgentPipeline):
-    pipeline = [OnlineFeatureExtractorAgent, OfflineWav2VecBertEncoderAgent]
+class MonotonicM4TS2TDetokAgent(UnitYAgentPipeline):
+    pipeline = [
+        OnlineFeatureExtractorAgent,
+        OfflineWav2VecBertEncoderAgent,
+        MMASpeechToTextDecoderAgent,
+        DetokenizerAgent,
+    ]
+
+
+@entrypoint
+class MonotonicM4TS2TAgent(UnitYAgentPipeline):
+    pipeline = [
+        OnlineFeatureExtractorAgent,
+        OfflineWav2VecBertEncoderAgent,
+        MMASpeechToTextDecoderAgent,
+    ]

+ 4 - 1
src/seamless_communication/streaming/agents/offline_w2v_bert_encoder.py

@@ -48,7 +48,7 @@ class OfflineWav2VecBertEncoderAgent(SpeechToSpeechAgent):
         return self.w2v2_encoder_config.fbank_stride
 
     @staticmethod
-    def add_args(parser: ArgumentParser):
+    def add_args(parser: ArgumentParser) -> None:
         parser.add_argument(
             "--min-starting-wait-w2vbert",
             default=None,
@@ -94,6 +94,9 @@ class OfflineWav2VecBertEncoderAgent(SpeechToSpeechAgent):
         cls, args: Namespace, **kwargs: Dict[str, Any]
     ) -> OfflineWav2VecBertEncoderAgent:
         unity_model = kwargs.get("unity_model", None)
+        assert isinstance(unity_model, UnitYModel)
         unity_config = kwargs.get("unity_config", None)
+        assert unity_config is not None
         text_tokenizer = kwargs.get("text_tokenizer", None)
+        assert isinstance(text_tokenizer, TextTokenizer)
         return cls(unity_model, unity_config.w2v2_encoder_config, text_tokenizer, args)

+ 325 - 0
src/seamless_communication/streaming/agents/online_text_decoder.py

@@ -0,0 +1,325 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+import torch
+
+from argparse import ArgumentParser, Namespace
+from torch import Tensor
+from typing import Any, Dict, List, Tuple
+
+from fairseq2.models.nllb.tokenizer import NllbTokenizer
+from fairseq2.nn.incremental_state import IncrementalStateBag
+from seamless_communication.models.monotonic_decoder import (
+    MonotonicDecoderConfig,
+    MonotonicDecoderModel,
+)
+
+from simuleval.agents import GenericAgent
+from simuleval.agents.actions import Action, ReadAction, WriteAction
+from simuleval.agents.states import AgentStates
+from simuleval.data.segments import Segment, TextSegment
+
+
+class DecoderAgentStates(AgentStates):
+    def reset(self) -> None:
+        self.source_steps = 0
+        self.target_indices: List[int] = []
+        self.tgt_lang = None
+        super().reset()
+
+    def update_source(self, segment: Segment) -> None:
+        """
+        Update states from input segment
+        Additionlly update incremental states
+
+        Args:
+            segment (~simuleval.agents.segments.Segment): input segment
+        """
+        self.source_finished = segment.finished
+        if self.tgt_lang is None and segment.tgt_lang is not None:
+            self.tgt_lang = segment.tgt_lang
+        if not segment.is_empty:
+            self.source = segment.content
+            if len(self.source) == 0 and segment.finished:
+                self.target_finished = True
+                return
+            self.source_steps = self.source.size(1)
+
+
+class OnlineTextDecoderAgent(GenericAgent):
+    """
+    Online text decoder
+    """
+
+    target_type = "text"
+
+    def __init__(
+        self,
+        model: MonotonicDecoderModel,
+        config: MonotonicDecoderConfig,
+        text_tokenizer: NllbTokenizer,
+        args: Namespace,
+    ) -> None:
+        super().__init__(args)
+        self.model = model
+        self.config = config
+        self.text_tokenizer = text_tokenizer
+
+        self.max_len_a: int = args.max_len_a
+        self.max_len_b: int = args.max_len_b
+        self.max_consecutive_writes = self.args.max_consecutive_write
+        self.min_starting_wait = args.min_starting_wait
+        self.min_starting_wait_reset = args.min_starting_wait_reset
+        self.no_early_stop = args.no_early_stop
+
+        self.device = args.device
+        self.dtype = args.dtype
+        self.eos_idx = text_tokenizer.vocab_info.eos_idx
+        token_encoder = text_tokenizer.create_encoder(lang=args.tgt_lang, mode="target")
+        prefix_tokens = token_encoder.prefix_indices
+        assert prefix_tokens is not None
+        self.prefix_tokens: List[int] = prefix_tokens.tolist()
+
+    def build_states(self) -> DecoderAgentStates:
+        return DecoderAgentStates()
+
+    def max_len(self, states: DecoderAgentStates) -> int:
+        return self.max_len_a * int(states.source.size(1)) + self.max_len_b
+
+    @staticmethod
+    def add_args(parser: ArgumentParser) -> None:
+        parser.add_argument(
+            "--max-len-a",
+            type=int,
+            default=1,
+            help="Max length of predictions, a in ax + b",
+        )
+        parser.add_argument(
+            "--max-len-b",
+            type=int,
+            default=200,
+            help="Max length of predictions, b in ax + b",
+        )
+        parser.add_argument(
+            "--max-consecutive-write",
+            type=int,
+            default=50,
+            help="Max consecutive writes.",
+        )
+        parser.add_argument(
+            "--min-starting-wait",
+            type=int,
+            default=12,
+            help="Minimal starting waiting source steps",
+        )
+        parser.add_argument(
+            "--min-starting-wait-reset",
+            type=int,
+            default=0,
+            help="Minimal starting waiting source steps",
+        )
+        parser.add_argument(
+            "--no-early-stop",
+            action="store_true",
+            default=True,
+        )
+
+    def policy(self, states: DecoderAgentStates) -> Action:
+        raise NotImplementedError
+
+
+class MMATextDecoderAgent(OnlineTextDecoderAgent):
+    def __init__(
+        self,
+        model: MonotonicDecoderModel,
+        config: MonotonicDecoderConfig,
+        text_tokenizer: NllbTokenizer,
+        args: Namespace,
+    ) -> None:
+        super().__init__(model, config, text_tokenizer, args=args)
+
+        self.num_decoder_layers = self.config.num_decoder_layers
+
+        self.decision_threshold = args.decision_threshold
+        self.decision_method = args.decision_method
+        self.p_choose_start_layer = args.p_choose_start_layer
+
+    @staticmethod
+    def add_args(parser: ArgumentParser) -> None:
+        OnlineTextDecoderAgent.add_args(parser)
+        parser.add_argument(
+            "--decision-threshold",
+            type=float,
+            default=0.5,
+            help="The threshold to write an output, from 0 to 1. Small values give low latency.",
+        )
+        parser.add_argument(
+            "--decision-method",
+            type=str,
+            default="min",
+            choices=["mean", "min", "median"],
+            help="The method to determine the decision. Either average all attention heads, or just pick the smallest one",
+        )
+        parser.add_argument(
+            "--p-choose-start-layer",
+            type=int,
+            default=0,
+            help="Encoder layer from which p_choose should be considered for selection.",
+        )
+
+    @classmethod
+    def from_args(
+        cls, args: Namespace, **kwargs: Dict[str, Any]
+    ) -> MMATextDecoderAgent:
+        model = kwargs.get("monotonic_decoder_model", None)
+        config = kwargs.get("monotonic_decoder_config", None)
+        text_tokenizer = kwargs.get("text_tokenizer", None)
+
+        assert isinstance(model, MonotonicDecoderModel)
+        assert isinstance(config, MonotonicDecoderConfig)
+        assert isinstance(text_tokenizer, NllbTokenizer)
+
+        return cls(
+            model=model,
+            config=config,
+            text_tokenizer=text_tokenizer,
+            args=args,
+        )
+
+    def run_decoder(
+        self, states: DecoderAgentStates, pred_indices: List[int]
+    ) -> Tuple[int, float, Tensor]:
+        if len(pred_indices) == 0:
+            target_input = torch.tensor(
+                self.prefix_tokens + states.target_indices,
+                device=self.device,
+                dtype=torch.int64,
+            ).unsqueeze(0)
+        else:
+            target_input = torch.tensor(
+                pred_indices[-1:], device=self.device, dtype=torch.int64
+            ).unsqueeze(0)
+
+        states.source_steps = states.source.size(1)
+        torch.cuda.empty_cache()
+
+        encoder_output = states.source
+        decoder_output, _, p_choose = self.model.decode(
+            target_input, None, encoder_output, None, state_bag=self.state_bag
+        )
+
+        logits = self.model.project(decoder_output)
+
+        index = int(logits[0, -1].argmax().item())
+        _, tgt_len, src_len = p_choose.size()
+
+        p_choose = p_choose.view(self.num_decoder_layers, -1, tgt_len, src_len)
+
+        if self.decision_method == "min":
+            prob = p_choose[self.p_choose_start_layer :, :, -1, -1].min().item()
+        elif self.decision_method == "mean":
+            prob = p_choose[self.p_choose_start_layer :, :, -1, -1].mean().item()
+        else:
+            prob = p_choose[self.p_choose_start_layer :, :, -1, -1].median().item()
+
+        return index, prob, decoder_output
+
+    def postprocess(
+        self, states: DecoderAgentStates, pred_indices: Tensor, finished: bool
+    ) -> TextSegment:
+        return TextSegment(
+            content=" ".join(
+                [self.text_tokenizer.model.index_to_token(idx) for idx in pred_indices]
+            ),
+            finished=finished,
+            tgt_lang=states.tgt_lang,
+        )
+
+    @torch.inference_mode()
+    def policy(self, states: DecoderAgentStates) -> Action:
+        if len(states.source) == 0:
+            return ReadAction()
+
+        if states.source_steps < self.min_starting_wait and not states.source_finished:
+            return ReadAction()
+
+        if states.target_finished:
+            return WriteAction("", finished=True)
+
+        if len(states.source) == 0:
+            return ReadAction()
+
+        self.state_bag = IncrementalStateBag(4096)
+
+        pred_indices: List[int] = []
+        index = None
+        prob = None
+        finished = False
+
+        while (
+            len(states.target_indices + pred_indices) < self.max_len(states)
+            and len(pred_indices) < self.max_consecutive_writes
+        ):
+            index, prob, _ = self.run_decoder(states, pred_indices)
+
+            if (
+                self.no_early_stop
+                and prob < self.decision_threshold
+                and not states.source_finished
+            ):
+                break
+            if (
+                self.no_early_stop
+                and index == self.eos_idx
+                and not states.source_finished
+            ):
+                if prob == 1.0:
+                    pred_indices = []
+                if states.source_steps < self.min_starting_wait_reset:
+                    pred_indices = []
+                    if len(states.target_indices) < 3:
+                        states.target_indices = []
+                break
+            if (
+                finished
+                or index == self.eos_idx
+                or len(states.target_indices + pred_indices) > self.max_len(states)
+            ):
+                finished = True
+                break
+
+            if (
+                not self.no_early_stop
+                and prob < self.decision_threshold
+                and not states.source_finished
+            ):
+                break
+
+            pred_indices.append(index)
+            if self.state_bag.step == 0:
+                self.state_bag.increment_step(
+                    len(self.prefix_tokens + states.target_indices)
+                )
+            else:
+                self.state_bag.increment_step()
+
+        states.target_indices += pred_indices
+
+        if len(pred_indices) > 0 or finished:
+            finished = finished or len(
+                states.target_indices + pred_indices
+            ) > self.max_len(states)
+            return WriteAction(
+                self.postprocess(states, torch.tensor(pred_indices), finished),
+                finished=finished,
+            )
+        else:
+            return ReadAction()
+
+
+class MMASpeechToTextDecoderAgent(MMATextDecoderAgent):
+    source_type = "speech"

+ 11 - 3
src/seamless_communication/streaming/agents/unity_pipeline.py

@@ -21,7 +21,10 @@ from seamless_communication.models.unity import (
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
 )
-from seamless_communication.models.monotonic_decoder import load_monotonic_decoder_model
+from seamless_communication.models.monotonic_decoder import (
+    load_monotonic_decoder_model,
+    load_monotonic_decoder_config,
+)
 
 from simuleval.agents import AgentPipeline, AgentStates
 from simuleval.data.segments import Segment
@@ -36,6 +39,7 @@ logger = logging.getLogger(__name__)
 
 
 def maybe_reset_states(states: Optional[List[Optional[AgentStates]]]) -> None:
+    assert states is not None
     for s in states:
         if s is not None:
             if isinstance(s, EarlyStoppingMixin):
@@ -46,7 +50,7 @@ def maybe_reset_states(states: Optional[List[Optional[AgentStates]]]) -> None:
 
 class UnitYPipelineMixin:
     """
-    Mixin for fairseq pipeline which works with both AgentPipeline
+    Mixin for UnitY pipeline which works with both AgentPipeline
     and TreeAgentPipeline
     """
 
@@ -79,7 +83,7 @@ class UnitYPipelineMixin:
 
     @classmethod
     def from_args(cls, args: Any) -> UnitYPipelineMixin:
-        return cls(args)
+        return cls()
 
 
 class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
@@ -125,6 +129,9 @@ class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
         unity_model = load_unity_model(asset_card, device=args.device, dtype=args.dtype)
         unity_model.eval()
 
+        monotonic_decoder_config = load_monotonic_decoder_config(
+            args.monotonic_decoder_model_name
+        )
         logger.info(
             f"Loading the Monotonic Decoder model: {args.monotonic_decoder_model_name} on device={args.device}, dtype={args.dtype}"
         )
@@ -141,6 +148,7 @@ class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
                     unity_model=unity_model,
                     unity_config=unity_config,
                     monotonic_decoder_model=monotonic_decoder_model,
+                    monotonic_decoder_config=monotonic_decoder_config,
                     text_tokenizer=text_tokenizer,
                     unit_tokenizer=unit_tokenizer,
                 )

+ 1 - 1
tests/integration/models/test_pretssel_vocoder.py

@@ -30,7 +30,7 @@ def test_pretssel_vocoder(example_rate16k_audio: AudioDecoderOutput) -> None:
     vocoder = load_mel_vocoder_model("vocoder_mel", device=device, dtype=torch.float32)
     vocoder.eval()
 
-    with torch.no_grad():
+    with torch.inference_mode():
         wav_hat = vocoder(feat).view(1, -1)
 
     audio_hat = {"sample_rate": sample_rate, "waveform": wav_hat}