|
@@ -6,20 +6,20 @@
|
|
from __future__ import annotations
|
|
from __future__ import annotations
|
|
|
|
|
|
from argparse import ArgumentParser, Namespace
|
|
from argparse import ArgumentParser, Namespace
|
|
-import torch
|
|
|
|
-from typing import Any, Dict
|
|
|
|
|
|
+from typing import Any, Dict, List
|
|
|
|
|
|
-from fairseq2.data.audio import WaveformToFbankConverter
|
|
|
|
|
|
+import torch
|
|
|
|
+from fairseq2.data.audio import WaveformToFbankConverter, WaveformToFbankInput
|
|
|
|
+from seamless_communication.models.generator.vocoder import PretsselVocoder
|
|
from seamless_communication.models.unity import load_gcmvn_stats
|
|
from seamless_communication.models.unity import load_gcmvn_stats
|
|
from seamless_communication.models.vocoder.vocoder import Vocoder
|
|
from seamless_communication.models.vocoder.vocoder import Vocoder
|
|
-from seamless_communication.models.generator.vocoder import PretsselVocoder
|
|
|
|
from seamless_communication.streaming.agents.common import NoUpdateTargetMixin
|
|
from seamless_communication.streaming.agents.common import NoUpdateTargetMixin
|
|
from simuleval.agents import AgentStates, TextToSpeechAgent
|
|
from simuleval.agents import AgentStates, TextToSpeechAgent
|
|
from simuleval.agents.actions import ReadAction, WriteAction
|
|
from simuleval.agents.actions import ReadAction, WriteAction
|
|
from simuleval.data.segments import SpeechSegment
|
|
from simuleval.data.segments import SpeechSegment
|
|
|
|
|
|
|
|
|
|
-class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
|
|
|
|
|
|
+class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent): # type: ignore
|
|
def __init__(self, vocoder: Vocoder, args: Namespace) -> None:
|
|
def __init__(self, vocoder: Vocoder, args: Namespace) -> None:
|
|
super().__init__(args)
|
|
super().__init__(args)
|
|
self.vocoder = vocoder
|
|
self.vocoder = vocoder
|
|
@@ -36,13 +36,15 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
|
|
dtype=args.dtype,
|
|
dtype=args.dtype,
|
|
)
|
|
)
|
|
|
|
|
|
-
|
|
|
|
_gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
|
|
_gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
|
|
- self.gcmvn_mean = torch.tensor(_gcmvn_mean, device=args.device, dtype=args.dtype)
|
|
|
|
|
|
+ self.gcmvn_mean = torch.tensor(
|
|
|
|
+ _gcmvn_mean, device=args.device, dtype=args.dtype
|
|
|
|
+ )
|
|
self.gcmvn_std = torch.tensor(_gcmvn_std, device=args.device, dtype=args.dtype)
|
|
self.gcmvn_std = torch.tensor(_gcmvn_std, device=args.device, dtype=args.dtype)
|
|
|
|
|
|
def gcmvn_normalize(self, seqs: torch.Tensor) -> torch.Tensor:
|
|
def gcmvn_normalize(self, seqs: torch.Tensor) -> torch.Tensor:
|
|
- return seqs.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
|
|
|
|
|
|
+ result: torch.Tensor = seqs.subtract(self.gcmvn_mean).divide(self.gcmvn_std)
|
|
|
|
+ return result
|
|
|
|
|
|
@torch.inference_mode()
|
|
@torch.inference_mode()
|
|
def policy(self, states: AgentStates) -> WriteAction:
|
|
def policy(self, states: AgentStates) -> WriteAction:
|
|
@@ -66,15 +68,18 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
|
|
|
|
|
|
duration *= 2
|
|
duration *= 2
|
|
|
|
|
|
- if type(states.upstream_states[self.upstream_idx].source) == list:
|
|
|
|
- source = sum(states.upstream_states[self.upstream_idx].source, [])
|
|
|
|
|
|
+ if isinstance(states.upstream_states[self.upstream_idx].source, list):
|
|
|
|
+ source: List[float] = sum(
|
|
|
|
+ states.upstream_states[self.upstream_idx].source, []
|
|
|
|
+ )
|
|
else:
|
|
else:
|
|
source = states.upstream_states[self.upstream_idx].source
|
|
source = states.upstream_states[self.upstream_idx].source
|
|
|
|
|
|
- audio_dict = {
|
|
|
|
- "waveform": torch.tensor(source, dtype=torch.float32, device=self.device).unsqueeze(1),
|
|
|
|
|
|
+ audio_dict: WaveformToFbankInput = {
|
|
|
|
+ "waveform": torch.tensor(
|
|
|
|
+ source, dtype=torch.float32, device=self.device
|
|
|
|
+ ).unsqueeze(1),
|
|
"sample_rate": self.sample_rate,
|
|
"sample_rate": self.sample_rate,
|
|
- "format": -1,
|
|
|
|
}
|
|
}
|
|
|
|
|
|
feats = self.convert_to_fbank(audio_dict)["fbank"]
|
|
feats = self.convert_to_fbank(audio_dict)["fbank"]
|
|
@@ -115,11 +120,13 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent):
|
|
"--vocoder-sample-rate",
|
|
"--vocoder-sample-rate",
|
|
type=int,
|
|
type=int,
|
|
default=16000,
|
|
default=16000,
|
|
- help="sample rate out of the vocoder"
|
|
|
|
|
|
+ help="sample rate out of the vocoder",
|
|
)
|
|
)
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
- def from_args(cls, args: Namespace, **kwargs: Dict[str, Any]) -> PretsselVocoderAgent:
|
|
|
|
|
|
+ def from_args(
|
|
|
|
+ cls, args: Namespace, **kwargs: Dict[str, Any]
|
|
|
|
+ ) -> PretsselVocoderAgent:
|
|
vocoder = kwargs.get("vocoder", None)
|
|
vocoder = kwargs.get("vocoder", None)
|
|
assert isinstance(vocoder, PretsselVocoder)
|
|
assert isinstance(vocoder, PretsselVocoder)
|
|
return cls(vocoder, args)
|
|
return cls(vocoder, args)
|