Browse Source

Rename gcmvn_fbank to prosody_encoder_input (#105)

* rename gcmvn_fbank to prosody_encoder_input

* Isort + Black
Yilin Yang 1 year ago
parent
commit
1a91d39931

+ 4 - 7
src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -41,8 +41,8 @@ from seamless_communication.cli.m4t.predict import (
 from seamless_communication.inference import BatchedSpeechOutput, Modality, Translator
 from seamless_communication.inference.pretssel_generator import PretsselGenerator
 from seamless_communication.models.unity import (
-    load_unity_text_tokenizer,
     load_gcmvn_stats,
+    load_unity_text_tokenizer,
 )
 
 logging.basicConfig(
@@ -232,9 +232,7 @@ def run_eval(
 
             # Skip performing inference when the input is entirely corrupted.
             if src["seqs"].numel() > 0:
-                gcmvn_fbank, padding_mask = get_seqs_and_padding_mask(
-                    example["audio"]["data"]["gcmvn_fbank"]
-                )
+                prosody_encoder_input = example["audio"]["data"]["gcmvn_fbank"]
                 text_output, unit_output = translator.predict(
                     src,
                     ctx.task,
@@ -244,15 +242,14 @@ def run_eval(
                     unit_generation_opts=ctx.unit_generation_opts,
                     unit_generation_ngram_filtering=ctx.unit_generation_ngram_filtering,
                     duration_factor=ctx.duration_factor,
-                    gcmvn_fbank=gcmvn_fbank,
+                    prosody_encoder_input=prosody_encoder_input,
                 )
 
                 assert unit_output is not None
                 speech_output = pretssel_generator.predict(
                     unit_output.units,
                     tgt_lang=ctx.target_lang,
-                    padding_mask=padding_mask,
-                    gcmvn_fbank=gcmvn_fbank,
+                    prosody_encoder_input=prosody_encoder_input,
                 )
 
             else:

+ 13 - 3
src/seamless_communication/inference/generator.py

@@ -8,6 +8,7 @@ from dataclasses import dataclass
 from typing import List, Optional, Tuple
 
 import torch
+from fairseq2.data import SequenceData
 from fairseq2.data.text import TextTokenizer
 from fairseq2.generation import (
     Seq2SeqGenerator,
@@ -16,7 +17,11 @@ from fairseq2.generation import (
     SequenceToTextGenerator,
     SequenceToTextOutput,
 )
-from fairseq2.nn.padding import PaddingMask, apply_padding_mask
+from fairseq2.nn.padding import (
+    PaddingMask,
+    apply_padding_mask,
+    get_seqs_and_padding_mask,
+)
 from fairseq2.nn.utils.module import infer_device
 from torch import Tensor
 
@@ -154,7 +159,7 @@ class UnitYGenerator:
         output_modality: str = "speech",
         ngram_filtering: bool = False,
         duration_factor: float = 1.0,
-        gcmvn_fbank: Optional[Tensor] = None,
+        prosody_encoder_input: Optional[SequenceData] = None,
     ) -> Tuple[SequenceToTextOutput, Optional["SequenceToUnitOutput"]]:
         """
         :param source_seqs:
@@ -219,8 +224,13 @@ class UnitYGenerator:
         unit_gen_output = None
         prosody_encoder_out = None
         if self.model.prosody_encoder_model is not None:
+            assert prosody_encoder_input is not None
+            prosody_input_seqs, prosody_padding_mask = get_seqs_and_padding_mask(
+                prosody_encoder_input
+            )
             prosody_encoder_out = self.model.prosody_encoder_model(
-                gcmvn_fbank, source_padding_mask
+                prosody_input_seqs,
+                prosody_padding_mask,
             ).unsqueeze(1)
 
         if isinstance(self.model.t2u_model, UnitYT2UModel):

+ 6 - 4
src/seamless_communication/inference/pretssel_generator.py

@@ -77,8 +77,7 @@ class PretsselGenerator(nn.Module):
         self,
         units: List[List[int]],
         tgt_lang: str,
-        padding_mask: Optional[PaddingMask],
-        gcmvn_fbank: Tensor,
+        prosody_encoder_input: SequenceData,
         sample_rate: int = 16000,
     ) -> BatchedSpeechOutput:
         list_units, durations = [], []
@@ -106,12 +105,15 @@ class PretsselGenerator(nn.Module):
         durations = self.duration_collate(durations)["seqs"]
 
         units_tensor, unit_padding_mask = get_seqs_and_padding_mask(speech_units)
+        prosody_input_seqs, prosody_padding_mask = get_seqs_and_padding_mask(
+            prosody_encoder_input
+        )
 
         mel_output = self.pretssel_model(
             units_tensor,
             unit_padding_mask,
-            gcmvn_fbank,
-            padding_mask,
+            prosody_input_seqs,
+            prosody_padding_mask,
             tgt_lang=tgt_lang,
             durations=durations,
         )

+ 5 - 6
src/seamless_communication/inference/translator.py

@@ -11,7 +11,6 @@ from typing import Callable, List, Optional, Tuple, Union, cast
 
 import torch
 import torch.nn as nn
-
 from fairseq2.assets import asset_store
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater, SequenceData
@@ -20,7 +19,7 @@ from fairseq2.data.text import TextTokenizer
 from fairseq2.data.typing import StringLike
 from fairseq2.generation import SequenceGeneratorOptions, SequenceToTextOutput
 from fairseq2.memory import MemoryBlock
-from fairseq2.nn.padding import get_seqs_and_padding_mask, PaddingMask
+from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
 from torch import Tensor
 
@@ -169,7 +168,7 @@ class Translator(nn.Module):
         unit_generation_opts: Optional[SequenceGeneratorOptions],
         unit_generation_ngram_filtering: bool = False,
         duration_factor: float = 1.0,
-        gcmvn_fbank: Optional[Tensor] = None,
+        prosody_encoder_input: Optional[SequenceData] = None,
     ) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
         # We disregard unit generations opts for the NAR T2U decoder.
         if output_modality != Modality.SPEECH or isinstance(
@@ -193,7 +192,7 @@ class Translator(nn.Module):
             output_modality.value,
             ngram_filtering=unit_generation_ngram_filtering,
             duration_factor=duration_factor,
-            gcmvn_fbank=gcmvn_fbank,
+            prosody_encoder_input=prosody_encoder_input,
         )
 
     @staticmethod
@@ -230,7 +229,7 @@ class Translator(nn.Module):
         sample_rate: int = 16000,
         unit_generation_ngram_filtering: bool = False,
         duration_factor: float = 1.0,
-        gcmvn_fbank: Optional[Tensor] = None,
+        prosody_encoder_input: Optional[SequenceData] = None,
     ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
         """
         The main method used to perform inference on all tasks.
@@ -315,7 +314,7 @@ class Translator(nn.Module):
             unit_generation_opts,
             unit_generation_ngram_filtering=unit_generation_ngram_filtering,
             duration_factor=duration_factor,
-            gcmvn_fbank=gcmvn_fbank,
+            prosody_encoder_input=prosody_encoder_input,
         )
 
         if output_modality == Modality.TEXT: