|
@@ -13,8 +13,8 @@ from pathlib import Path
|
|
|
from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
-from torch.nn import Module
|
|
|
import torchaudio
|
|
|
+from fairseq2.assets import asset_store
|
|
|
from fairseq2.assets.card import AssetCard
|
|
|
from fairseq2.data import Collater, DataPipeline, FileMapper, SequenceData
|
|
|
from fairseq2.data.audio import (
|
|
@@ -24,13 +24,13 @@ from fairseq2.data.audio import (
|
|
|
)
|
|
|
from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
|
|
|
from fairseq2.generation import SequenceGeneratorOptions
|
|
|
-from fairseq2.typing import DataType, Device
|
|
|
from fairseq2.nn.padding import get_seqs_and_padding_mask
|
|
|
+from fairseq2.typing import DataType, Device
|
|
|
from sacrebleu.metrics import BLEU # type: ignore[attr-defined]
|
|
|
from torch import Tensor
|
|
|
+from torch.nn import Module
|
|
|
from tqdm import tqdm
|
|
|
|
|
|
-from seamless_communication.models.unity import UnitTokenizer
|
|
|
from seamless_communication.cli.m4t.evaluate.evaluate import (
|
|
|
adjust_output_for_corrupted_inputs,
|
|
|
count_lines,
|
|
@@ -40,12 +40,13 @@ from seamless_communication.cli.m4t.predict import (
|
|
|
set_generation_opts,
|
|
|
)
|
|
|
from seamless_communication.inference import BatchedSpeechOutput, Translator
|
|
|
+from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
|
|
|
from seamless_communication.models.unity import (
|
|
|
+ UnitTokenizer,
|
|
|
load_gcmvn_stats,
|
|
|
load_unity_text_tokenizer,
|
|
|
load_unity_unit_tokenizer,
|
|
|
)
|
|
|
-from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
|
|
|
|
|
|
logging.basicConfig(
|
|
|
level=logging.INFO,
|
|
@@ -58,7 +59,7 @@ logger = logging.getLogger(__name__)
|
|
|
class PretsselGenerator(Module):
|
|
|
def __init__(
|
|
|
self,
|
|
|
- pretssel_name_or_card: Union[str, AssetCard],
|
|
|
+ pretssel_name_or_card: str,
|
|
|
unit_tokenizer: UnitTokenizer,
|
|
|
device: Device,
|
|
|
dtype: DataType = torch.float16,
|
|
@@ -78,7 +79,7 @@ class PretsselGenerator(Module):
|
|
|
)
|
|
|
self.pretssel_model.eval()
|
|
|
|
|
|
- vocoder_model_card = asset_store.retrieve_card(vocoder_name_or_card)
|
|
|
+ vocoder_model_card = asset_store.retrieve_card(pretssel_name_or_card)
|
|
|
self.output_sample_rate = vocoder_model_card.field("sample_rate").as_(int)
|
|
|
|
|
|
self.unit_tokenizer = unit_tokenizer
|
|
@@ -115,7 +116,7 @@ class PretsselGenerator(Module):
|
|
|
|
|
|
duration *= 2
|
|
|
|
|
|
- prosody_input_seq = prosody_input_seqs[i][:prosody_input_lens[i]]
|
|
|
+ prosody_input_seq = prosody_input_seqs[i][: prosody_input_lens[i]]
|
|
|
|
|
|
audio_wav = self.pretssel_model(
|
|
|
unit,
|
|
@@ -193,7 +194,9 @@ def build_data_pipeline(
|
|
|
|
|
|
def main() -> None:
|
|
|
parser = argparse.ArgumentParser(description="Running PretsselModel inference")
|
|
|
- parser.add_argument("data_file", type=Path, help="Data file (.tsv) to be evaluated.")
|
|
|
+ parser.add_argument(
|
|
|
+ "data_file", type=Path, help="Data file (.tsv) to be evaluated."
|
|
|
+ )
|
|
|
|
|
|
parser = add_inference_arguments(parser)
|
|
|
parser.add_argument(
|