Przeglądaj źródła

Tree Pipeline (#140)

* Tree Pipeline

* Fix parent class
Abinesh Ramakrishnan 1 rok temu
rodzic
commit
b5b98699c6

+ 67 - 18
src/seamless_communication/streaming/agents/unity_pipeline.py

@@ -7,7 +7,7 @@ from __future__ import annotations
 
 import logging
 from argparse import ArgumentParser, Namespace
-from typing import Any, List, Optional
+from typing import Any, Dict, List, Optional
 
 import torch
 from fairseq2.assets import asset_store
@@ -27,7 +27,7 @@ from seamless_communication.streaming.agents.common import (
     AgentStates,
     EarlyStoppingMixin,
 )
-from simuleval.agents import AgentPipeline
+from simuleval.agents import AgentPipeline, TreeAgentPipeline
 from simuleval.agents.agent import GenericAgent
 from simuleval.data.segments import Segment
 
@@ -88,11 +88,8 @@ class UnitYPipelineMixin:
             type=str,
         )
 
-
-class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
-    pipeline: List[GenericAgent] = []
-
-    def __init__(self, args: Namespace):
+    @classmethod
+    def load_model(cls, args: Namespace) -> Dict[str, Any]:
         if not torch.cuda.is_available() and "cuda" in args.device:
             raise ValueError("CUDA not available, use CPU.")
 
@@ -142,25 +139,36 @@ class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
         )
         monotonic_decoder_model.eval()
 
-        self.vocoder = None
+        vocoder = None
         if args.vocoder_name is not None and output_modality == Modality.SPEECH:
-            self.vocoder = load_vocoder_model(
+            vocoder = load_vocoder_model(
                 args.vocoder_name, device=args.device, dtype=args.dtype
             )
-            self.vocoder.eval()
+            vocoder.eval()
+
+        return {
+            "unity_model": unity_model,
+            "unity_config": unity_config,
+            "monotonic_decoder_model": monotonic_decoder_model,
+            "monotonic_decoder_config": monotonic_decoder_config,
+            "text_tokenizer": text_tokenizer,
+            "unit_tokenizer": unit_tokenizer,
+            "vocoder": vocoder,
+        }
+
+
+class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
+    pipeline: List[GenericAgent] = []
+
+    def __init__(self, args: Namespace):
+        models_and_configs = self.load_model(args)
 
         module_list = []
         for p in self.pipeline:
             module_list.append(
                 p.from_args(
                     args,
-                    unity_model=unity_model,
-                    unity_config=unity_config,
-                    monotonic_decoder_model=monotonic_decoder_model,
-                    monotonic_decoder_config=monotonic_decoder_config,
-                    text_tokenizer=text_tokenizer,
-                    unit_tokenizer=unit_tokenizer,
-                    vocoder=self.vocoder,
+                    **models_and_configs,
                 )
             )
 
@@ -187,5 +195,46 @@ class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
         return output_segment
 
     @classmethod
-    def from_args(cls, args: Any) -> UnitYPipelineMixin:
+    def from_args(cls, args: Any) -> UnitYAgentPipeline:
+        return cls(args)
+
+
+class UnitYAgentTreePipeline(UnitYPipelineMixin, TreeAgentPipeline):
+    pipeline = {}
+
+    def __init__(self, args: Namespace):
+        models_and_configs = self.load_model(args)
+
+        assert len(self.pipeline) > 0
+        module_dict = {}
+        for module_class, children in self.pipeline.items():
+            module_dict[module_class.from_args(args, *models_and_configs)] = children
+
+    @classmethod
+    def from_args(cls, args: Any) -> UnitYAgentPipeline:
         return cls(args)
+
+    def pop(
+        self, states: Optional[List[Optional[AgentStates]]] = None
+    ) -> List[Segment]:
+        output_segment = super().pop(states)
+        if states is None:
+            # Not stateless
+            first_states = self.source_module.states
+        else:
+            assert len(states) == len(self.module_dict)
+            first_states = states[self.source_module]
+
+        if not first_states.source_finished and any(
+            segment.finished for segment in output_segment
+        ):
+            # An early stop.
+            # The temporary solution is to start over
+            if states is not None:
+                maybe_reset_states(states.values())
+            else:
+                self.reset()
+            for segment in output_segment:
+                segment.finished = False
+
+        return output_segment