|
@@ -7,6 +7,7 @@ from __future__ import annotations
|
|
|
|
|
|
import logging
|
|
import logging
|
|
from argparse import ArgumentParser, Namespace
|
|
from argparse import ArgumentParser, Namespace
|
|
|
|
+from pathlib import Path
|
|
from typing import Any, Dict, List
|
|
from typing import Any, Dict, List
|
|
|
|
|
|
import torch
|
|
import torch
|
|
@@ -14,6 +15,7 @@ from fairseq2.assets import asset_store
|
|
from fairseq2.data.audio import WaveformToFbankConverter, WaveformToFbankInput
|
|
from fairseq2.data.audio import WaveformToFbankConverter, WaveformToFbankInput
|
|
from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
|
|
from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
|
|
from seamless_communication.models.unity import load_gcmvn_stats
|
|
from seamless_communication.models.unity import load_gcmvn_stats
|
|
|
|
+from seamless_communication.store import add_gated_assets
|
|
from seamless_communication.streaming.agents.common import NoUpdateTargetMixin
|
|
from seamless_communication.streaming.agents.common import NoUpdateTargetMixin
|
|
from simuleval.agents import AgentStates, TextToSpeechAgent
|
|
from simuleval.agents import AgentStates, TextToSpeechAgent
|
|
from simuleval.agents.actions import ReadAction, WriteAction
|
|
from simuleval.agents.actions import ReadAction, WriteAction
|
|
@@ -31,6 +33,9 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent): # type: ign
|
|
def __init__(self, args: Namespace) -> None:
|
|
def __init__(self, args: Namespace) -> None:
|
|
super().__init__(args)
|
|
super().__init__(args)
|
|
|
|
|
|
|
|
+ if args.gated_model_dir:
|
|
|
|
+ add_gated_assets(args.gated_model_dir)
|
|
|
|
+
|
|
logger.info(
|
|
logger.info(
|
|
f"Loading the Vocoder model: {args.vocoder_name} on device={args.device}, dtype={args.dtype}"
|
|
f"Loading the Vocoder model: {args.vocoder_name} on device={args.device}, dtype={args.dtype}"
|
|
)
|
|
)
|
|
@@ -129,6 +134,12 @@ class PretsselVocoderAgent(NoUpdateTargetMixin, TextToSpeechAgent): # type: ign
|
|
|
|
|
|
@classmethod
|
|
@classmethod
|
|
def add_args(cls, parser: ArgumentParser) -> None:
|
|
def add_args(cls, parser: ArgumentParser) -> None:
|
|
|
|
+ param = parser.add_argument(
|
|
|
|
+ "--gated-model-dir",
|
|
|
|
+ type=Path,
|
|
|
|
+ required=False,
|
|
|
|
+ help="SeamlessExpressive model directory.",
|
|
|
|
+ )
|
|
parser.add_argument(
|
|
parser.add_argument(
|
|
"--vocoder-name",
|
|
"--vocoder-name",
|
|
type=str,
|
|
type=str,
|