Quellcode durchsuchen

Enable joint s2t + s2s output for demo (#146)

* Enable joint s2t + s2s output for demo

* mypy
Anna Sun vor 1 Jahr
Ursprung
Commit
8373db2ee5

+ 23 - 0
src/seamless_communication/streaming/agents/detokenizer.py

@@ -14,6 +14,10 @@ from seamless_communication.streaming.agents.common import (
     AgentStates,
     NoUpdateTargetMixin,
 )
+from seamless_communication.streaming.agents.online_text_decoder import (
+    UnitYTextDecoderOutput,
+)
+from simuleval.data.segments import Segment, EmptySegment
 
 
 class DetokenizerAgent(NoUpdateTargetMixin, TextToTextAgent):  # type: ignore
@@ -54,3 +58,22 @@ class DetokenizerAgent(NoUpdateTargetMixin, TextToTextAgent):  # type: ignore
 
     def decode(self, x: str) -> str:
         return x.replace(" ", "").replace("\u2581", " ").strip()
+
+
+class UnitYDetokenizerAgentStates(AgentStates):
+    def update_source(self, segment: Segment) -> None:
+        """
+        Extract tokens from UnitYTextDecoderOutput
+        """
+        self.source_finished = segment.finished
+        if isinstance(segment, EmptySegment):
+            return
+        # TextSegment
+        segment_content: UnitYTextDecoderOutput = segment.content
+        token = segment_content.tokens
+        self.source += token
+
+
+class UnitYDetokenizerAgent(DetokenizerAgent):
+    def build_states(self) -> UnitYDetokenizerAgentStates:
+        return UnitYDetokenizerAgentStates()

+ 29 - 1
src/seamless_communication/streaming/agents/mma_m4t_s2st.py

@@ -16,9 +16,14 @@ from seamless_communication.streaming.agents.online_text_decoder import (
 from seamless_communication.streaming.agents.online_unit_decoder import (
     NARUnitYUnitDecoderAgent,
 )
+from seamless_communication.streaming.agents.silero_vad import SileroVADAgent
 from seamless_communication.streaming.agents.online_vocoder import VocoderAgent
 
-from seamless_communication.streaming.agents.unity_pipeline import UnitYAgentPipeline
+from seamless_communication.streaming.agents.detokenizer import UnitYDetokenizerAgent
+from seamless_communication.streaming.agents.unity_pipeline import (
+    UnitYAgentPipeline,
+    UnitYAgentTreePipeline,
+)
 from simuleval.utils import entrypoint
 
 
@@ -31,3 +36,26 @@ class MonotonicM4TS2STAgent(UnitYAgentPipeline):
         NARUnitYUnitDecoderAgent,
         VocoderAgent,
     ]
+
+
+class MonotonicM4TS2STVADAgent(UnitYAgentPipeline):
+    pipeline = [
+        SileroVADAgent,
+        OnlineFeatureExtractorAgent,
+        OfflineWav2VecBertEncoderAgent,
+        UnitYMMATextDecoderAgent,
+        NARUnitYUnitDecoderAgent,
+        VocoderAgent,
+    ]
+
+
+class MonotonicM4TS2STJointVADAgent(UnitYAgentTreePipeline):
+    pipeline = {
+        SileroVADAgent: [OnlineFeatureExtractorAgent],
+        OnlineFeatureExtractorAgent: [OfflineWav2VecBertEncoderAgent],
+        OfflineWav2VecBertEncoderAgent: [UnitYMMATextDecoderAgent],
+        UnitYMMATextDecoderAgent: [UnitYDetokenizerAgent, NARUnitYUnitDecoderAgent],
+        UnitYDetokenizerAgent: [],
+        NARUnitYUnitDecoderAgent: [VocoderAgent],
+        VocoderAgent: [],
+    }

+ 14 - 6
src/seamless_communication/streaming/agents/unity_pipeline.py

@@ -208,7 +208,9 @@ class UnitYAgentTreePipeline(UnitYPipelineMixin, TreeAgentPipeline):  # type: ig
         assert len(self.pipeline) > 0
         module_dict = {}
         for module_class, children in self.pipeline.items():
-            module_dict[module_class.from_args(args, *models_and_configs)] = children
+            module_dict[module_class.from_args(args, **models_and_configs)] = children
+
+        super().__init__(module_dict, args)
 
     @classmethod
     def from_args(cls, args: Any) -> UnitYAgentPipeline:
@@ -225,16 +227,22 @@ class UnitYAgentTreePipeline(UnitYPipelineMixin, TreeAgentPipeline):  # type: ig
             assert len(states) == len(self.module_dict)
             first_states = states[self.source_module]
 
-        if not first_states.source_finished and any(
-            segment.finished for segment in output_segment
-        ):
+        if isinstance(output_segment, list):
+            finished = any(segment.finished for segment in output_segment)
+        else:
+            # case when output_index is used
+            finished = output_segment.finished
+        if not first_states.source_finished and finished:
             # An early stop.
             # The temporary solution is to start over
             if states is not None:
                 maybe_reset_states(states)
             else:
                 self.reset()
-            for segment in output_segment:
-                segment.finished = False
+            if isinstance(output_segment, list):
+                for segment in output_segment:
+                    segment.finished = False
+            else:
+                output_segment.finished = False
 
         return output_segment  # type: ignore[no-any-return]