Преглед на файлове

small afterwards fix due to conda env break yesterday (#142)

Yilin Yang преди 1 година
родител
ревизия
8d49dd0450
променени са 1 файла, в които са добавени 11 реда и са изтрити 8 реда
  1. 11 8
      src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py

+ 11 - 8
src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py

@@ -13,8 +13,8 @@ from pathlib import Path
 from typing import Callable, Dict, List, Optional, Tuple, Union
 
 import torch
-from torch.nn import Module
 import torchaudio
+from fairseq2.assets import asset_store
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater, DataPipeline, FileMapper, SequenceData
 from fairseq2.data.audio import (
@@ -24,13 +24,13 @@ from fairseq2.data.audio import (
 )
 from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
 from fairseq2.generation import SequenceGeneratorOptions
-from fairseq2.typing import DataType, Device
 from fairseq2.nn.padding import get_seqs_and_padding_mask
+from fairseq2.typing import DataType, Device
 from sacrebleu.metrics import BLEU  # type: ignore[attr-defined]
 from torch import Tensor
+from torch.nn import Module
 from tqdm import tqdm
 
-from seamless_communication.models.unity import UnitTokenizer
 from seamless_communication.cli.m4t.evaluate.evaluate import (
     adjust_output_for_corrupted_inputs,
     count_lines,
@@ -40,12 +40,13 @@ from seamless_communication.cli.m4t.predict import (
     set_generation_opts,
 )
 from seamless_communication.inference import BatchedSpeechOutput, Translator
+from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
 from seamless_communication.models.unity import (
+    UnitTokenizer,
     load_gcmvn_stats,
     load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
 )
-from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
 
 logging.basicConfig(
     level=logging.INFO,
@@ -58,7 +59,7 @@ logger = logging.getLogger(__name__)
 class PretsselGenerator(Module):
     def __init__(
         self,
-        pretssel_name_or_card: Union[str, AssetCard],
+        pretssel_name_or_card: str,
         unit_tokenizer: UnitTokenizer,
         device: Device,
         dtype: DataType = torch.float16,
@@ -78,7 +79,7 @@ class PretsselGenerator(Module):
         )
         self.pretssel_model.eval()
 
-        vocoder_model_card = asset_store.retrieve_card(vocoder_name_or_card)
+        vocoder_model_card = asset_store.retrieve_card(pretssel_name_or_card)
         self.output_sample_rate = vocoder_model_card.field("sample_rate").as_(int)
 
         self.unit_tokenizer = unit_tokenizer
@@ -115,7 +116,7 @@ class PretsselGenerator(Module):
 
             duration *= 2
 
-            prosody_input_seq = prosody_input_seqs[i][:prosody_input_lens[i]]
+            prosody_input_seq = prosody_input_seqs[i][: prosody_input_lens[i]]
 
             audio_wav = self.pretssel_model(
                 unit,
@@ -193,7 +194,9 @@ def build_data_pipeline(
 
 def main() -> None:
     parser = argparse.ArgumentParser(description="Running PretsselModel inference")
-    parser.add_argument("data_file", type=Path, help="Data file (.tsv) to be evaluated.")
+    parser.add_argument(
+        "data_file", type=Path, help="Data file (.tsv) to be evaluated."
+    )
 
     parser = add_inference_arguments(parser)
     parser.add_argument(