|
@@ -11,7 +11,6 @@ from typing import Callable, List, Optional, Tuple, Union, cast
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
-
|
|
|
from fairseq2.assets import asset_store
|
|
|
from fairseq2.assets.card import AssetCard
|
|
|
from fairseq2.data import Collater, SequenceData
|
|
@@ -20,7 +19,7 @@ from fairseq2.data.text import TextTokenizer
|
|
|
from fairseq2.data.typing import StringLike
|
|
|
from fairseq2.generation import SequenceGeneratorOptions, SequenceToTextOutput
|
|
|
from fairseq2.memory import MemoryBlock
|
|
|
-from fairseq2.nn.padding import get_seqs_and_padding_mask, PaddingMask
|
|
|
+from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask
|
|
|
from fairseq2.typing import DataType, Device
|
|
|
from torch import Tensor
|
|
|
|
|
@@ -169,7 +168,7 @@ class Translator(nn.Module):
|
|
|
unit_generation_opts: Optional[SequenceGeneratorOptions],
|
|
|
unit_generation_ngram_filtering: bool = False,
|
|
|
duration_factor: float = 1.0,
|
|
|
- gcmvn_fbank: Optional[Tensor] = None,
|
|
|
+ prosody_encoder_input: Optional[SequenceData] = None,
|
|
|
) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
|
|
|
# We disregard unit generations opts for the NAR T2U decoder.
|
|
|
if output_modality != Modality.SPEECH or isinstance(
|
|
@@ -193,7 +192,7 @@ class Translator(nn.Module):
|
|
|
output_modality.value,
|
|
|
ngram_filtering=unit_generation_ngram_filtering,
|
|
|
duration_factor=duration_factor,
|
|
|
- gcmvn_fbank=gcmvn_fbank,
|
|
|
+ prosody_encoder_input=prosody_encoder_input,
|
|
|
)
|
|
|
|
|
|
@staticmethod
|
|
@@ -230,7 +229,7 @@ class Translator(nn.Module):
|
|
|
sample_rate: int = 16000,
|
|
|
unit_generation_ngram_filtering: bool = False,
|
|
|
duration_factor: float = 1.0,
|
|
|
- gcmvn_fbank: Optional[Tensor] = None,
|
|
|
+ prosody_encoder_input: Optional[SequenceData] = None,
|
|
|
) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
|
|
|
"""
|
|
|
The main method used to perform inference on all tasks.
|
|
@@ -315,7 +314,7 @@ class Translator(nn.Module):
|
|
|
unit_generation_opts,
|
|
|
unit_generation_ngram_filtering=unit_generation_ngram_filtering,
|
|
|
duration_factor=duration_factor,
|
|
|
- gcmvn_fbank=gcmvn_fbank,
|
|
|
+ prosody_encoder_input=prosody_encoder_input,
|
|
|
)
|
|
|
|
|
|
if output_modality == Modality.TEXT:
|