Jelajahi Sumber

Implement PretsselModel & its inference (#89)

* Implement PretsselModel & its inference

* add mel_vocoder

* refactor inference

* minor fix

* minor fix

* minor fix

* mypy pytest isort black formatting

* change padding to 'same'

* minor renaming

* address PR comments

* slight refactor

* minor fix: address PR comment

* make padding_mask optional

---------

Co-authored-by: Tuan Tran <tuantran@devfair0436.h2.fair>
Yilin Yang 1 tahun lalu
induk
melakukan
ed18e69190

+ 181 - 0
src/seamless_communication/cards/pretssel_v1.yaml

@@ -0,0 +1,181 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: pretssel_v1
+model_type: pretssel
+model_arch: base
+checkpoint: "file://checkpoint/mjhwang/experiments/230930-noiseaug_p2v-mls_multilingual_6lang/231005-noiseaug_p2v-mls_multilingual_6lang-alignfix.config_v2.langemb1.vuv_logit1.denoise.ngpu16/checkpoint_best.pt"
+num_units: 10000
+languages:
+  - cmn
+  - deu
+  - eng
+  - fra
+  - ita
+  - spa
+gcmvn_stats:
+  mean:
+    - 9.023406257490224
+    - 9.406622923058864
+    - 10.554165334059368
+    - 11.475190058682356
+    - 12.179117104099705
+    - 12.603782921407062
+    - 12.769632747861747
+    - 12.714276772934083
+    - 12.747612172560233
+    - 12.750373688097946
+    - 12.948050207790237
+    - 13.121829398704277
+    - 13.40130828476734
+    - 13.58028050886195
+    - 13.601835409305883
+    - 13.608734047373218
+    - 13.538274892335826
+    - 13.391518457210937
+    - 13.382843811359622
+    - 13.0524299456858
+    - 12.785193828396269
+    - 12.876608812372632
+    - 12.59571918874957
+    - 12.674484745567813
+    - 12.57325195345546
+    - 12.651938120109422
+    - 12.556821722150424
+    - 12.639338348530158
+    - 12.610449431411217
+    - 12.639992872912376
+    - 12.697503827987052
+    - 12.754788270377214
+    - 12.837605043617405
+    - 12.964379088501497
+    - 13.11997048142582
+    - 13.267395589173432
+    - 13.384668687260483
+    - 13.495000208959356
+    - 13.606835320307384
+    - 13.578073476073252
+    - 13.689796531497368
+    - 13.643079802391588
+    - 13.7340755472615
+    - 13.735199777666043
+    - 13.79347692248429
+    - 13.875183654243305
+    - 13.967272256671393
+    - 14.058507936754117
+    - 14.114704594203507
+    - 14.156211337193277
+    - 14.14747081594401
+    - 14.173917097974343
+    - 14.22330474758318
+    - 14.251272943225572
+    - 14.230904505178053
+    - 14.226937644205396
+    - 14.222223350670225
+    - 14.211638354996317
+    - 14.208930098405544
+    - 14.19476983404041
+    - 14.2195925729048
+    - 14.16490878238837
+    - 14.115436751205117
+    - 14.039442767347872
+    - 13.976934063901625
+    - 13.917068116556464
+    - 13.856293662219073
+    - 13.773769842100085
+    - 13.706245521082796
+    - 13.685052933361192
+    - 13.68570131643094
+    - 13.714811890011152
+    - 13.751451253935347
+    - 13.772212258132148
+    - 13.76013448427468
+    - 13.702368406557508
+    - 13.600406368803617
+    - 13.369574889658164
+    - 12.998399608309988
+    - 12.443732902848723
+  std:
+    - 3.729248515707457
+    - 4.001623098079929
+    - 4.570009061358065
+    - 4.811572361201577
+    - 5.010239923828185
+    - 5.152145212706857
+    - 5.223885876119451
+    - 5.224443623432338
+    - 5.161790275239061
+    - 5.098988232815804
+    - 5.090890035509122
+    - 5.130345212529546
+    - 5.165849688173366
+    - 5.164761699263693
+    - 5.131177988219367
+    - 5.085522051815558
+    - 5.035829108165894
+    - 4.987478975310455
+    - 4.932652442855969
+    - 4.8650037198748075
+    - 4.799238163232527
+    - 4.727086345775988
+    - 4.646858066575789
+    - 4.5733249959652715
+    - 4.51685060334288
+    - 4.467449073425149
+    - 4.4296881304192075
+    - 4.4028775449713775
+    - 4.397905653025904
+    - 4.3862594566308015
+    - 4.366485847923521
+    - 4.344483498393771
+    - 4.324692736391383
+    - 4.310481738978154
+    - 4.3053492473916
+    - 4.3035205126659655
+    - 4.2987898577000605
+    - 4.287403454800855
+    - 4.27087296372773
+    - 4.25387490294079
+    - 4.233513102251301
+    - 4.212047255068752
+    - 4.1810370158214445
+    - 4.186014591107853
+    - 4.194806047136222
+    - 4.2183377208747075
+    - 4.249293562464735
+    - 4.268847210561774
+    - 4.270455756367186
+    - 4.25811368227528
+    - 4.245975115347766
+    - 4.23058010369271
+    - 4.203075111087773
+    - 4.20123812057283
+    - 4.187143614375688
+    - 4.172633823274146
+    - 4.162541203161947
+    - 4.156022884601996
+    - 4.1618428838805706
+    - 4.157259439238067
+    - 4.139859013016601
+    - 4.150685014911159
+    - 4.152025499126372
+    - 4.165010788120131
+    - 4.15179422331336
+    - 4.137041631098819
+    - 4.10861757770052
+    - 4.119916019361405
+    - 4.131749366642117
+    - 4.119438578634397
+    - 4.100095269698108
+    - 4.073900009963118
+    - 4.0580796715728855
+    - 4.050916705279105
+    - 4.037976834115189
+    - 4.023757063156459
+    - 3.9987849927993353
+    - 3.989251079820668
+    - 3.9464430977885256
+    - 3.8673932921278995

+ 1 - 1
src/seamless_communication/cards/vocoder_pretssel.yaml → src/seamless_communication/cards/vocoder_mel.yaml

@@ -4,7 +4,7 @@
 # This source code is licensed under the BSD-style license found in the
 # LICENSE file in the root directory of this source tree.
 
-name: vocoder_pretssel
+name: vocoder_mel
 model_type: vocoder_mel_hifigan
 model_arch: base_mel
 checkpoint: "file://large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt"

+ 66 - 81
src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -12,9 +12,9 @@ from dataclasses import dataclass
 from pathlib import Path
 from typing import Dict, List, Optional, Tuple
 
-import numpy as np
 import torch
 import torchaudio
+from fairseq2.assets import asset_store
 from fairseq2.data import Collater, CString, DataPipeline, FileMapper
 from fairseq2.data.audio import (
     AudioDecoder,
@@ -24,17 +24,26 @@ from fairseq2.data.audio import (
 from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
 from fairseq2.data.typing import PathLike, StringLike
 from fairseq2.generation import SequenceGeneratorOptions
+from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
 from sacrebleu.metrics import BLEU  # type: ignore[attr-defined]
 from torch import Tensor
 from tqdm import tqdm
 
+from seamless_communication.cli.m4t.evaluate.evaluate import (
+    adjust_output_for_corrupted_inputs,
+    count_lines,
+)
 from seamless_communication.cli.m4t.predict import (
     add_inference_arguments,
     set_generation_opts,
 )
 from seamless_communication.inference import BatchedSpeechOutput, Modality, Translator
-from seamless_communication.models.unity import load_unity_text_tokenizer
+from seamless_communication.inference.pretssel_generator import PretsselGenerator
+from seamless_communication.models.unity import (
+    load_unity_text_tokenizer,
+    load_gcmvn_stats,
+)
 
 logging.basicConfig(
     level=logging.INFO,
@@ -97,13 +106,20 @@ class EvalContext:
     """If True, removes consecutive repeating ngrams
     from the decoded unit output."""
 
-    gcmvn_stats: Optional[PathLike] = None
-    """the stats for gcmvn, used by Prosody Encoder"""
+    pretssel_model: str
+    """The name of the PretsselModel"""
+
+    vocoder_name: str
+    """The name of the Vocoder Model"""
+
+    gcmvn_mean: Optional[Tensor]
+    """The mean stats for global-normalized fbank"""
 
+    gcmvn_std: Optional[Tensor]
+    """The std stats for global-normalized fbank"""
 
-def count_lines(filename: Path) -> int:
-    result = subprocess.run(["wc", "-l", filename], stdout=subprocess.PIPE)
-    return int(result.stdout.decode().split()[0])
+    duration_factor: float = 1.1
+    """The duration factor for NAR T2U model. The Expressivity model uses 1.1"""
 
 
 def build_data_pipeline(
@@ -118,15 +134,6 @@ def build_data_pipeline(
 
     split_tsv = StrSplitter(names=header)
 
-    if ctx.gcmvn_stats is not None:
-        if isinstance(ctx.gcmvn_stats, CString):
-            ctx.gcmvn_stats = str(ctx.gcmvn_stats)
-        gcmvn_stats: Dict[str, np.ndarray] = np.load(ctx.gcmvn_stats)  # type: ignore[type-arg]
-        gcmvn_mean = torch.tensor(
-            gcmvn_stats["mean"], device=ctx.device, dtype=ctx.dtype
-        )
-        gcmvn_std = torch.tensor(gcmvn_stats["std"], device=ctx.device, dtype=ctx.dtype)
-
     pipeline_builder = read_text(ctx.data_file, rtrim=True).skip(1).map(split_tsv)
 
     assert ctx.audio_root_dir is not None
@@ -150,8 +157,8 @@ def build_data_pipeline(
         fbank = data["fbank"]
         std, mean = torch.std_mean(fbank, dim=0)
         data["fbank"] = fbank.subtract(mean).divide(std)
-        if ctx.gcmvn_stats is not None:
-            data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
+        if ctx.gcmvn_mean is not None and ctx.gcmvn_std is not None:
+            data["gcmvn_fbank"] = fbank.subtract(ctx.gcmvn_mean).divide(ctx.gcmvn_std)
         return data
 
     pipeline_builder.map(
@@ -171,52 +178,19 @@ def build_data_pipeline(
     return pipeline_builder.and_return()
 
 
-def adjust_output_for_corrupted_inputs(
-    valid_sequences: Tensor,
-    text_output: List[StringLike],
-    speech_output: Optional[BatchedSpeechOutput],
-) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
-    adjusted_text_output: List[StringLike] = []
-    adjusted_speech_output: Optional[BatchedSpeechOutput] = None
-
-    if speech_output is not None:
-        assert (
-            len(text_output)
-            == len(speech_output.units)
-            == len(speech_output.audio_wavs)
-        )
-        adjusted_speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
-
-    batch_counter = 0
-    for is_valid in valid_sequences:
-        if is_valid:
-            adjusted_text_output.append(text_output[batch_counter])
-            if speech_output is not None:
-                assert adjusted_speech_output is not None
-                adjusted_speech_output.units.append(speech_output.units[batch_counter])
-                adjusted_speech_output.audio_wavs.append(
-                    speech_output.audio_wavs[batch_counter]
-                )
-            batch_counter += 1
-        else:
-            # For the corrupted inputs, we save the following dummy outputs:
-            # empty string for text, empty list for units, 1 second of silence for audio.
-            adjusted_text_output.append("")
-            if adjusted_speech_output is not None:
-                sample_rate = adjusted_speech_output.sample_rate
-                adjusted_speech_output.units.append([])
-                adjusted_speech_output.audio_wavs.append(
-                    torch.zeros(sample_rate).unsqueeze(0).unsqueeze(0)
-                )
-    return (
-        adjusted_text_output,
-        adjusted_speech_output,
-    )
-
-
 def run_eval(
     translator: Translator, text_tokenizer: TextTokenizer, ctx: EvalContext
 ) -> None:
+    pretssel_generator = PretsselGenerator(
+        ctx.model_name,
+        ctx.vocoder_name,
+        ctx.pretssel_model,
+        ctx.device,
+        ctx.gcmvn_mean,
+        ctx.gcmvn_std,
+        ctx.dtype,
+    )
+
     pipeline = build_data_pipeline(ctx, text_tokenizer)
 
     total_steps = count_lines(ctx.data_file) - 1
@@ -226,7 +200,7 @@ def run_eval(
     output_path.mkdir(parents=True, exist_ok=True)
 
     if ctx.output_modality == Modality.SPEECH:
-        waveforms_dir = output_path / f"waveform_{ctx.data_file.stem}"
+        waveforms_dir = output_path / f"waveform"
         waveforms_dir.mkdir(parents=True, exist_ok=True)
 
     hyps = []
@@ -258,10 +232,10 @@ def run_eval(
 
             # Skip performing inference when the input is entirely corrupted.
             if src["seqs"].numel() > 0:
-                (
-                    text_output,
-                    speech_output,
-                ) = translator.predict(
+                gcmvn_fbank, padding_mask = get_seqs_and_padding_mask(
+                    example["audio"]["data"]["gcmvn_fbank"]
+                )
+                text_output, unit_output = translator.predict(
                     src,
                     ctx.task,
                     ctx.target_lang,
@@ -269,8 +243,18 @@ def run_eval(
                     text_generation_opts=ctx.text_generation_opts,
                     unit_generation_opts=ctx.unit_generation_opts,
                     unit_generation_ngram_filtering=ctx.unit_generation_ngram_filtering,
-                    gcmvn_fbank=example["audio"]["data"].get("gcmvn_fbank", None),
+                    duration_factor=ctx.duration_factor,
+                    gcmvn_fbank=gcmvn_fbank,
+                )
+
+                assert unit_output is not None
+                speech_output = pretssel_generator.predict(
+                    unit_output.units,
+                    tgt_lang=ctx.target_lang,
+                    padding_mask=padding_mask,
+                    gcmvn_fbank=gcmvn_fbank,
                 )
+
             else:
                 text_output = []
                 if ctx.output_modality == Modality.SPEECH:
@@ -279,10 +263,7 @@ def run_eval(
                     speech_output = None
 
             if valid_sequences is not None and not valid_sequences.all():
-                (
-                    text_output,
-                    speech_output,
-                ) = adjust_output_for_corrupted_inputs(
+                (text_output, speech_output,) = adjust_output_for_corrupted_inputs(
                     valid_sequences,
                     text_output,
                     speech_output,
@@ -293,6 +274,7 @@ def run_eval(
 
             for i in range(len(text_output)):
                 t = text_output[i]
+                idx = str(example["id"][i])
                 hyp_file.write(f"{t}\n")
 
                 if ctx.output_modality == Modality.SPEECH:
@@ -301,8 +283,8 @@ def run_eval(
                     str_units = [str(i) for i in u]
                     unit_file.write(" ".join(str_units) + "\n")
                     torchaudio.save(
-                        waveforms_dir / f"{sample_id}_pred.wav",
-                        speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
+                        waveforms_dir / f"{idx}_pred.wav",
+                        speech_output.audio_wavs[i].to(torch.float32).cpu(),
                         sample_rate=speech_output.sample_rate,
                     )
 
@@ -353,9 +335,9 @@ def main() -> None:
         default="tgt_text",
     )
     parser.add_argument(
-        "--gcmvn_stats",
+        "--pretssel_model",
         type=str,
-        help="The path to gcmvn fbank stats, if provided, the DataPipeline'd have another copy of gcmvn fbank features (for P2V enc)",
+        help="Model card name for PretsselModel",
         default=None,
     )
     args = parser.parse_args()
@@ -369,19 +351,19 @@ def main() -> None:
 
     if torch.cuda.is_available():
         device = torch.device("cuda:0")
-        dtype = torch.float32
+        dtype = torch.float16
     else:
         device = torch.device("cpu")
         dtype = torch.float32
 
     text_tokenizer = load_unity_text_tokenizer(args.model_name)
 
-    # TODO: Avoid loading the T2U model, vocoder when the output
-    # modality is text.
+    gcmvn_mean, gcmvn_std = load_gcmvn_stats(args.pretssel_model)
+
     translator = Translator(
         args.model_name,
-        args.vocoder_name,
-        device,
+        vocoder_name_or_card=None,
+        device=device,
         text_tokenizer=text_tokenizer,
         dtype=dtype,
     )
@@ -411,7 +393,10 @@ def main() -> None:
         unit_generation_opts=unit_generation_opts,
         unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
         output_path=Path(args.output_path),
-        gcmvn_stats=args.gcmvn_stats,
+        gcmvn_mean=torch.tensor(gcmvn_mean, device=device, dtype=dtype),
+        gcmvn_std=torch.tensor(gcmvn_std, device=device, dtype=dtype),
+        pretssel_model=args.pretssel_model,
+        vocoder_name=args.vocoder_name,
     )
     # fmt: on
     logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")

+ 4 - 2
src/seamless_communication/inference/generator.py

@@ -153,7 +153,8 @@ class UnitYGenerator:
         input_modality: str = "speech",
         output_modality: str = "speech",
         ngram_filtering: bool = False,
-        gcmvn_seqs: Optional[Tensor] = None,
+        duration_factor: float = 1.0,
+        gcmvn_fbank: Optional[Tensor] = None,
     ) -> Tuple[SequenceToTextOutput, Optional["SequenceToUnitOutput"]]:
         """
         :param source_seqs:
@@ -219,7 +220,7 @@ class UnitYGenerator:
         prosody_encoder_out = None
         if self.model.prosody_encoder_model is not None:
             prosody_encoder_out = self.model.prosody_encoder_model(
-                gcmvn_seqs, source_padding_mask
+                gcmvn_fbank, source_padding_mask
             ).unsqueeze(1)
 
         if isinstance(self.model.t2u_model, UnitYT2UModel):
@@ -238,6 +239,7 @@ class UnitYGenerator:
                 text_decoder_output=decoder_output,
                 text_decoder_padding_mask=decoder_padding_mask,
                 text_seqs=text_seqs,
+                duration_factor=duration_factor,
                 film_cond_emb=prosody_encoder_out,
             )
             # (B, S_unit, V_unit)

+ 132 - 0
src/seamless_communication/inference/pretssel_generator.py

@@ -0,0 +1,132 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+from fairseq2.assets.card import AssetCard
+from fairseq2.data import Collater, SequenceData
+from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask
+from fairseq2.typing import DataType, Device
+from torch import Tensor
+
+from seamless_communication.inference.translator import BatchedSpeechOutput
+from seamless_communication.models.pretssel import load_pretssel_model
+from seamless_communication.models.unity import load_unity_unit_tokenizer
+from seamless_communication.models.vocoder import load_mel_vocoder_model
+
+
+class PretsselGenerator(nn.Module):
+    def __init__(
+        self,
+        model_name_or_card: Union[str, AssetCard],
+        vocoder_name_or_card: Union[str, AssetCard],
+        pretssel_name_or_card: Union[str, AssetCard],
+        device: Device,
+        gcmvn_mean: Optional[Tensor] = None,
+        gcmvn_std: Optional[Tensor] = None,
+        dtype: DataType = torch.float16,
+    ):
+        super().__init__()
+        # Load the model.
+        if device == torch.device("cpu"):
+            dtype = torch.float32
+
+        self.device = device
+        self.dtype = dtype
+
+        # Load the vocoder.
+        self.vocoder = load_mel_vocoder_model(
+            vocoder_name_or_card,
+            device=device,
+            dtype=dtype,
+        )
+        self.vocoder.eval()
+
+        self.pretssel_model = load_pretssel_model(
+            pretssel_name_or_card,
+            device=device,
+            dtype=dtype,
+        )
+        self.pretssel_model.eval()
+
+        self.unit_tokenizer = load_unity_unit_tokenizer(model_name_or_card)
+        self.unit_collate = Collater(pad_value=self.unit_tokenizer.vocab_info.pad_idx)
+        self.duration_collate = Collater(pad_value=0)
+
+        self.gcmvn_mean = gcmvn_mean
+        self.gcmvn_std = gcmvn_std
+
+    def gcmvn_denormalize(self, x: Tensor) -> Tensor:
+        if self.gcmvn_mean is None or self.gcmvn_std is None:
+            return x
+
+        # x: B x T x C
+        assert (
+            x.ndim == 3
+            and x.shape[2] == self.gcmvn_mean.shape[0] == self.gcmvn_std.shape[0]
+        )
+        x = x * self.gcmvn_std.view(1, 1, -1).expand_as(x)
+        return x + self.gcmvn_mean.view(1, 1, -1).expand_as(x)
+
+    @torch.inference_mode()
+    def predict(
+        self,
+        units: List[List[int]],
+        tgt_lang: str,
+        padding_mask: Optional[PaddingMask],
+        gcmvn_fbank: Tensor,
+        sample_rate: int = 16000,
+    ) -> BatchedSpeechOutput:
+        list_units, durations = [], []
+        unit_eos_token = torch.tensor(
+            [self.unit_tokenizer.vocab_info.eos_idx],
+            device=self.device,
+        )
+
+        for i, u in enumerate(units):
+            unit = torch.tensor(u).to(unit_eos_token)
+
+            # adjust the control symbols for the embedding
+            unit += 4
+            unit = torch.cat([unit, unit_eos_token], dim=0)
+
+            unit, duration = torch.unique_consecutive(unit, return_counts=True)
+
+            # adjust for the last eos token
+            duration[-1] = 0
+
+            list_units.append(unit)
+            durations.append(duration * 2)
+
+        speech_units = self.unit_collate(list_units)
+        durations = self.duration_collate(durations)["seqs"]
+
+        units_tensor, unit_padding_mask = get_seqs_and_padding_mask(speech_units)
+
+        mel_output = self.pretssel_model(
+            units_tensor,
+            unit_padding_mask,
+            gcmvn_fbank,
+            padding_mask,
+            tgt_lang=tgt_lang,
+            durations=durations,
+        )
+
+        mel_output = self.gcmvn_denormalize(mel_output)
+
+        audio_wavs = []
+        for i, mel_out in enumerate(mel_output):
+            # TODO: Implement batched inference for vocoder.
+            mel_out = mel_out[: durations[i].sum()]
+            translated_audio_wav = self.vocoder(mel_out, normalize_before=True)
+            audio_wavs.append(translated_audio_wav.view(1, -1))
+
+        return BatchedSpeechOutput(
+            units=units,
+            audio_wavs=audio_wavs,
+            sample_rate=sample_rate,
+        )

+ 25 - 20
src/seamless_communication/inference/translator.py

@@ -20,7 +20,7 @@ from fairseq2.data.text import TextTokenizer
 from fairseq2.data.typing import StringLike
 from fairseq2.generation import SequenceGeneratorOptions, SequenceToTextOutput
 from fairseq2.memory import MemoryBlock
-from fairseq2.nn.padding import get_seqs_and_padding_mask
+from fairseq2.nn.padding import get_seqs_and_padding_mask, PaddingMask
 from fairseq2.typing import DataType, Device
 from torch import Tensor
 
@@ -77,7 +77,7 @@ class Translator(nn.Module):
     def __init__(
         self,
         model_name_or_card: Union[str, AssetCard],
-        vocoder_name_or_card: Union[str, AssetCard],
+        vocoder_name_or_card: Union[str, AssetCard, None],
         device: Device,
         text_tokenizer: Optional[TextTokenizer] = None,
         dtype: DataType = torch.float16,
@@ -136,8 +136,9 @@ class Translator(nn.Module):
             pad_value=self.text_tokenizer.vocab_info.pad_idx, pad_to_multiple=2
         )
         self.vocoder = None
-        if output_modality is None or output_modality == Modality.SPEECH:
-            # Load the vocoder.
+        if vocoder_name_or_card is not None and (
+            output_modality is None or output_modality == Modality.SPEECH
+        ):
             self.vocoder = self.load_model_for_inference(
                 load_vocoder_model, vocoder_name_or_card, device, torch.float32
             )
@@ -159,14 +160,16 @@ class Translator(nn.Module):
         model: UnitYModel,
         text_tokenizer: TextTokenizer,
         unit_tokenizer: Optional[UnitTokenizer],
-        src: SequenceData,
+        seqs: Tensor,
+        padding_mask: Optional[PaddingMask],
         input_modality: Modality,
         output_modality: Modality,
         tgt_lang: str,
         text_generation_opts: SequenceGeneratorOptions,
         unit_generation_opts: Optional[SequenceGeneratorOptions],
         unit_generation_ngram_filtering: bool = False,
-        gcmvn_fbank: Optional[SequenceData] = None,
+        duration_factor: float = 1.0,
+        gcmvn_fbank: Optional[Tensor] = None,
     ) -> Tuple[SequenceToTextOutput, Optional[SequenceToUnitOutput]]:
         # We disregard unit generations opts for the NAR T2U decoder.
         if output_modality != Modality.SPEECH or isinstance(
@@ -182,11 +185,6 @@ class Translator(nn.Module):
             text_opts=text_generation_opts,
             unit_opts=unit_generation_opts,
         )
-        seqs, padding_mask = get_seqs_and_padding_mask(src)
-        if gcmvn_fbank is not None:
-            gcmvn_seqs = gcmvn_fbank["seqs"]
-        else:
-            gcmvn_seqs = None
 
         return generator(
             seqs,
@@ -194,7 +192,8 @@ class Translator(nn.Module):
             input_modality.value,
             output_modality.value,
             ngram_filtering=unit_generation_ngram_filtering,
-            gcmvn_seqs=gcmvn_seqs,
+            duration_factor=duration_factor,
+            gcmvn_fbank=gcmvn_fbank,
         )
 
     @staticmethod
@@ -230,7 +229,8 @@ class Translator(nn.Module):
         spkr: Optional[int] = -1,
         sample_rate: int = 16000,
         unit_generation_ngram_filtering: bool = False,
-        gcmvn_fbank: Optional[SequenceData] = None,
+        duration_factor: float = 1.0,
+        gcmvn_fbank: Optional[Tensor] = None,
     ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
         """
         The main method used to perform inference on all tasks.
@@ -299,17 +299,22 @@ class Translator(nn.Module):
             src = self.collate(self.token_encoder(text))
 
         assert isinstance(self.model, UnitYModel)
+
+        seqs, padding_mask = get_seqs_and_padding_mask(src)
+
         text_output, unit_output = self.get_prediction(
             self.model,
             self.text_tokenizer,
             self.unit_tokenizer,
-            src,
+            seqs,
+            padding_mask,
             input_modality,
             output_modality,
             tgt_lang,
             text_generation_opts,
             unit_generation_opts,
             unit_generation_ngram_filtering=unit_generation_ngram_filtering,
+            duration_factor=duration_factor,
             gcmvn_fbank=gcmvn_fbank,
         )
 
@@ -317,7 +322,6 @@ class Translator(nn.Module):
             return text_output.sentences, None
         else:
             assert unit_output is not None
-            assert self.vocoder is not None
 
             if isinstance(self.model.t2u_model, UnitYT2UModel):
                 # Remove the lang token for AR UnitY since the vocoder doesn't need it
@@ -339,11 +343,12 @@ class Translator(nn.Module):
                 )
                 u = u[:index_of_first_one]
                 speech_units.append(u)
-                # TODO: Implement batched inference for vocoder.
-                translated_audio_wav = self.vocoder(
-                    u, tgt_lang, spkr, dur_prediction=duration_prediction
-                )
-                audio_wavs.append(translated_audio_wav)
+                if self.vocoder is not None:
+                    # TODO: Implement batched inference for vocoder.
+                    translated_audio_wav = self.vocoder(
+                        u, tgt_lang, spkr, dur_prediction=duration_prediction
+                    )
+                    audio_wavs.append(translated_audio_wav)
 
             return (
                 text_output.sentences,

+ 3 - 0
src/seamless_communication/models/pretssel/__init__.py

@@ -14,3 +14,6 @@ from seamless_communication.models.pretssel.ecapa_tdnn_builder import (
 from seamless_communication.models.pretssel.ecapa_tdnn_builder import (
     ecapa_tdnn_archs as ecapa_tdnn_archs,
 )
+from seamless_communication.models.pretssel.loader import (
+    load_pretssel_model as load_pretssel_model,
+)

+ 419 - 0
src/seamless_communication/models/pretssel/builder.py

@@ -0,0 +1,419 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from typing import List, Literal, Optional, Union
+
+from fairseq2.assets import asset_store
+from fairseq2.assets.card import AssetCard
+from fairseq2.data import VocabularyInfo
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.nn.embedding import Embedding, StandardEmbedding, init_scaled_embedding
+from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
+from fairseq2.nn.projection import Linear
+from fairseq2.nn.transformer import (
+    MultiheadAttention,
+    StandardMultiheadAttention,
+    TransformerNormOrder,
+    create_default_sdpa,
+)
+from fairseq2.typing import DataType, Device
+from torch.nn import Conv1d
+
+from seamless_communication.models.pretssel.ecapa_tdnn_builder import (
+    EcapaTDNNBuilder,
+    EcapaTDNNConfig,
+    ecapa_tdnn_archs,
+)
+from seamless_communication.models.pretssel.pretssel_model import (
+    PostNet,
+    PretsselDecoderFrontend,
+    PretsselEncoderFrontend,
+    PretsselModel,
+)
+from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
+from seamless_communication.models.unity.fft_decoder_layer import (
+    Conv1dBlock,
+    FeedForwardTransformerLayer,
+)
+from seamless_communication.models.unity.length_regulator import (
+    HardUpsampling,
+    VarianceAdaptor,
+    VariancePredictor,
+)
+from seamless_communication.models.unity.t2u_builder import VariancePredictorConfig
+
+
+@dataclass
+class PretsselEncoderFrontendConfig:
+    prosody_encoder_config: EcapaTDNNConfig
+    dropout: float
+    lang_embed_dim: Optional[int] = None
+
+
+@dataclass
+class FFTLayerConfig:
+    attention_heads: int
+    hidden_dim: int
+    kernel_size: int
+    dropout: float
+    conv1d_dropout: float
+    film_cond_dim: int
+    use_film: bool = False
+
+
+@dataclass
+class PretsselDecoderFrontendConfig:
+    upsampling_type: Literal["gaussian", "hard"]
+    variance_predictor_config: VariancePredictorConfig
+    add_variance_parallel: bool
+
+
+@dataclass
+class PostnetConfig:
+    dropout: float
+    layers: int
+    conv_dim: int
+    conv_kernel_size: int
+
+
+@dataclass
+class PretsselConfig:
+    model_name_or_card: str
+    encoder_frontend_config: PretsselEncoderFrontendConfig
+    fft_layer_config: FFTLayerConfig
+    decoder_frontend_config: PretsselDecoderFrontendConfig
+    post_net_config: PostnetConfig
+    vocab_info: VocabularyInfo
+    model_dim: int
+    max_seq_len: int
+    encoder_layers: int
+    decoder_layers: int
+    output_dim: int
+
+
+pretssel_archs = ArchitectureRegistry[PretsselConfig]("pretssel")
+
+pretssel_arch = pretssel_archs.marker
+
+
+@pretssel_arch("base")
+def _base_pretssel() -> PretsselConfig:
+    prosody_encoder_config = ecapa_tdnn_archs.get_config("base")
+
+    encoder_frontend_config = PretsselEncoderFrontendConfig(
+        prosody_encoder_config=prosody_encoder_config,
+        dropout=0.2,
+        lang_embed_dim=64,
+    )
+
+    fft_layer_config = FFTLayerConfig(
+        attention_heads=2,
+        hidden_dim=1024,
+        kernel_size=9,
+        dropout=0.0,
+        conv1d_dropout=0.2,
+        use_film=True,
+        film_cond_dim=576,
+    )
+
+    variance_predictor_config = VariancePredictorConfig(
+        var_pred_hidden_dim=512,
+        var_pred_kernel_size=5,
+        var_pred_dropout=0.5,
+        use_film=True,
+        film_cond_dim=576,
+    )
+
+    decoder_frontend_config = PretsselDecoderFrontendConfig(
+        upsampling_type="gaussian",
+        variance_predictor_config=variance_predictor_config,
+        add_variance_parallel=True,
+    )
+
+    post_net_config = PostnetConfig(
+        dropout=0.5,
+        layers=5,
+        conv_dim=512,
+        conv_kernel_size=5,
+    )
+
+    return PretsselConfig(
+        "pretssel_v1",
+        encoder_frontend_config,
+        fft_layer_config,
+        decoder_frontend_config,
+        post_net_config,
+        vocab_info=VocabularyInfo(
+            size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
+        ),
+        model_dim=256,
+        max_seq_len=10000,
+        encoder_layers=4,
+        decoder_layers=4,
+        output_dim=80,
+    )
+
+
+class PretsselBuilder:
+    """
+    Builder module for PRETSSEL model
+    """
+
+    config: PretsselConfig
+    prosody_encoder_builder: EcapaTDNNBuilder
+    device: Optional[Device]
+    dtype: Optional[DataType]
+
+    def __init__(
+        self,
+        config: PretsselConfig,
+        prosody_encoder_builder: EcapaTDNNBuilder,
+        *,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param devicev:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        self.config = config
+
+        self.prosody_encoder_builder = prosody_encoder_builder
+
+        self.device, self.dtype = device, dtype
+
+    def build_embed_tokens(self) -> StandardEmbedding:
+        """Build a unit embedding table."""
+
+        return StandardEmbedding(
+            num_embeddings=self.config.vocab_info.size,
+            embedding_dim=self.config.model_dim,
+            init_fn=init_scaled_embedding,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_fft(self, num_layers: int) -> FeedForwardTransformer:
+        """Build a Transformer encoder."""
+
+        layers = [self.build_fft_layer() for _ in range(num_layers)]
+
+        return FeedForwardTransformer(
+            layers,
+            norm_order=TransformerNormOrder.POST,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_fft_layer(self) -> FeedForwardTransformerLayer:
+        """Build a Transformer decoder layer."""
+
+        self_attn = self.build_attention(self.config.fft_layer_config.attention_heads)
+
+        conv1d = Conv1dBlock(
+            self.config.model_dim,
+            self.config.fft_layer_config.hidden_dim,
+            self.config.fft_layer_config.kernel_size,
+            bias=True,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        return FeedForwardTransformerLayer(
+            self_attn,
+            conv1d,
+            dropout_p=0.0,  # fairseq1 doesn't have this
+            conv1d_dropout_p=self.config.fft_layer_config.conv1d_dropout,
+            use_film=self.config.fft_layer_config.use_film,
+            film_cond_dim=self.config.fft_layer_config.film_cond_dim,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_attention(self, num_heads: int) -> MultiheadAttention:
+        """Build a Transformer multi-head attention layer."""
+
+        sdpa = create_default_sdpa(attn_dropout_p=self.config.fft_layer_config.dropout)
+
+        return StandardMultiheadAttention(
+            self.config.model_dim,
+            num_heads,
+            sdpa=sdpa,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_variance_adaptor(
+        self,
+        decoder_frontend_config: PretsselDecoderFrontendConfig,
+    ) -> VarianceAdaptor:
+        """Build a variance adaptor module."""
+
+        variance_predictor_config = decoder_frontend_config.variance_predictor_config
+
+        pitch_predictor = VariancePredictor(
+            self.config.model_dim,
+            variance_predictor_config.var_pred_hidden_dim,
+            variance_predictor_config.var_pred_kernel_size,
+            variance_predictor_config.var_pred_dropout,
+            use_film=variance_predictor_config.use_film,
+            film_cond_dim=variance_predictor_config.film_cond_dim,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        embed_pitch = Conv1d(
+            1,
+            self.config.model_dim,
+            kernel_size=1,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        vuv_predictor = VariancePredictor(
+            self.config.model_dim,
+            variance_predictor_config.var_pred_hidden_dim,
+            variance_predictor_config.var_pred_kernel_size,
+            variance_predictor_config.var_pred_dropout,
+            use_film=variance_predictor_config.use_film,
+            film_cond_dim=variance_predictor_config.film_cond_dim,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        energy_predictor = VariancePredictor(
+            self.config.model_dim,
+            variance_predictor_config.var_pred_hidden_dim,
+            variance_predictor_config.var_pred_kernel_size,
+            variance_predictor_config.var_pred_dropout,
+            use_film=variance_predictor_config.use_film,
+            film_cond_dim=variance_predictor_config.film_cond_dim,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        embed_energy = Conv1d(
+            1,
+            self.config.model_dim,
+            kernel_size=1,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        variance_adaptor = VarianceAdaptor(
+            duration_predictor=None,
+            pitch_predictor=pitch_predictor,
+            embed_pitch=embed_pitch,
+            vuv_predictor=vuv_predictor,
+            energy_predictor=energy_predictor,
+            embed_energy=embed_energy,
+            add_variance_parallel=decoder_frontend_config.add_variance_parallel,
+            upsampling_type=decoder_frontend_config.upsampling_type,
+        )
+
+        return variance_adaptor
+
+    def build_model(self) -> PretsselModel:
+        """Build a model."""
+        prosody_encoder = self.prosody_encoder_builder.build_model()
+
+        embed_tokens = self.build_embed_tokens()
+
+        embed_positions = SinusoidalPositionEncoder(
+            self.config.model_dim,
+            self.config.max_seq_len,
+            _legacy_pad_idx=self.config.vocab_info.pad_idx,
+            device=self.device,
+        )
+
+        model_card = asset_store.retrieve_card(self.config.model_name_or_card)
+        langs = model_card.field("languages").as_list(str)
+        lang_to_index = {l: i for i, l in enumerate(langs)}
+
+        encoder_frontend = PretsselEncoderFrontend(
+            prosody_encoder,
+            embed_tokens,
+            embed_positions,
+            lang_to_index,
+            lang_embed_dim=self.config.encoder_frontend_config.lang_embed_dim,
+            dropout_p=self.config.encoder_frontend_config.dropout,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        encoder = self.build_fft(self.config.encoder_layers)
+
+        variance_adaptor = self.build_variance_adaptor(
+            self.config.decoder_frontend_config
+        )
+
+        decoder_frontend = PretsselDecoderFrontend(
+            variance_adaptor,
+            embed_positions,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        decoder = self.build_fft(self.config.decoder_layers)
+
+        final_proj = Linear(
+            self.config.model_dim,
+            self.config.output_dim,
+            bias=True,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        postnet = PostNet(
+            self.config.output_dim,
+            self.config.post_net_config.conv_dim,
+            self.config.post_net_config.conv_kernel_size,
+            self.config.post_net_config.layers,
+            self.config.post_net_config.dropout,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        return PretsselModel(
+            encoder_frontend,
+            encoder,
+            decoder_frontend,
+            decoder,
+            final_proj,
+            postnet,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+
+def create_pretssel_model(
+    config: PretsselConfig,
+    device: Optional[Device] = None,
+    dtype: Optional[DataType] = None,
+) -> PretsselModel:
+    """Create a PretsselModel.
+
+    :param config:
+        The configuration to use.
+    :param device:
+        The device on which to initialize modules.
+    :param dtype:
+        The data type of module parameters and buffers.
+    """
+
+    prosody_encoder_builder = EcapaTDNNBuilder(
+        config.encoder_frontend_config.prosody_encoder_config,
+        device=device,
+        dtype=dtype,
+    )
+    return PretsselBuilder(
+        config, prosody_encoder_builder, device=device, dtype=dtype
+    ).build_model()

+ 114 - 0
src/seamless_communication/models/pretssel/loader.py

@@ -0,0 +1,114 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any, Dict, Mapping, final
+
+from fairseq2.assets import asset_store, download_manager
+from fairseq2.models.utils.checkpoint_loader import upgrade_fairseq_checkpoint
+from fairseq2.models.utils.model_loader import ModelLoader
+from overrides import override as finaloverride
+
+from seamless_communication.models.pretssel.builder import (
+    PretsselConfig,
+    create_pretssel_model,
+    pretssel_archs,
+)
+from seamless_communication.models.pretssel.pretssel_model import PretsselModel
+
+
+@final
+class PretsselLoader(ModelLoader[PretsselModel, PretsselConfig]):
+    """Load PretsselModel."""
+
+    @finaloverride
+    def _convert_checkpoint(
+        self, checkpoint: Mapping[str, Any], config: PretsselConfig
+    ) -> Mapping[str, Any]:
+        state_dict = checkpoint["model"]
+
+        # Check if we have a fairseq2 checkpoint.
+        if "decoder_frontend.embed.weight" in state_dict:
+            return checkpoint
+
+        key_map = self._fairseq_key_map(config)
+
+        checkpoint = upgrade_fairseq_checkpoint(checkpoint, key_map)
+
+        state_dict = checkpoint["model"]
+
+        keys_to_delete = []
+
+        keys_to_delete.extend(
+            [
+                "encoder.embed_positions._float_tensor",
+                "decoder.embed_positions._float_tensor",
+                "enc_emb_proj.weight",
+                "enc_emb_proj.bias",
+            ]
+        )
+
+        keys_to_delete.extend(
+            [
+                key
+                for key in state_dict
+                if key.startswith("decoder.var_adaptor.duration_predictor")
+            ]
+        )
+
+        for key in keys_to_delete:
+            if key in state_dict:
+                del state_dict[key]
+
+        return checkpoint
+
+    @staticmethod
+    def _fairseq_key_map(config: PretsselConfig) -> Dict[str, str]:
+        key_map = {
+            # fmt: off
+            # encoder frontend
+            r"^prosody_encoder\.":                                        r"encoder_frontend.prosody_encoder.",
+            r"^encoder\.embed_tokens\.":                                  r"encoder_frontend.embed_tokens.",
+            r"^embed_lang\.":                                             r"encoder_frontend.embed_lang.",
+            r"^encoder\.pos_emb_alpha":                                   r"encoder_frontend.pos_emb_alpha",
+
+            # encoder
+            r"^encoder\.fft_layers\.([0-9]+)\.self_attn\.out_proj\.":     r"encoder.layers.\1.self_attn.output_proj.",
+            r"^encoder\.fft_layers\.([0-9]+)\.self_attn\.":               r"encoder.layers.\1.self_attn.",
+            r"^encoder\.fft_layers\.([0-9]+)\.layer_norm\.":              r"encoder.layers.\1.self_attn_layer_norm.",
+            r"^encoder\.fft_layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"encoder.layers.\1.conv1d.conv1.",
+            r"^encoder\.fft_layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"encoder.layers.\1.conv1d.conv2.",
+            r"^encoder\.fft_layers\.([0-9]+)\.ffn\.layer_norm\.":         r"encoder.layers.\1.conv1d_layer_norm.",
+            r"^encoder\.fft_layers\.([0-9]+)\.film\.":                    r"encoder.layers.\1.film.",
+
+            # decoder frontend
+            r"^decoder\.var_adaptor\.":                                   r"decoder_frontend.variance_adaptor.",
+            r"^decoder\.pos_emb_alpha":                                   r"decoder_frontend.pos_emb_alpha",
+
+            # decoder
+            r"^decoder\.fft_layers\.([0-9]+)\.self_attn\.out_proj\.":     r"decoder.layers.\1.self_attn.output_proj.",
+            r"^decoder\.fft_layers\.([0-9]+)\.self_attn\.":               r"decoder.layers.\1.self_attn.",
+            r"^decoder\.fft_layers\.([0-9]+)\.layer_norm\.":              r"decoder.layers.\1.self_attn_layer_norm.",
+            r"^decoder\.fft_layers\.([0-9]+)\.ffn\.ffn\.0\.":             r"decoder.layers.\1.conv1d.conv1.",
+            r"^decoder\.fft_layers\.([0-9]+)\.ffn\.ffn\.2\.":             r"decoder.layers.\1.conv1d.conv2.",
+            r"^decoder\.fft_layers\.([0-9]+)\.ffn\.layer_norm\.":         r"decoder.layers.\1.conv1d_layer_norm.",
+            r"^decoder\.fft_layers\.([0-9]+)\.film\.":                    r"decoder.layers.\1.film.",
+
+            # final_proj & postnet
+            r"^decoder\.out_proj\.":                                      r"final_proj.",
+            r"^decoder\.postnet\.":                                       r"postnet.",
+            # fmt: on
+        }
+
+        return key_map
+
+
+load_pretssel_model = PretsselLoader(
+    asset_store,
+    download_manager,
+    create_pretssel_model,
+    pretssel_archs,
+    restrict_checkpoints=False,
+)

+ 276 - 0
src/seamless_communication/models/pretssel/pretssel_model.py

@@ -0,0 +1,276 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.nn.embedding import Embedding, StandardEmbedding
+from fairseq2.nn.padding import PaddingMask
+from fairseq2.nn.position_encoder import PositionEncoder
+from fairseq2.nn.projection import Projection
+from fairseq2.typing import DataType, Device
+from torch import Tensor
+from torch.nn import (
+    BatchNorm1d,
+    Conv1d,
+    Dropout,
+    Module,
+    ModuleList,
+    Parameter,
+    Sequential,
+    Tanh,
+    init,
+)
+
+from seamless_communication.models.pretssel.ecapa_tdnn import ECAPA_TDNN
+from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
+from seamless_communication.models.unity.length_regulator import VarianceAdaptor
+
+
+class PretsselEncoderFrontend(Module):
+    """Represent Encoder frontend, including speaker & language embedding"""
+
+    prosody_encoder: ECAPA_TDNN
+    embed_tokens: Embedding
+    embed_positions: PositionEncoder
+    pos_emb_alpha: Parameter
+    embed_lang: Embedding
+    dropout: Dropout
+
+    def __init__(
+        self,
+        prosody_encoder: ECAPA_TDNN,
+        embed_tokens: Embedding,
+        embed_positions: PositionEncoder,
+        lang_to_index: Dict[str, int],
+        lang_embed_dim: Optional[int],
+        dropout_p: float,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+
+        self.prosody_encoder = prosody_encoder
+
+        self.embed_tokens = embed_tokens
+
+        self.embed_positions = embed_positions
+        self.pos_emb_alpha = Parameter(torch.ones(1, device=device, dtype=dtype))
+
+        self.lang_to_index = lang_to_index
+
+        if lang_embed_dim is not None:
+            self.embed_lang = StandardEmbedding(
+                len(lang_to_index), lang_embed_dim, device=device, dtype=dtype
+            )
+        else:
+            self.register_module("embed_lang", None)
+
+        self.dropout = Dropout(dropout_p)
+
+        self.device = device
+        self.dtype = dtype
+
+    def forward(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[PaddingMask],
+        prosody_input_seqs: Tensor,
+        prosody_padding_mask: Optional[PaddingMask],
+        tgt_lang: str,
+    ) -> Tuple[Tensor, Tensor]:
+        prosody_embs = self.prosody_encoder(
+            prosody_input_seqs,
+            prosody_padding_mask,
+        ).unsqueeze(1)
+
+        if self.embed_lang is not None:
+            lang_index = self.lang_to_index[tgt_lang]
+            lang_index_tensor = (
+                torch.tensor([lang_index]).to(seqs).repeat(seqs.size(0), 1)
+            )
+            lang_embeds = self.embed_lang(lang_index_tensor)
+            prosody_embs = torch.cat([prosody_embs, lang_embeds], dim=-1)
+
+        seqs = self.embed_tokens(seqs)
+        seqs += self.pos_emb_alpha * (self.embed_positions(seqs, padding_mask) - seqs)
+        seqs = self.dropout(seqs)
+
+        return seqs, prosody_embs
+
+
+class PretsselDecoderFrontend(Module):
+    """Represent Decoder frontend, including VarianceAdaptor & Positional embedding"""
+
+    variance_adaptor: VarianceAdaptor
+    embed_positions: PositionEncoder
+    pos_emb_alpha: Parameter
+
+    def __init__(
+        self,
+        variance_adaptor: VarianceAdaptor,
+        embed_positions: PositionEncoder,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+
+        self.variance_adaptor = variance_adaptor
+        self.embed_positions = embed_positions
+        self.pos_emb_alpha = Parameter(torch.ones(1, device=device, dtype=dtype))
+
+        self.device = device
+        self.dtype = dtype
+
+    def forward(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[PaddingMask],
+        durations: Optional[Tensor] = None,
+        duration_factor: float = 1.0,
+        min_duration: int = 0,
+        film_cond_emb: Optional[Tensor] = None,
+    ) -> Tuple[Tensor, Optional[PaddingMask]]:
+        seqs, padding_mask = self.variance_adaptor(
+            seqs, padding_mask, durations, duration_factor, min_duration, film_cond_emb
+        )
+
+        seqs += self.pos_emb_alpha * (self.embed_positions(seqs, padding_mask) - seqs)
+
+        return seqs, padding_mask
+
+
+class PostNet(Module):
+    """Represent a PostNet"""
+
+    def __init__(
+        self,
+        in_dim: int,
+        n_channels: int,
+        kernel_size: int,
+        n_layers: int,
+        dropout: float,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+
+        self.convolutions = ModuleList()
+        assert kernel_size % 2 == 1
+        for i in range(n_layers):
+            cur_layers = (
+                [
+                    Conv1d(
+                        in_dim if i == 0 else n_channels,
+                        n_channels if i < n_layers - 1 else in_dim,
+                        kernel_size=kernel_size,
+                        padding="same",
+                        device=device,
+                        dtype=dtype,
+                    ),
+                    BatchNorm1d(
+                        n_channels if i < n_layers - 1 else in_dim,
+                        device=device,
+                        dtype=dtype,
+                    ),
+                ]
+                + ([Tanh()] if i < n_layers - 1 else [])
+                + [Dropout(dropout)]
+            )
+            self.convolutions.append(Sequential(*cur_layers))
+
+        self.device = device
+        self.dtype = dtype
+        self.reset_parameters()
+
+    def reset_parameters(self) -> None:
+        """Reset the parameters and buffers of the module."""
+        for i, layer in enumerate(self.convolutions):
+            init.xavier_uniform_(
+                layer[0].weight,
+                init.calculate_gain(
+                    "tanh" if i < len(self.convolutions) - 1 else "linear"
+                ),
+            )
+
+    def forward(self, x: Tensor) -> Tensor:
+        x = x.transpose(1, 2)  # B x T x C -> B x C x T
+        for layer in self.convolutions:
+            x = layer(x)
+
+        return x.transpose(1, 2)
+
+
+class PretsselModel(Module):
+    """Represent the PretsselModel"""
+
+    encoder_frontend: PretsselEncoderFrontend
+    encoder: FeedForwardTransformer
+    decoder_frontend: PretsselDecoderFrontend
+    decoder: FeedForwardTransformer
+    final_proj: Projection
+    postnet: PostNet
+
+    def __init__(
+        self,
+        encoder_frontend: PretsselEncoderFrontend,
+        encoder: FeedForwardTransformer,
+        decoder_frontend: PretsselDecoderFrontend,
+        decoder: FeedForwardTransformer,
+        final_proj: Projection,
+        postnet: PostNet,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+
+        self.encoder_frontend = encoder_frontend
+        self.encoder = encoder
+        self.decoder_frontend = decoder_frontend
+        self.decoder = decoder
+        self.final_proj = final_proj
+        self.postnet = postnet
+
+        self.device = device
+        self.dtype = dtype
+
+    def forward(
+        self,
+        seqs: Tensor,
+        padding_mask: Optional[PaddingMask],
+        prosody_input_seqs: Tensor,
+        prosody_padding_mask: Optional[PaddingMask],
+        tgt_lang: str,
+        durations: Optional[Tensor] = None,
+        duration_factor: float = 1.0,
+        min_duration: int = 0,
+    ) -> Tensor:
+        # (N, S) -> (N, S, M)
+        seqs, cond_embs = self.encoder_frontend(
+            seqs,
+            padding_mask,
+            prosody_input_seqs,
+            prosody_padding_mask,
+            tgt_lang,
+        )
+
+        seqs, padding_mask = self.encoder(seqs, padding_mask, cond_embs)
+
+        # (N, S, M) -> (N, X, M), inflated units
+        seqs, padding_mask = self.decoder_frontend(
+            seqs, padding_mask, durations, duration_factor, min_duration, cond_embs
+        )
+
+        seqs, padding_mask = self.decoder(seqs, padding_mask, cond_embs)
+
+        # (N, X, M) -> (N, X, n_mels)
+        seqs = self.final_proj(seqs)
+
+        seqs = seqs + self.postnet(seqs)
+
+        return seqs

+ 9 - 6
src/seamless_communication/models/unity/__init__.py

@@ -20,6 +20,12 @@ from seamless_communication.models.unity.char_tokenizer import (
 from seamless_communication.models.unity.char_tokenizer import (
     load_unity_char_tokenizer as load_unity_char_tokenizer,
 )
+from seamless_communication.models.unity.fft_decoder import (
+    FeedForwardTransformer as FeedForwardTransformer,
+)
+from seamless_communication.models.unity.fft_decoder_layer import (
+    FeedForwardTransformerLayer as FeedForwardTransformerLayer,
+)
 from seamless_communication.models.unity.film import FiLM
 from seamless_communication.models.unity.length_regulator import (
     HardUpsampling as HardUpsampling,
@@ -31,6 +37,9 @@ from seamless_communication.models.unity.length_regulator import (
     VariancePredictor as VariancePredictor,
 )
 from seamless_communication.models.unity.loader import UnitYLoader as UnitYLoader
+from seamless_communication.models.unity.loader import (
+    load_gcmvn_stats as load_gcmvn_stats,
+)
 from seamless_communication.models.unity.loader import (
     load_unity_model as load_unity_model,
 )
@@ -47,15 +56,9 @@ from seamless_communication.models.unity.model import (
 from seamless_communication.models.unity.model import UnitYOutput as UnitYOutput
 from seamless_communication.models.unity.model import UnitYT2UModel as UnitYT2UModel
 from seamless_communication.models.unity.model import UnitYX2TModel as UnitYX2TModel
-from seamless_communication.models.unity.nar_decoder import (
-    NARTransformerDecoder as NARTransformerDecoder,
-)
 from seamless_communication.models.unity.nar_decoder_frontend import (
     NARDecoderFrontend as NARDecoderFrontend,
 )
-from seamless_communication.models.unity.nar_decoder_layer import (
-    NARTransformerDecoderLayer as NARTransformerDecoderLayer,
-)
 from seamless_communication.models.unity.t2u_builder import (
     UnitYNART2UBuilder as UnitYNART2UBuilder,
 )

+ 1 - 1
src/seamless_communication/models/unity/builder.py

@@ -202,7 +202,7 @@ def _expressivity_v2() -> UnitYConfig:
 
     mt_model_config.vocab_info.pad_idx = 1
 
-    mt_model_config.max_seq_len = 4000
+    mt_model_config.max_seq_len = 10000
 
     t2u_config = unity_t2u_archs.get_config("expressivity_nar")
 

+ 5 - 5
src/seamless_communication/models/unity/nar_decoder.py → src/seamless_communication/models/unity/fft_decoder.py

@@ -14,14 +14,14 @@ from fairseq2.typing import DataType, Device, finaloverride
 from torch import Tensor
 from torch.nn import Module
 
-from seamless_communication.models.unity.nar_decoder_layer import (
-    NARTransformerDecoderLayer,
+from seamless_communication.models.unity.fft_decoder_layer import (
+    FeedForwardTransformerLayer,
 )
 
 
 @final
-class NARTransformerDecoder(Module):
-    """Represents a non-autoregressive Transformer decoder."""
+class FeedForwardTransformer(Module):
+    """Represents a Feedforward Transformer decoder."""
 
     model_dim: int
     layer_norm: Optional[LayerNorm]
@@ -29,7 +29,7 @@ class NARTransformerDecoder(Module):
 
     def __init__(
         self,
-        layers: Iterable[NARTransformerDecoderLayer],
+        layers: Iterable[FeedForwardTransformerLayer],
         *,
         norm_order: TransformerNormOrder = TransformerNormOrder.POST,
         device: Optional[Device] = None,

+ 1 - 1
src/seamless_communication/models/unity/nar_decoder_layer.py → src/seamless_communication/models/unity/fft_decoder_layer.py

@@ -102,7 +102,7 @@ class Conv1dBlock(Module):
 
 
 @final
-class NARTransformerDecoderLayer(Module):
+class FeedForwardTransformerLayer(Module):
     """Represents the FFT Block as described in
     :cite:t:`https://arxiv.org/pdf/1905.09263.pdf`."""
 

+ 130 - 22
src/seamless_communication/models/unity/length_regulator.py

@@ -3,11 +3,12 @@
 #
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
-from typing import Optional, Tuple
+from typing import Literal, Optional, Tuple, Union
 
 import torch
+import torch.nn.functional as F
 from fairseq2.nn.normalization import LayerNorm
-from fairseq2.nn.padding import PaddingMask, apply_padding_mask
+from fairseq2.nn.padding import PaddingMask, apply_padding_mask, to_padding_mask
 from fairseq2.nn.projection import Linear
 from fairseq2.nn.transformer import create_standard_layer_norm
 from fairseq2.typing import DataType, Device
@@ -38,6 +39,63 @@ class HardUpsampling(Module):
         return upsampled_seqs, upsampled_seq_lens
 
 
+class GaussianUpsampling(Module):
+    """Gaussian upsampling with fixed temperature as in:
+    https://arxiv.org/abs/2010.04301
+    """
+
+    def __init__(self, delta: float = 0.1):
+        super().__init__()
+        self.delta = delta
+
+    def forward(
+        self,
+        x: Tensor,
+        durations: Tensor,
+        padding_mask: Optional[PaddingMask] = None,
+    ) -> Tuple[Tensor, Tensor]:
+        """Upsample hidden states according to durations.
+        Args:
+            x (Tensor): Batched hidden state to be expanded (B, T_text, C).
+            durations (Tensor): Batched token duration (B, T_text).
+            padding_mask (Tensor): Mask tensor (B, T_text).
+        Returns:
+            Tensor: Expanded hidden state (B, T_feat, C).
+            Tensor: Output lengths (B,).
+        """
+        out_lens = durations.sum(dim=1)
+        y_mask = to_padding_mask(out_lens, max(out_lens))
+
+        B = durations.size(0)
+        if durations.sum() == 0:
+            # NOTE(kan-bayashi): This case must not be happened in teacher forcing.
+            #   It will be happened in inference with a bad duration predictor.
+            #   So we do not need to care the padded sequence case here.
+            durations[durations.sum(dim=1).eq(0)] = 1
+
+        if y_mask is None:
+            T_feat = durations.sum().int()
+        else:
+            T_feat = y_mask.size(-1)
+
+        t = torch.arange(0, T_feat).unsqueeze(0).repeat(B, 1).to(x)
+        if y_mask is not None:
+            t = t * y_mask.float()
+
+        c = durations.cumsum(dim=-1) - durations / 2
+        energy = -1 * self.delta * (t.unsqueeze(-1) - c.unsqueeze(1)) ** 2
+
+        if padding_mask is not None:
+            energy = energy.masked_fill(
+                ~padding_mask.materialize().unsqueeze(1).repeat(1, T_feat, 1),
+                -float("inf"),
+            )
+
+        p_attn = F.softmax(energy, dim=2).to(x)  # (B, T_feat, T_text)
+        x = torch.matmul(p_attn, x)
+        return x, out_lens
+
+
 class VariancePredictor(Module):
     """Represents the duration/pitch/energy predictor as described in
     :cite:t:`https://arxiv.org/pdf/2006.04558.pdf`"""
@@ -70,7 +128,7 @@ class VariancePredictor(Module):
                 var_pred_hidden_dim,
                 var_pred_kernel_size,
                 stride=1,
-                padding=(var_pred_kernel_size - 1) // 2,
+                padding="same",
                 bias=bias,
                 device=device,
                 dtype=dtype,
@@ -90,7 +148,7 @@ class VariancePredictor(Module):
                 var_pred_hidden_dim,
                 var_pred_kernel_size,
                 stride=1,
-                padding=1,
+                padding="same",
                 bias=bias,
                 device=device,
                 dtype=dtype,
@@ -114,7 +172,7 @@ class VariancePredictor(Module):
     def forward(
         self,
         seqs: Tensor,
-        padding_mask: Optional[PaddingMask],
+        padding_mask: Optional[PaddingMask] = None,
         film_cond_emb: Optional[Tensor] = None,
     ) -> Tensor:
         # Ensure that we do not leak padded positions in the convolution layer.
@@ -164,53 +222,103 @@ class VarianceAdaptor(Module):
     """Represent the Variance adaptor as described in
     :cite:t:`https://arxiv.org/pdf/2006.04558.pdf`"""
 
-    duration_predictor: VariancePredictor
+    duration_predictor: Optional[VariancePredictor]
     pitch_predictor: Optional[VariancePredictor]
+    vuv_predictor: Optional[VariancePredictor]
     energy_predictor: Optional[VariancePredictor]
-    hard_upsampling: HardUpsampling
+    length_regulator: Union[HardUpsampling, GaussianUpsampling]
 
     def __init__(
         self,
-        duration_predictor: VariancePredictor,
+        duration_predictor: Optional[VariancePredictor] = None,
         pitch_predictor: Optional[VariancePredictor] = None,
+        embed_pitch: Optional[Conv1d] = None,
+        vuv_predictor: Optional[VariancePredictor] = None,
         energy_predictor: Optional[VariancePredictor] = None,
+        embed_energy: Optional[Conv1d] = None,
+        add_variance_parallel: bool = True,
+        upsampling_type: Literal["gaussian", "hard"] = "hard",
+        use_film: bool = False,
+        film_cond_dim: Optional[int] = None,
     ):
         super().__init__()
 
-        self.duration_predictor = duration_predictor
+        if duration_predictor:
+            self.duration_predictor = duration_predictor
+        else:
+            self.register_module("duration_predictor", None)
 
         if pitch_predictor:
             self.pitch_predictor = pitch_predictor
+            self.embed_pitch = embed_pitch
         else:
             self.register_module("pitch_predictor", None)
+            self.register_module("embed_pitch", None)
+
+        if vuv_predictor:
+            self.vuv_predictor = vuv_predictor
+        else:
+            self.register_module("vuv_predictor", None)
 
         if energy_predictor:
             self.energy_predictor = energy_predictor
+            self.embed_energy = embed_energy
         else:
             self.register_module("energy_predictor", None)
+            self.register_module("embed_energy", None)
+
+        self.add_variance_parallel = add_variance_parallel
 
-        self.hard_upsampling = HardUpsampling()
+        if upsampling_type == "gaussian":
+            self.length_regulator = GaussianUpsampling()
+        else:
+            self.length_regulator = HardUpsampling()
 
     def forward(
         self,
         seqs: Tensor,
         padding_mask: Optional[PaddingMask],
+        durations: Optional[Tensor] = None,
         duration_factor: float = 1.0,
         min_duration: int = 0,
         film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, PaddingMask]:
-        log_durations = self.duration_predictor(seqs, padding_mask, film_cond_emb)
-
-        durations = torch.clamp(
-            torch.round((torch.exp(log_durations) - 1) * duration_factor).long(),
-            min=min_duration,
-        )
 
-        # We need to apply the padding_mask again since we clamp by min_duration.
-        durations = apply_padding_mask(durations, padding_mask, pad_value=0)
-
-        # TODO: Implement pitch, energy predictors.
-        # TODO: Implement GaussianUpsampling.
-        seqs, seq_lens = self.hard_upsampling(seqs, durations)
+        if self.duration_predictor is not None:
+            log_durations = self.duration_predictor(seqs, padding_mask, film_cond_emb)
+            durations = torch.clamp(
+                torch.round((torch.exp(log_durations) - 1) * duration_factor).long(),
+                min=min_duration,
+            )
+            # We need to apply the padding_mask again since we clamp by min_duration.
+            durations = apply_padding_mask(durations, padding_mask, pad_value=0)
+
+        assert durations is not None
+
+        if self.pitch_predictor is not None:
+            pitch_out = self.pitch_predictor(seqs, padding_mask, film_cond_emb)
+            if self.vuv_predictor is not None:
+                vuv_out = self.vuv_predictor(seqs, padding_mask, film_cond_emb)
+                pitch_out = pitch_out * (torch.sigmoid(vuv_out) >= 0.5)
+
+            assert self.embed_pitch is not None
+            pitch_embed = self.embed_pitch(pitch_out.unsqueeze(1)).transpose(1, 2)
+            if not self.add_variance_parallel:
+                seqs = seqs + pitch_embed
+
+        if self.energy_predictor is not None:
+            energy_out = self.energy_predictor(seqs, padding_mask, film_cond_emb)
+
+            assert self.embed_energy is not None
+            energy_embed = self.embed_energy(energy_out.unsqueeze(1)).transpose(1, 2)
+            if self.add_variance_parallel:
+                seqs = seqs + pitch_embed + energy_embed
+            else:
+                seqs = seqs + energy_embed
+
+        if isinstance(self.length_regulator, GaussianUpsampling):
+            seqs, seq_lens = self.length_regulator(seqs, durations, padding_mask)
+        else:
+            seqs, seq_lens = self.length_regulator(seqs, durations)
 
         return seqs, PaddingMask(seq_lens, batch_seq_len=seqs.size(1))

+ 32 - 1
src/seamless_communication/models/unity/loader.py

@@ -4,7 +4,7 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import Any, Dict, List, Mapping, Union, final
+from typing import Any, Dict, List, Mapping, Tuple, Union, final
 
 import torch
 from fairseq2.assets import AssetStore, asset_store, download_manager
@@ -445,3 +445,34 @@ class UnitYUnitTokenizerLoader:
 
 
 load_unity_unit_tokenizer = UnitYUnitTokenizerLoader(asset_store)
+
+
+class GcmvnStatsLoader:
+    """Loads GCMVN stats (mean & std) for ProsodyUnitY & PretsselModel"""
+
+    def __init__(self, asset_store: AssetStore) -> None:
+        """
+        :param asset_store:
+            The asset store to retrieve the model information.
+        """
+        self.asset_store = asset_store
+
+    def __call__(
+        self, model_name_or_card: Union[str, AssetCard]
+    ) -> Tuple[List[float], List[float]]:
+        """
+        :param model_name_or_card:
+            The name of the model or an already loaded AssetCard
+        """
+
+        if isinstance(model_name_or_card, AssetCard):
+            card = model_name_or_card
+        else:
+            card = self.asset_store.retrieve_card(model_name_or_card)
+
+        gcmvn_stats: Dict[str, List[float]] = card.field("gcmvn_stats").as_(dict)
+
+        return gcmvn_stats["mean"], gcmvn_stats["std"]
+
+
+load_gcmvn_stats = GcmvnStatsLoader(asset_store)

+ 11 - 4
src/seamless_communication/models/unity/model.py

@@ -20,7 +20,7 @@ from torch import Tensor
 from torch.nn import Module
 
 from seamless_communication.models.pretssel.ecapa_tdnn import ECAPA_TDNN
-from seamless_communication.models.unity.nar_decoder import NARTransformerDecoder
+from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
 from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
 
 
@@ -334,7 +334,7 @@ class UnitYNART2UModel(Module):
     model_dim: int
     encoder: Optional[TransformerEncoder]
     decoder_frontend: NARDecoderFrontend
-    decoder: NARTransformerDecoder
+    decoder: FeedForwardTransformer
     final_proj: Projection
     target_vocab_info: VocabularyInfo
     prosody_proj: Optional[Projection]
@@ -343,7 +343,7 @@ class UnitYNART2UModel(Module):
         self,
         encoder: Optional[TransformerEncoder],
         decoder_frontend: NARDecoderFrontend,
-        decoder: NARTransformerDecoder,
+        decoder: FeedForwardTransformer,
         final_proj: Projection,
         target_vocab_info: VocabularyInfo,
         prosody_proj: Optional[Projection] = None,
@@ -381,6 +381,7 @@ class UnitYNART2UModel(Module):
         text_decoder_output: Tensor,
         text_decoder_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],
+        duration_factor: float = 1.0,
         film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[SequenceModelOutput, Optional[PaddingMask]]:
         encoder_output, encoder_padding_mask = self.encode(
@@ -394,6 +395,7 @@ class UnitYNART2UModel(Module):
             encoder_output,
             encoder_padding_mask,
             text_seqs,
+            duration_factor,
             film_cond_emb,
         )
 
@@ -414,12 +416,17 @@ class UnitYNART2UModel(Module):
         encoder_output: Tensor,
         encoder_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],
+        duration_factor: float = 1.0,
         film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         # encoder_output: (N, S, M)
         # text_seqs: (N, S)
         seqs, padding_mask = self.decoder_frontend(
-            encoder_output, encoder_padding_mask, text_seqs, film_cond_emb
+            encoder_output,
+            encoder_padding_mask,
+            text_seqs,
+            duration_factor,
+            film_cond_emb,
         )
 
         return self.decoder(seqs, padding_mask, film_cond_emb=film_cond_emb)  # type: ignore[no-any-return]

+ 2 - 0
src/seamless_communication/models/unity/nar_decoder_frontend.py

@@ -302,6 +302,7 @@ class NARDecoderFrontend(Module):
         encoder_output: Tensor,
         encoder_padding_mask: Optional[PaddingMask],
         text_seqs: Optional[Tensor],
+        duration_factor: float = 1.0,
         film_cond_emb: Optional[Tensor] = None,
     ) -> Tuple[Tensor, Optional[PaddingMask]]:
         assert text_seqs is not None
@@ -323,6 +324,7 @@ class NARDecoderFrontend(Module):
         seqs, padding_mask = self.variance_adaptor(
             seqs,
             encoder_padding_mask,
+            duration_factor=duration_factor,
             min_duration=1,
             film_cond_emb=film_cond_emb,
         )

+ 11 - 11
src/seamless_communication/models/unity/t2u_builder.py

@@ -38,17 +38,17 @@ from fairseq2.typing import DataType, Device
 from torch.nn import GELU, ReLU
 
 from seamless_communication.models.unity.char_tokenizer import load_unity_char_tokenizer
+from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
+from seamless_communication.models.unity.fft_decoder_layer import (
+    Conv1dBlock,
+    FeedForwardTransformerLayer,
+)
 from seamless_communication.models.unity.length_regulator import (
     VarianceAdaptor,
     VariancePredictor,
 )
 from seamless_communication.models.unity.model import UnitYNART2UModel, UnitYT2UModel
-from seamless_communication.models.unity.nar_decoder import NARTransformerDecoder
 from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
-from seamless_communication.models.unity.nar_decoder_layer import (
-    Conv1dBlock,
-    NARTransformerDecoderLayer,
-)
 
 
 @dataclass
@@ -252,7 +252,7 @@ def _expressivity_nar() -> UnitYT2UConfig:
     nar_decoder_config = NARDecoderConfig(
         model_name_or_card="seamless_expressivity",
         char_vocabulary_size=10904,
-        char_max_seq_len=4000,
+        char_max_seq_len=10000,
         conv1d_kernel_size=7,
         conv1d_inner_dim=1024,
         conv1d_dropout_p=0.1,
@@ -262,7 +262,7 @@ def _expressivity_nar() -> UnitYT2UConfig:
 
     return UnitYT2UConfig(
         model_dim=1024,
-        unit_max_seq_len=4000,
+        unit_max_seq_len=10000,
         target_vocab_info=VocabularyInfo(
             size=10005, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
         ),
@@ -631,21 +631,21 @@ class UnitYNART2UBuilder:
             dtype=self.dtype,
         )
 
-    def build_decoder(self) -> NARTransformerDecoder:
+    def build_decoder(self) -> FeedForwardTransformer:
         """Build a Transformer decoder."""
 
         num_layers = self.config.num_decoder_layers
 
         layers = [self.build_decoder_layer() for _ in range(num_layers)]
 
-        return NARTransformerDecoder(
+        return FeedForwardTransformer(
             layers,
             norm_order=TransformerNormOrder.PRE,
             device=self.device,
             dtype=self.dtype,
         )
 
-    def build_decoder_layer(self) -> NARTransformerDecoderLayer:
+    def build_decoder_layer(self) -> FeedForwardTransformerLayer:
         """Build a Transformer decoder layer."""
 
         assert self.config.nar_decoder_config is not None
@@ -661,7 +661,7 @@ class UnitYNART2UBuilder:
             dtype=self.dtype,
         )
 
-        return NARTransformerDecoderLayer(
+        return FeedForwardTransformerLayer(
             self_attn,
             conv1d,
             dropout_p=self.config.dropout_p,

+ 1 - 1
tests/integration/models/test_pretssel_vocoder.py

@@ -19,7 +19,7 @@ def test_pretssel_vocoder() -> None:
     sample_rate = 16_000
 
     vocoder = load_mel_vocoder_model(
-        "vocoder_pretssel", device=device, dtype=torch.float32
+        "vocoder_mel", device=device, dtype=torch.float32
     )
 
     url = "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav"