瀏覽代碼

Ability to change tgt_lang dynamically during streaming inference. (#121)

* Text decoder agent improvements

* Fixing overlooked issues

* revert while loop simplification attempt
Abinesh Ramakrishnan 1 年之前
父節點
當前提交
bc88690d56

+ 20 - 14
src/seamless_communication/streaming/agents/online_text_decoder.py

@@ -5,28 +5,26 @@
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 from __future__ import annotations
 from __future__ import annotations
 
 
-import torch
-
 from argparse import ArgumentParser, Namespace
 from argparse import ArgumentParser, Namespace
-from torch import Tensor
 from typing import Any, Dict, List, Tuple
 from typing import Any, Dict, List, Tuple
 
 
+import torch
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from fairseq2.models.nllb.tokenizer import NllbTokenizer
 from fairseq2.nn.incremental_state import IncrementalStateBag
 from fairseq2.nn.incremental_state import IncrementalStateBag
 from seamless_communication.models.monotonic_decoder import (
 from seamless_communication.models.monotonic_decoder import (
     MonotonicDecoderConfig,
     MonotonicDecoderConfig,
     MonotonicDecoderModel,
     MonotonicDecoderModel,
 )
 )
-
 from simuleval.agents import GenericAgent
 from simuleval.agents import GenericAgent
 from simuleval.agents.actions import Action, ReadAction, WriteAction
 from simuleval.agents.actions import Action, ReadAction, WriteAction
 from simuleval.agents.states import AgentStates
 from simuleval.agents.states import AgentStates
 from simuleval.data.segments import Segment, TextSegment
 from simuleval.data.segments import Segment, TextSegment
+from torch import Tensor
 
 
 
 
 class DecoderAgentStates(AgentStates):
 class DecoderAgentStates(AgentStates):
     def reset(self) -> None:
     def reset(self) -> None:
-        self.source_steps = 0
+        self.source_len = 0
         self.target_indices: List[int] = []
         self.target_indices: List[int] = []
         self.tgt_lang = None
         self.tgt_lang = None
         super().reset()
         super().reset()
@@ -47,7 +45,7 @@ class DecoderAgentStates(AgentStates):
             if len(self.source) == 0 and segment.finished:
             if len(self.source) == 0 and segment.finished:
                 self.target_finished = True
                 self.target_finished = True
                 return
                 return
-            self.source_steps = self.source.size(1)
+            self.source_len = self.source.size(1)
 
 
 
 
 class OnlineTextDecoderAgent(GenericAgent):
 class OnlineTextDecoderAgent(GenericAgent):
@@ -80,9 +78,9 @@ class OnlineTextDecoderAgent(GenericAgent):
         self.dtype = args.dtype
         self.dtype = args.dtype
         self.eos_idx = text_tokenizer.vocab_info.eos_idx
         self.eos_idx = text_tokenizer.vocab_info.eos_idx
         token_encoder = text_tokenizer.create_encoder(lang=args.tgt_lang, mode="target")
         token_encoder = text_tokenizer.create_encoder(lang=args.tgt_lang, mode="target")
-        prefix_tokens = token_encoder.prefix_indices
-        assert prefix_tokens is not None
-        self.prefix_tokens: List[int] = prefix_tokens.tolist()
+        prefix_indices = token_encoder.prefix_indices
+        assert prefix_indices is not None
+        self.prefix_indices: List[int] = prefix_indices.tolist()
 
 
     def build_states(self) -> DecoderAgentStates:
     def build_states(self) -> DecoderAgentStates:
         return DecoderAgentStates()
         return DecoderAgentStates()
@@ -131,6 +129,12 @@ class OnlineTextDecoderAgent(GenericAgent):
     def policy(self, states: DecoderAgentStates) -> Action:
     def policy(self, states: DecoderAgentStates) -> Action:
         raise NotImplementedError
         raise NotImplementedError
 
 
+    def enforce_tgt_lang_in_prefix(self, states: DecoderAgentStates) -> None:
+        if states.tgt_lang:
+            tgt_lang_tag = f"__{states.tgt_lang}__"
+            tgt_lang_tag_idx = self.text_tokenizer.model.token_to_index(tgt_lang_tag)
+            self.prefix_indices[-1] = tgt_lang_tag_idx
+
 
 
 class MMATextDecoderAgent(OnlineTextDecoderAgent):
 class MMATextDecoderAgent(OnlineTextDecoderAgent):
     def __init__(
     def __init__(
@@ -194,8 +198,9 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
         self, states: DecoderAgentStates, pred_indices: List[int]
         self, states: DecoderAgentStates, pred_indices: List[int]
     ) -> Tuple[int, float, Tensor]:
     ) -> Tuple[int, float, Tensor]:
         if len(pred_indices) == 0:
         if len(pred_indices) == 0:
+            self.enforce_tgt_lang_in_prefix(states)
             target_input = torch.tensor(
             target_input = torch.tensor(
-                self.prefix_tokens + states.target_indices,
+                self.prefix_indices + states.target_indices,
                 device=self.device,
                 device=self.device,
                 dtype=torch.int64,
                 dtype=torch.int64,
             ).unsqueeze(0)
             ).unsqueeze(0)
@@ -204,7 +209,6 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
                 pred_indices[-1:], device=self.device, dtype=torch.int64
                 pred_indices[-1:], device=self.device, dtype=torch.int64
             ).unsqueeze(0)
             ).unsqueeze(0)
 
 
