|
@@ -7,6 +7,7 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
+from fairseq2.assets import asset_store
|
|
|
from fairseq2.assets.card import AssetCard
|
|
|
from fairseq2.data import Collater, SequenceData
|
|
|
from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask
|
|
@@ -53,6 +54,9 @@ class PretsselGenerator(nn.Module):
|
|
|
)
|
|
|
self.pretssel_model.eval()
|
|
|
|
|
|
+ vocoder_model_card = asset_store.retrieve_card(vocoder_name_or_card)
|
|
|
+ self.output_sample_rate = vocoder_model_card.field("sample_rate").as_(int)
|
|
|
+
|
|
|
self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
|
|
|
self.unit_collate = Collater(pad_value=self.unit_tokenizer.vocab_info.pad_idx)
|
|
|
self.duration_collate = Collater(pad_value=0)
|
|
@@ -78,7 +82,6 @@ class PretsselGenerator(nn.Module):
|
|
|
units: List[List[int]],
|
|
|
tgt_lang: str,
|
|
|
prosody_encoder_input: SequenceData,
|
|
|
- sample_rate: int = 16000,
|
|
|
) -> BatchedSpeechOutput:
|
|
|
list_units, durations = [], []
|
|
|
unit_eos_token = torch.tensor(
|
|
@@ -130,5 +133,5 @@ class PretsselGenerator(nn.Module):
|
|
|
return BatchedSpeechOutput(
|
|
|
units=units,
|
|
|
audio_wavs=audio_wavs,
|
|
|
- sample_rate=sample_rate,
|
|
|
+ sample_rate=self.output_sample_rate,
|
|
|
)
|