Browse Source

Introduce online unit decoder SimulEval agent. (#115)

* Initial commit for unit decoder agent.

* Cleanup and polish the unit decoder agent

* Remove spm detok agent.

* Changes to get the pipeline working (#129)

* Incrementing state_bag when we append a comma to the model output.

* Fix bug in getting the token index of comma.

* Clean up online unit decoder, just calling the model's forward().

* Fix bug in unity generator.

---------

Co-authored-by: ibanesh <3632454+ibanesh@users.noreply.github.com>
Kaushik Ram Sadagopan 1 year ago
parent
commit
6fadf9e320

+ 3 - 3
src/seamless_communication/inference/generator.py

@@ -252,7 +252,7 @@ class UnitYGenerator:
             )
             unit_seqs, _ = unit_gen_output.collate()
         else:
-            unit_decoder_output, decoder_padding_mask, _ = self.model.t2u_model(
+            t2u_model_output, decoder_padding_mask, _ = self.model.t2u_model(
                 text_decoder_output=decoder_output,
                 text_decoder_padding_mask=decoder_padding_mask,
                 text_seqs=text_seqs,
@@ -260,10 +260,10 @@ class UnitYGenerator:
                 film_cond_emb=prosody_encoder_out,
             )
             # (B, S_unit, V_unit)
-            unit_seqs = unit_decoder_output.logits.argmax(dim=2)
+            unit_seqs = t2u_model_output.logits.argmax(dim=2)
             # Apply the padding mask to the generated units.
             unit_seqs = apply_padding_mask(
-                unit_seqs, decoder_padding_mask, unit_decoder_output.vocab_info.pad_idx
+                unit_seqs, decoder_padding_mask, t2u_model_output.vocab_info.pad_idx
             )
 
         # Convert to speech units.

+ 3 - 0
src/seamless_communication/streaming/agents/__init__.py

@@ -4,6 +4,9 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+from seamless_communication.streaming.agents.mma_m4t_s2st import (
+    MonotonicM4TS2STAgent as MonotonicM4TS2STAgent,
+)
 from seamless_communication.streaming.agents.mma_m4t_s2t import (
     MonotonicM4TS2TAgent as MonotonicM4TS2TAgent,
 )

+ 30 - 0
src/seamless_communication/streaming/agents/mma_m4t_s2st.py

@@ -0,0 +1,30 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from seamless_communication.streaming.agents.offline_w2v_bert_encoder import (
+    OfflineWav2VecBertEncoderAgent,
+)
+from seamless_communication.streaming.agents.online_feature_extractor import (
+    OnlineFeatureExtractorAgent,
+)
+from seamless_communication.streaming.agents.online_text_decoder import (
+    UnitYMMATextDecoderAgent,
+)
+from seamless_communication.streaming.agents.online_unit_decoder import (
+    NARUnitYUnitDecoderAgent,
+)
+from seamless_communication.streaming.agents.unity_pipeline import UnitYAgentPipeline
+from simuleval.utils import entrypoint
+
+
+@entrypoint
+class MonotonicM4TS2STAgent(UnitYAgentPipeline):
+    pipeline = [
+        OnlineFeatureExtractorAgent,
+        OfflineWav2VecBertEncoderAgent,
+        UnitYMMATextDecoderAgent,
+        NARUnitYUnitDecoderAgent,
+    ]

+ 78 - 9
src/seamless_communication/streaming/agents/online_text_decoder.py

@@ -6,7 +6,8 @@
 from __future__ import annotations
 
 from argparse import ArgumentParser, Namespace
-from typing import Any, Dict, List, Set, Tuple
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Set, Tuple
 
 import torch
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
@@ -77,13 +78,19 @@ class OnlineTextDecoderAgent(GenericAgent):
         self.device = args.device
         self.dtype = args.dtype
         self.eos_idx = text_tokenizer.vocab_info.eos_idx
-        if getattr(args, "tgt_lang", None) and getattr(args, "prefix_tgt_lang", None):
+        if (
+            hasattr(args, "tgt_lang")
+            and hasattr(args, "prefix_tgt_lang")
+            and args.tgt_lang is not None
+            and args.prefix_tgt_lang is not None
+        ):
             assert args.tgt_lang == args.prefix_tgt_lang
         tgt_lang = getattr(args, "tgt_lang", None) or getattr(
             args, "prefix_tgt_lang", None
         )
-        token_encoder = text_tokenizer.create_encoder(lang=tgt_lang, mode="target")
-        prefix_indices = token_encoder.prefix_indices
+        assert tgt_lang is not None
+        self.token_encoder = text_tokenizer.create_encoder(lang=tgt_lang, mode="target")
+        prefix_indices = self.token_encoder.prefix_indices
         assert prefix_indices is not None
         self.prefix_indices: List[int] = prefix_indices.tolist()
 
@@ -218,8 +225,6 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
                 pred_indices[-1:], device=self.device, dtype=torch.int64
             ).unsqueeze(0)
 
-        torch.cuda.empty_cache()
-
         encoder_output = states.source
         decoder_output, _, p_choose = self.model.decode(
             target_input, None, encoder_output, None, state_bag=self.state_bag
@@ -246,7 +251,11 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
         return index, prob, decoder_output
 
     def postprocess(
-        self, states: DecoderAgentStates, pred_indices: Tensor, finished: bool
+        self,
+        states: DecoderAgentStates,
+        pred_indices: List[int],
+        finished: bool,
+        decoder_features_out: Optional[Tensor] = None,
     ) -> TextSegment:
         return TextSegment(
             content=" ".join(
@@ -319,12 +328,19 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
         prob = None
         finished = False
         blocked_ngrams = self.get_blocked_ngrams(states.target_indices)
+        decoder_features_out = None
 
         while (
             len(states.target_indices + pred_indices) < self.max_len(states)
             and len(pred_indices) < self.max_consecutive_writes
         ):
-            index, prob, _ = self.run_decoder(states, pred_indices)
+            index, prob, decoder_features = self.run_decoder(states, pred_indices)
+
+            if decoder_features_out is None:
+                decoder_features_out = decoder_features.new(0)
+            decoder_features_out = torch.cat(
+                [decoder_features_out, decoder_features], dim=1
+            )
 
             if (
                 self.no_early_stop
@@ -366,7 +382,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
             ) > self.max_len(states)
             states.ngram_block_count = 0
             return WriteAction(
-                self.postprocess(states, torch.tensor(pred_indices), finished),
+                self.postprocess(states, pred_indices, finished, decoder_features_out),
                 finished=finished,
             )
         else:
@@ -375,3 +391,56 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
 
 class MMASpeechToTextDecoderAgent(MMATextDecoderAgent):
     source_type = "speech"
+
+
+@dataclass
+class UnitYTextDecoderOutput:
+    decoder_features: Tensor
+    tokens: List[str]
+    target_indices: Optional[Tensor] = None
+
+
+class UnitYMMATextDecoderAgent(MMASpeechToTextDecoderAgent):
+    """
+    MMA UnitY text decoder agent which just prepares the decoder
+    output for the downstream agent.
+    """
+
+    def postprocess(
+        self,
+        states: DecoderAgentStates,
+        pred_indices: List[int],
+        finished: bool,
+        decoder_features_out: Optional[Tensor] = None,
+    ) -> TextSegment:
+        tokens: List[str] = [
+            self.text_tokenizer.model.index_to_token(idx) for idx in pred_indices
+        ]
+        assert decoder_features_out is not None
+        token_list = self.prefix_indices + states.target_indices
+        if (
+            len(pred_indices) > 0
+            and pred_indices[-1] != self.text_tokenizer.vocab_info.eos_idx
+        ):
+            # Append "," to make speech smooth
+            # TODO: a temporary solution.
+            ending_token_index = self.text_tokenizer.model.token_to_index(",")
+            token_list.append(ending_token_index)
+            self.state_bag.increment_step()
+
+            _, _, decoder_features = self.run_decoder(states, [ending_token_index])
+            decoder_features_out = torch.cat(
+                [decoder_features_out, decoder_features], dim=1
+            )
+
+        target_input = torch.tensor(
+            token_list,
+            device=self.device,
+            dtype=torch.int64,
+        ).unsqueeze(0)
+
+        return TextSegment(
+            content=UnitYTextDecoderOutput(decoder_features_out, tokens, target_input),
+            finished=finished,
+            tgt_lang=states.tgt_lang,
+        )

+ 156 - 0
src/seamless_communication/streaming/agents/online_unit_decoder.py

@@ -0,0 +1,156 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+from __future__ import annotations
+
+from argparse import ArgumentParser, Namespace
+from typing import Any, List, Optional
+
+import torch
+from seamless_communication.models.unity.model import UnitYModel, UnitYNART2UModel
+from seamless_communication.models.unity.unit_tokenizer import UnitTokenizer
+from seamless_communication.streaming.agents.online_text_decoder import (
+    UnitYTextDecoderOutput,
+)
+from simuleval.agents import GenericAgent
+from simuleval.agents.actions import Action, ReadAction, WriteAction
+from simuleval.agents.states import AgentStates
+from simuleval.data.segments import Segment, TextSegment
+
+
+class NARUnitDecoderAgentStates(AgentStates):
+    def reset(self) -> None:
+        self.source_token_list: List[str] = []
+        self.source_indices: Optional[torch.Tensor] = None
+        self.duration_start_index: int = 0
+        self.tgt_lang = None
+        super().reset()
+
+    def update_source(self, segment: Segment) -> None:
+        """
+        Update states from input segment
+        Additionlly update incremental states
+
+        Args:
+            segment (~simuleval.agents.segments.Segment): input segment
+        """
+        self.source_finished = segment.finished
+        if self.tgt_lang is None and segment.tgt_lang is not None:
+            self.tgt_lang = segment.tgt_lang
+        if segment.is_empty:
+            if segment.finished:
+                self.target_finished = True
+            return
+        segment_content: UnitYTextDecoderOutput = segment.content
+        content = segment_content.decoder_features
+        token = segment_content.tokens
+        self.source_indices = segment_content.target_indices
+        self.source_token_list += token
+        self.source = content
+
+
+class NARUnitYUnitDecoderAgent(GenericAgent):
+    """Non-autoregressive unit decoder"""
+
+    source_type = "text"
+    target_type = "text"
+
+    def __init__(
+        self, model: UnitYNART2UModel, tokenizer: UnitTokenizer, args: Namespace
+    ) -> None:
+        self.model = model
+        self.tokenizer = tokenizer
+        self.min_unit_chunk_size = args.min_unit_chunk_size
+        self.d_factor = args.d_factor
+        self.device = args.device
+        self.dtype = args.dtype
+        self.token_decoder = self.tokenizer.create_decoder()
+        super().__init__(args)
+
+    def build_states(self) -> NARUnitDecoderAgentStates:
+        return NARUnitDecoderAgentStates()
+
+    @property
+    def max_len(self) -> int:
+        return 200
+
+    @staticmethod
+    def add_args(parser: ArgumentParser) -> None:
+        parser.add_argument(
+            "--min-unit-chunk-size",
+            type=int,
+            required=True,
+            help="Minimal units to produce every chunk",
+        )
+        parser.add_argument(
+            "--d-factor",
+            type=float,
+            default=1.0,
+            help="scaling factor for duration prediction",
+        )
+
+    @torch.inference_mode()
+    def policy(self, states: NARUnitDecoderAgentStates) -> Action:
+        if states.target_finished:
+            return WriteAction("", finished=True)
+
+        if len(states.source_token_list) < 2:
+            if not states.source_finished:
+                return ReadAction()
+            else:
+                return WriteAction("", finished=True)
+
+        model_output, _, durations = self.model(
+            text_decoder_output=states.source,
+            text_decoder_padding_mask=None,
+            text_seqs=states.source_indices,
+            duration_factor=self.d_factor,
+        )
+        durations = durations[0]
+
+        if states.source_finished and states.duration_start_index > 0:
+            # We have to consider one more word for EOS, because we append an EOS at the end.
+            if sum(durations[states.duration_start_index :]) == 0:
+                # If you reach here, it means that the last source token is a silence token (e.g. punctuations)
+                # In that case no need to consider one more token.
+                return WriteAction("", finished=True)
+            else:
+                states.duration_start_index = max(states.duration_start_index - 1, 0)
+
+        current_duration = sum(durations[states.duration_start_index :])
+
+        if current_duration < self.min_unit_chunk_size:
+            if not states.source_finished:
+                # if current untranslated source result less than self.min_unit_chunk_size units
+                return ReadAction()
+            else:
+                if current_duration == 0:
+                    return WriteAction("", finished=True)
+
+        unit_seqs = model_output.logits[0].argmax(dim=-1)
+        index_start_offset = sum(durations[: states.duration_start_index])
+        unit_seqs = unit_seqs[index_start_offset:].unsqueeze(0)
+        units = self.token_decoder(unit_seqs)
+
+        # minus one because we add a ending_token on each s2t output phrase
+        states.duration_start_index = len(durations) - 1
+
+        return WriteAction(
+            TextSegment(
+                content=units,
+                finished=states.source_finished,
+                tgt_lang=states.tgt_lang,
+            ),
+            finished=states.source_finished,
+        )
+
+    @classmethod
+    def from_args(cls, args: Namespace, **kwargs: Any) -> NARUnitYUnitDecoderAgent:
+        unity_model: UnitYModel = kwargs.get("unity_model", None)
+        unit_tokenizer: UnitTokenizer = kwargs.get("unit_tokenizer", None)
+        assert unity_model.t2u_model is not None and isinstance(
+            unity_model.t2u_model, UnitYNART2UModel
+        )
+        return cls(model=unity_model.t2u_model, tokenizer=unit_tokenizer, args=args)

+ 2 - 2
src/seamless_communication/streaming/agents/unity_pipeline.py

@@ -62,13 +62,13 @@ class UnitYPipelineMixin:
             "--unity-model-name",
             type=str,
             help="Unity model name.",
-            default="unity_sans_decoder",
+            default="seamless_streaming_unity",
         )
         parser.add_argument(
             "--monotonic-decoder-model-name",
             type=str,
             help="Monotonic decoder model name.",
-            default="monotonic_decoder",
+            default="seamless_streaming_monotonic_decoder",
         )
         parser.add_argument(
             "--sample-rate",