-        states.source_steps = states.source.size(1)
         torch.cuda.empty_cache()
         torch.cuda.empty_cache()
 
 
         encoder_output = states.source
         encoder_output = states.source
@@ -244,7 +248,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
         if len(states.source) == 0:
         if len(states.source) == 0:
             return ReadAction()
             return ReadAction()
 
 
-        if states.source_steps < self.min_starting_wait and not states.source_finished:
+        if states.source_len < self.min_starting_wait and not states.source_finished:
             return ReadAction()
             return ReadAction()
 
 
         if states.target_finished:
         if states.target_finished:
@@ -255,6 +259,8 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
 
 
         self.state_bag = IncrementalStateBag(4096)
         self.state_bag = IncrementalStateBag(4096)
 
 
+        states.source_len = states.source.size(1)
+
         pred_indices: List[int] = []
         pred_indices: List[int] = []
         index = None
         index = None
         prob = None
         prob = None
@@ -279,7 +285,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
             ):
             ):
                 if prob == 1.0:
                 if prob == 1.0:
                     pred_indices = []
                     pred_indices = []
-                if states.source_steps < self.min_starting_wait_reset:
+                if states.source_len < self.min_starting_wait_reset:
                     pred_indices = []
                     pred_indices = []
                     if len(states.target_indices) < 3:
                     if len(states.target_indices) < 3:
                         states.target_indices = []
                         states.target_indices = []
@@ -302,7 +308,7 @@ class MMATextDecoderAgent(OnlineTextDecoderAgent):
             pred_indices.append(index)
             pred_indices.append(index)
             if self.state_bag.step == 0:
             if self.state_bag.step == 0:
                 self.state_bag.increment_step(
                 self.state_bag.increment_step(
-                    len(self.prefix_tokens + states.target_indices)
+                    len(self.prefix_indices + states.target_indices)
                 )
                 )
             else:
             else:
                 self.state_bag.increment_step()
                 self.state_bag.increment_step()

+ 11 - 15
src/seamless_communication/streaming/agents/unity_pipeline.py

@@ -4,32 +4,29 @@
 # This source code is licensed under the license found in the
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 # LICENSE file in the root directory of this source tree.
 from __future__ import annotations
 from __future__ import annotations
-from simuleval.agents.agent import GenericAgent
 
 
 import logging
 import logging
-import torch
-
 from argparse import ArgumentParser, Namespace
 from argparse import ArgumentParser, Namespace
 from typing import Any, List, Optional
 from typing import Any, List, Optional
 
 
+import torch
 from fairseq2.assets import asset_store
 from fairseq2.assets import asset_store
-from seamless_communication.streaming.agents.mixins import EarlyStoppingMixin
 from seamless_communication.inference.translator import Modality, Translator
 from seamless_communication.inference.translator import Modality, Translator
+from seamless_communication.models.monotonic_decoder import (
+    load_monotonic_decoder_config,
+    load_monotonic_decoder_model,
+)
 from seamless_communication.models.unity import (
 from seamless_communication.models.unity import (
     load_unity_config,
     load_unity_config,
     load_unity_model,
     load_unity_model,
     load_unity_text_tokenizer,
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
     load_unity_unit_tokenizer,
 )
 )
-from seamless_communication.models.monotonic_decoder import (
-    load_monotonic_decoder_model,
-    load_monotonic_decoder_config,
-)
-
+from seamless_communication.streaming.agents.mixins import EarlyStoppingMixin
 from simuleval.agents import AgentPipeline, AgentStates
 from simuleval.agents import AgentPipeline, AgentStates
+from simuleval.agents.agent import GenericAgent
 from simuleval.data.segments import Segment
 from simuleval.data.segments import Segment
 
 
-
 logging.basicConfig(
 logging.basicConfig(
     level=logging.INFO,
     level=logging.INFO,
     format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
     format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
@@ -81,16 +78,11 @@ class UnitYPipelineMixin:
             type=str,
             type=str,
         )
         )
 
 
-    @classmethod
-    def from_args(cls, args: Any) -> UnitYPipelineMixin:
-        return cls()
-
 
 
 class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
 class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
     pipeline: List[GenericAgent] = []
     pipeline: List[GenericAgent] = []
 
 
     def __init__(self, args: Namespace):
     def __init__(self, args: Namespace):
-
         if not torch.cuda.is_available() and "cuda" in args.device:
         if not torch.cuda.is_available() and "cuda" in args.device:
             raise ValueError("CUDA not available, use CPU.")
             raise ValueError("CUDA not available, use CPU.")
 
 
@@ -175,3 +167,7 @@ class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
             output_segment.finished = False
             output_segment.finished = False
 
 
         return output_segment
         return output_segment
+
+    @classmethod
+    def from_args(cls, args: Any) -> UnitYPipelineMixin:
+        return cls(args)