|
@@ -7,7 +7,7 @@ from __future__ import annotations
|
|
|
|
|
|
import logging
|
|
import logging
|
|
from argparse import ArgumentParser, Namespace
|
|
from argparse import ArgumentParser, Namespace
|
|
-from typing import Any, List, Optional
|
|
|
|
|
|
+from typing import Any, Dict, List, Optional
|
|
|
|
|
|
import torch
|
|
import torch
|
|
from fairseq2.assets import asset_store
|
|
from fairseq2.assets import asset_store
|
|
@@ -27,7 +27,7 @@ from seamless_communication.streaming.agents.common import (
|
|
AgentStates,
|
|
AgentStates,
|
|
EarlyStoppingMixin,
|
|
EarlyStoppingMixin,
|
|
)
|
|
)
|
|
-from simuleval.agents import AgentPipeline
|
|
|
|
|
|
+from simuleval.agents import AgentPipeline, TreeAgentPipeline
|
|
from simuleval.agents.agent import GenericAgent
|
|
from simuleval.agents.agent import GenericAgent
|
|
from simuleval.data.segments import Segment
|
|
from simuleval.data.segments import Segment
|
|
|
|
|
|
@@ -88,11 +88,8 @@ class UnitYPipelineMixin:
|
|
type=str,
|
|
type=str,
|
|
)
|
|
)
|
|
|
|
|
|
-
|
|
|
|
-class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
|
|
|
|
- pipeline: List[GenericAgent] = []
|
|
|
|
-
|
|
|
|
- def __init__(self, args: Namespace):
|
|
|
|
|
|
+ @classmethod
|
|
|
|
+ def load_model(cls, args: Namespace) -> Dict[str, Any]:
|
|
if not torch.cuda.is_available() and "cuda" in args.device:
|
|
if not torch.cuda.is_available() and "cuda" in args.device:
|
|
raise ValueError("CUDA not available, use CPU.")
|
|
raise ValueError("CUDA not available, use CPU.")
|
|
|
|
|
|
@@ -142,25 +139,36 @@ class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
|
|
)
|
|
)
|
|
monotonic_decoder_model.eval()
|
|
monotonic_decoder_model.eval()
|
|
|
|
|
|
- self.vocoder = None
|
|
|
|
|
|
+ vocoder = None
|
|
if args.vocoder_name is not None and output_modality == Modality.SPEECH:
|
|
if args.vocoder_name is not None and output_modality == Modality.SPEECH:
|
|
- self.vocoder = load_vocoder_model(
|
|
|
|
|
|
+ vocoder = load_vocoder_model(
|
|
args.vocoder_name, device=args.device, dtype=args.dtype
|
|
args.vocoder_name, device=args.device, dtype=args.dtype
|
|
)
|
|
)
|
|
- self.vocoder.eval()
|
|
|
|
|
|
+ vocoder.eval()
|
|
|
|
+
|
|
|
|
+ return {
|
|
|
|
+ "unity_model": unity_model,
|
|
|
|
+ "unity_config": unity_config,
|
|
|
|
+ "monotonic_decoder_model": monotonic_decoder_model,
|
|
|
|
+ "monotonic_decoder_config": monotonic_decoder_config,
|
|
|
|
+ "text_tokenizer": text_tokenizer,
|
|
|
|
+ "unit_tokenizer": unit_tokenizer,
|
|
|
|
+ "vocoder": vocoder,
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
|
|
|
|
+ pipeline: List[GenericAgent] = []
|
|
|
|
+
|
|
|
|
+ def __init__(self, args: Namespace):
|
|
|
|
+ models_and_configs = self.load_model(args)
|
|
|
|
|
|
module_list = []
|
|
module_list = []
|
|
for p in self.pipeline:
|
|
for p in self.pipeline:
|
|
module_list.append(
|
|
module_list.append(
|
|
p.from_args(
|
|
p.from_args(
|
|
args,
|
|
args,
|
|
- unity_model=unity_model,
|
|
|
|
- unity_config=unity_config,
|
|
|
|
- monotonic_decoder_model=monotonic_decoder_model,
|
|
|
|
- monotonic_decoder_config=monotonic_decoder_config,
|
|
|
|
- text_tokenizer=text_tokenizer,
|
|
|
|
- unit_tokenizer=unit_tokenizer,
|
|
|
|
- vocoder=self.vocoder,
|
|
|
|
|
|
+ **models_and_configs,
|
|
)
|
|
)
|
|
)
|
|
)
|
|
|
|
|
|
@@ -187,5 +195,46 @@ class UnitYAgentPipeline(UnitYPipelineMixin, AgentPipeline):
|
|
return output_segment
|
|
return output_segment
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
- def from_args(cls, args: Any) -> UnitYPipelineMixin:
|
|
|
|
|
|
+ def from_args(cls, args: Any) -> UnitYAgentPipeline:
|
|
|
|
+ return cls(args)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+class UnitYAgentTreePipeline(UnitYPipelineMixin, TreeAgentPipeline):
|
|
|
|
+ pipeline = {}
|
|
|
|
+
|
|
|
|
+ def __init__(self, args: Namespace):
|
|
|
|
+ models_and_configs = self.load_model(args)
|
|
|
|
+
|
|
|
|
+ assert len(self.pipeline) > 0
|
|
|
|
+ module_dict = {}
|
|
|
|
+ for module_class, children in self.pipeline.items():
|
|
|
|
+ module_dict[module_class.from_args(args, *models_and_configs)] = children
|
|
|
|
+
|
|
|
|
+ @classmethod
|
|
|
|
+ def from_args(cls, args: Any) -> UnitYAgentPipeline:
|
|
return cls(args)
|
|
return cls(args)
|
|
|
|
+
|
|
|
|
+ def pop(
|
|
|
|
+ self, states: Optional[List[Optional[AgentStates]]] = None
|
|
|
|
+ ) -> List[Segment]:
|
|
|
|
+ output_segment = super().pop(states)
|
|
|
|
+ if states is None:
|
|
|
|
+ # Not stateless
|
|
|
|
+ first_states = self.source_module.states
|
|
|
|
+ else:
|
|
|
|
+ assert len(states) == len(self.module_dict)
|
|
|
|
+ first_states = states[self.source_module]
|
|
|
|
+
|
|
|
|
+ if not first_states.source_finished and any(
|
|
|
|
+ segment.finished for segment in output_segment
|
|
|
|
+ ):
|
|
|
|
+ # An early stop.
|
|
|
|
+ # The temporary solution is to start over
|
|
|
|
+ if states is not None:
|
|
|
|
+ maybe_reset_states(states.values())
|
|
|
|
+ else:
|
|
|
|
+ self.reset()
|
|
|
|
+ for segment in output_segment:
|
|
|
|
+ segment.finished = False
|
|
|
|
+
|
|
|
|
+ return output_segment
|