Przeglądaj źródła

Ability to change tgt_lang dynamically during streaming inference. (#121)

* Text decoder agent improvements

* Fixing overlooked issues

* revert while loop simplification attempt
Abinesh Ramakrishnan 1 rok temu
rodzic
commit
bc88690d56

+ 20 - 14
src/seamless_communication/streaming/agents/online_text_decoder.py

@@ -5,28 +5,26 @@
 # LICENSE file in the root directory of this source tree.
 from __future__ import annotations
 
-import torch
-
 from argparse import ArgumentParser, Namespace
-from torch import Tensor
 from typing import Any, Dict, List, Tuple
 
+import torch
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from fairseq2.nn.incremental_state import IncrementalStateBag
 from seamless_communication.models.monotonic_decoder import (
     MonotonicDecoderConfig,
     MonotonicDecoderModel,
 )
-
 from simuleval.agents import GenericAgent
 from simuleval.agents.actions import Action, ReadAction, WriteAction
 from simuleval.agents.states import AgentStates
 from simuleval.data.segments import Segment, TextSegment
+from torch import Tensor
 
 
 class DecoderAgentStates(AgentStates):
     def reset(self) -> None:
-        self.source_steps = 0
+        self.source_len = 0
         self.target_indices: List[int] = []
         self.tgt_lang = None
         super().reset()
@@ -47,7 +45,7 @@ class DecoderAgentStates(AgentStates):
             if len(self.source) == 0 and segment.finished:
                 self.target_finished = True
                 return
-            self.source_steps = self.source.size(1)
+            self.source_len = self.source.size(1)
 
 
 class OnlineTextDecoderAgent(GenericAgent):
@@ -80,9 +78,9 @@ class OnlineTextDecoderAgent(GenericAgent):
         self.dtype = args.dtype
         self.eos_idx = text_tokenizer.vocab_info.eos_idx
         token_encoder = text_tokenizer.create_encoder(lang=args.tgt_lang, mode="target")
-        prefix_tokens = token_encoder.prefix_indices
-        assert prefix_tokens is not None
-        self.prefix_tokens: List[int] = prefix_tokens.tolist()
+        prefix_indices = token_encoder.prefix_indices
+        assert prefix_indices is not None
+        self.prefix_indices: List[int] = prefix_indices.tolist()
 
     def build_states(self) -> DecoderAgentStates:
         return DecoderAgentStates()
@@ -131,6 +129,12 @@ class OnlineTextDecoderAgent(GenericAgent):
     def policy(self, states: DecoderAgentStates) -> Action:
         raise NotImplementedError
 
+    def enforce_tgt_lang_in_prefix(self, states: DecoderAgentStates) -> None:
+        if states.tgt_lang:
+            tgt_lang_tag = f"__{states.tgt_lang}__"
+            tgt_lang_tag_idx = self.text_tokenizer.model.token_to_index(tgt_lang_tag)
+            self.prefix_indices[-1] = tgt_lang_tag_idx
+
 
 class MMATextDecoderAgent(OnlineTextDecoderAgent):
     def __init__(
@@ -194,8 +198,9 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
         self, states: DecoderAgentStates, pred_indices: List[int]
     ) -> Tuple[int, float, Tensor]:
         if len(pred_indices) == 0:
+            self.enforce_tgt_lang_in_prefix(states)
             target_input = torch.tensor(
-                self.prefix_tokens + states.target_indices,
+                self.prefix_indices + states.target_indices,
                 device=self.device,
                 dtype=torch.int64,
             ).unsqueeze(0)
@@ -204,7 +209,6 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
                 pred_indices[-1:], device=self.device, dtype=torch.int64
             ).unsqueeze(0)
 
-        states.source_steps = states.source.size(1)
         torch.cuda.empty_cache()
 
         encoder_output = states.source
@@ -244,7 +248,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
         if len(states.source) == 0:
             return ReadAction()
 
-        if states.source_steps < self.min_starting_wait and not states.source_finished:
+        if states.source_len < self.min_starting_wait and not states.source_finished:
             return ReadAction()
 
         if states.target_finished:
@@ -255,6 +259,8 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
 
         self.state_bag = IncrementalStateBag(4096)
 
+        states.source_len = states.source.size(1)
+
         pred_indices: List[int] = []
         index = None
         prob = None
@@ -279,7 +285,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
             ):
                 if prob == 1.0:
                     pred_indices = []
-                if states.source_steps < self.min_starting_wait_reset:
+                if states.source_len < self.min_starting_wait_reset:
                     pred_indices = []
                     if len(states.target_indices) < 3:
                         states.target_indices = []
@@ -302,7 +308,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
             pred_indices.append(index)
             if self.state_bag.step == 0:
                 self.state_bag.increment_step(
-                    len(self.prefix_tokens + states.target_indices)
+                    len(self.prefix_indices + states.target_indices)
                 )
             else:
                 self.state_bag.increment_step()

+ 11 - 15
src/seamless_communication/streaming/agents/unity_pipeline.py

@@ -4,32 +4,29 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 from __future__ import annotations
-from simuleval.agents.agent import GenericAgent
 
 import logging
-import torch
-
 from argparse import ArgumentParser, Namespace
 from typing import Any, List, Optional
 
+import torch
 from fairseq2.assets import asset_store
-from seamless_communication.streaming.agents.mixins import EarlyStoppingMixin
 from seamless_communication.inference.translator import Modality, Translator
+from seamless_communication.models.monotonic_decoder import (
+    load_monotonic_decoder_config,
+    load_monotonic_decoder_model,
+)
 from seamless_communication.models.unity import (
     load_unity_config,
     load_unity_model,
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
 )
-from seamless_communication.models.monotonic_decoder import (
-    load_monotonic_decoder_model,
-    load_monotonic_decoder_config,
-)
-
+from seamless_communication.streaming.agents.mixins import EarlyStoppingMixin
 from simuleval.agents import AgentPipeline, AgentStates
+from simuleval.agents.agent import GenericAgent
 from simuleval.data.segments import Segment
 
-
 logging.basicConfig(
     level=logging.INFO,
     format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
@@ -81,16 +78,11 @@ class UnitYPipelineMixin:
             type=str,
         )
 
-    @classmethod
-    def from_args(cls, args: Any) -> UnitYPipelineMixin:
-        return cls()
-
 
 class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
     pipeline: List[GenericAgent] = []
 
     def __init__(self, args: Namespace):
-
         if not torch.cuda.is_available() and "cuda" in args.device:
             raise ValueError("CUDA not available, use CPU.")
 
@@ -175,3 +167,7 @@ class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
             output_segment.finished = False
 
         return output_segment
+
+    @classmethod
+    def from_args(cls, args: Any) -> UnitYPipelineMixin:
+        return cls(args)