Explorar o código

Make a public-facing watermarked vocoder (PretsselVocoder) (#97)

* initial commit

* add audiocraft to requrements

* all in one file

* cleanup

* remove unused imports

* cleanup and add doc string

* add checkpoint to model card instead

* add pretssel hifigan

* update

* update

* update

* update

* update

* --amend

* Update src/seamless_communication/models/vocoder/builder.py

Co-authored-by: Kaushik Ram Sadagopan <krs@fb.com>

* update

* add pretssel vocoder

* update builder and loader

* Implement PretsselModel & its inference

* update checkpoint handling code

* add mel_vocoder

* refactor inference

* minor fix

* minor fix

* minor fix

* mypy pytest isort black formatting

* change padding to 'same'

* minor renaming

* functional melvocoder + watermark

* remove pretssel from the public-facing module

* checkpoint code

* update checkpoint and fix typos in builder

* update integration test

* make languages updatable automatically

* blend PostNet

* update test cases with new mixing logic of wm layers

* linting

* remove obsolete state dict keys in the checkpoint

* debug the wm deltas error

* linting

* linting

* correct typo

* exclude scripts/convert_mel_hifigan_chkpt.py from the PR

* fix typo

* exclude non-related linting (we will do it in another PR)

* exclude non-related linting (we will do it in another PR)

* exclude non-related linting (we will do it in another PR)

* exclude non-related linting and make it in a separate PR

* exclude non-related linting (we will do it in another PR)

* obsfucate further the code by making torch.tanh embeddable in conv1d, so it is not possible to remove one block of code from vocoder, but only multiple blocks at multiple places

* update watermarking.py as a standalone script

* add temporary coments to support the PR review

* fix typos

* fix linting

* update with new loader API from fairseq2

* re-sync ecapa_tdnn with main

* re-sync ecapa_tdnn_builder with main

* fix typo in card

* update compile_checkpoint to new loader API

* Yilin's comments

* check if gcmnvn_mean is loaded

* update deps of unity; update test case

* update cuda setting for watermarking

* (1) remove the unity deps in this PR - we put this in another PR; (2) address last comments from Alex and Yilin

* tmp

* enable batched watermark inference

* revise

* update config for 24khz audios

* update model config for 24khz expetect rate

* fix model card and checkpoint compilation script

* update test case to 24khz

* linting

* integra Kaushik's PR #134

* update model card info, update deps on unity

* revise sample_rate

* remove internal comments

---------

Co-authored-by: hady elsahar <hadyelsahar@meta.com>
Co-authored-by: Changhan Wang <changhan@fb.com>
Co-authored-by: Changhan Wang <wangchanghan@gmail.com>
Co-authored-by: Kaushik Ram Sadagopan <krs@fb.com>
Co-authored-by: Tuan Tran <tuantran@devfair0436.h2.fair>
Co-authored-by: Yilin Yang <yilinyang721@gmail.com>
Tuan Tran hai 1 ano
pai
achega
9f6ade6ee4

+ 6 - 6
scripts/convert_pretssel_hifigan_chkpt.py → scripts/convert_mel_hifigan_chkpt.py

@@ -9,10 +9,9 @@ out_channels -> upsample_initial_channel
 """
 
 
-def main():
-    chkpt_root = "/checkpoint/mjhwang/experiments/231007-mel_vocoder-mls_multilingual_6lang/train_mls_multilingual_6lang_subset_hifigan.v1_8gpu_adapt"
-    cfg = f"{chkpt_root}/config.yml"
-    # TODO: display cfg
+def main() -> None:
+    # chkpt_root = "/checkpoint/mjhwang/experiments/231007-mel_vocoder-mls_multilingual_6lang/train_mls_multilingual_6lang_subset_hifigan.v1_8gpu_adapt"
+    chkpt_root = "/checkpoint/mjhwang/experiments/231112-mel_vocoder-ai_speech_24khz/train_train_highquality_speech_20231111_no16khz_100000_hifigan.v1_8gpu_adapt"
     chkpt = torch.load(f"{chkpt_root}/checkpoint-400000steps.pkl")
     del chkpt["model"]["discriminator"]
     conv_seq_map = {
@@ -21,7 +20,7 @@ def main():
         ".1.weight_v": ".weight_v",
     }
 
-    def update_key(k):
+    def update_key(k: str) -> str:
         if k.startswith("input_conv"):
             k = k.replace("input_conv", "conv_pre")
         elif k.startswith("upsamples"):
@@ -50,7 +49,8 @@ def main():
     for k in ["optimizer", "scheduler", "steps", "epochs"]:
         del chkpt[k]
 
-    out_path = "/large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt"
+    # out_path = "/large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt"
+    out_path = "/large_experiments/seamless/workstream/expressivity/oss/checkpoints/melhifigan_20231121.pt"
     torch.save(chkpt, out_path)
 
 

+ 208 - 0
scripts/watermarking/compile_chkpt.py

@@ -0,0 +1,208 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+#
+#
+# The rules to blend the p2v decoder, mel-vocoder and the watermarking:
+#
+# Step 1) Make the big sequential module `layers` that consists of:
+#    - PostNet (last layer of the p2v decoder) : 5 layers
+#    - mel-vocoder layers (conv_pre, ups, resblocks, conv_post): 18 layers
+#    - watermarking encoder and decoder: 32 layers
+#
+# Step 2) Take the last 32 layers of the watermarking, split into 4 blocks of
+# 8 layers. Mix these blocks into the previous layers
+#
+# The final mixed architecture SPVM (Spaghetti Pretssel Vovoder Model):
+#
+#     <P2V: Post Net>
+#           |
+# <Block 1 of Watermarker> ------
+#           |                   |
+#          \/                   |
+#  <Melvocoder: Conv_pre>       |
+#           | (skip)            |
+# <Block 2 of Watermarker> -----|
+#           |                   |
+#          \/                   |
+# <Melvocoder: Upsampler>       |
+#           | (skip)            |
+# <Block 3 of Watermarker> -----|
+#           |                   |
+#          \/                   |
+# <Melvocoder: Resblocks>       |
+#           | (skip)            |
+# <Block 4 of Watermarker> -----|
+#           |                   |
+#          \/                   |
+#  <Melvocoder: Conv_post>      |
+#           |                   |
+#           | ------------------|
+#           |
+#          \/
+#    watermarked wavs
+
+from pathlib import Path
+from argparse import ArgumentParser
+from typing import Any, Mapping, Match
+
+import torch
+from fairseq2.models.utils.checkpoint import (
+    convert_fairseq_checkpoint,
+    convert_model_state_dict,
+    load_checkpoint,
+)
+
+
+def pretssel_key_map() -> Mapping[str, str]:
+    """
+    The rule for renaming the layers of Pretssel model checkpoint:
+        - Merge decoder.postnet into `layers`
+    """
+    from seamless_communication.models.pretssel.loader import _fairseq_key_map  # noqa
+
+    key_map = _fairseq_key_map(None)  # type: ignore[arg-type]
+    del key_map[r"^decoder\.postnet\."]
+    key_map[r"^decoder\.postnet\.convolutions\."] = r"layers."
+    return key_map
+
+
+def vocoder_key_map() -> Mapping[str, Any]:
+    """
+    Rename layers in the mel-vocoder checkpoint. We flatten the vocoder arch and put everything
+    into the `layers`, in which `postnet_size` layers from the PostNet already present. In other words:
+        - conv_pre -> layers.<postnet_size + watermark / 4>
+        - ups.i -> layers.<postnet_size + 1 + i + watermark_size / 2>
+        - resblocks.i -> layers.<postnet_size + 1 + ups_size + i + 3 * watermark_size / 4>
+        - conv_post.i -> layers.<postnet_size + 1 + ups_size + resblocks_size + i + watermark_size>
+    """
+
+    return {
+        # fmt: off
+        # postnet_size = 5, 1st wm block = 8 -> 13
+        r"^conv_pre\.":               r"layers.13.",                                 # noqa, 2nd wm block = 8 -> +8
+        r"^ups\.([0-9]+)\.":          lambda x: f"layers.{int(x.group(1)) + 22}.",   # noqa, ups_size = 4, 3rd wm block = 8 -> +12
+        r"^resblocks\.([0-9]+)\.":    lambda x: f"layers.{int(x.group(1)) + 34}.",   # noqa, resblocks_size = 12, 4th wm block = 8 -> +20
+        r"^conv_post\.":              r"layers.54.",
+        # fmt: on
+    }
+
+
+def wm_key_map() -> Mapping[Any, Any]:
+    """
+    flatten all encoders and decoders into the one sequential layer (step 1) and split the watermaker
+    into 4 blocks and mix into the layers of the p2v decoder and mel-vocoder:
+        - encoder.model.[0-7] --> layers.<postnet_size + i> (5 --> 12)
+        - encoder.model.[8-15] --> layers.<postnet_size + 9> (14 --> 21)
+        - decoder.model.[0-7] --> layers.<postnet_size + vocoder_ups_size + conv_pre + 16> (26 -> 33)
+        - decoder.model.[8-15] --> layers.<postnet_size + vocoder_ups_size + conv_pre + resblock_size + 24> (46 -> 53)
+    """
+
+    def encoder_layer_index(match_obj: Match[str]) -> str:
+        idx = int(match_obj.group(1))
+        # First half of the encoder is after the PostNet
+        if idx < 8:
+            # postnet_size = 5
+            return f"layers.{idx + 5}."
+
+        # Second half of the encoder goes after the mel-vocoder:conv_pre
+        else:
+            # postnet = 5, conv_pre = 1 --> +6
+            return f"layers.{idx + 6}."
+
+    def decoder_layer_index(match_obj: Match[str]) -> str:
+        idx = int(match_obj.group(1))
+        # First half of the decoder is after the mel-vocoder:ups
+        if idx < 8:
+            # postnet 5, conv_pre 1, encoder 16, ups 4 --> +26
+            return f"layers.{idx + 26}."
+        else:
+            # postnet 5, conv_pre 1, encoder 16, ups 4, resblock 12 -> +38
+            return f"layers.{idx + 38}."
+
+    return {
+        r"^encoder\.model\.([0-9]+)\.": encoder_layer_index,
+        r"^decoder\.model\.([0-9]+)\.": decoder_layer_index,
+    }
+
+
+def combine_chkpts(pretssel_file: str, vocoder_file: str, out_path: str) -> None:
+    """Combine the pretssel and melhifigan into one model"""
+    pretssel_chkpt = load_checkpoint(pretssel_file)
+    pretssel_chkpt = convert_fairseq_checkpoint(pretssel_chkpt, pretssel_key_map())
+
+    vocoder_chkpt = load_checkpoint(vocoder_file)
+    vocoder_chkpt = convert_fairseq_checkpoint(vocoder_chkpt, vocoder_key_map())
+
+    wm_ckpt = load_checkpoint(
+        "/large_experiments/seamless/nllb/watermarking/checkpoints/ckpt_e9d0008c.th",
+    )
+    # wm_ckpt is not a fairseq2 checkpoint so we have to handle it differently
+    wm_ckpt = convert_model_state_dict(wm_ckpt, wm_key_map())
+
+    # Merge the state dicts
+    ckpt = pretssel_chkpt
+    state_dict = ckpt["model"]
+    for vocoder_key in vocoder_chkpt["model"]:
+        state_dict[vocoder_key] = vocoder_chkpt["model"][vocoder_key]
+
+    for wm_key, wm_val in wm_ckpt.items():
+        if wm_key.startswith("layers."):
+            state_dict[wm_key] = wm_val
+
+    # Remove obsolete layers
+    keys_to_delete = [
+        "encoder.embed_positions._float_tensor",
+        "decoder.embed_positions._float_tensor",
+        "enc_emb_proj.weight",
+        "enc_emb_proj.bias",
+    ]
+    keys_to_delete.extend(
+        [
+            key
+            for key in state_dict
+            if key.startswith("decoder.var_adaptor.duration_predictor")
+        ]
+    )
+    for key in keys_to_delete:
+        if key in state_dict:
+            del state_dict[key]
+
+    out_path = "/large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-final.pt"
+    model_mapping_metafile = Path(out_path).with_suffix(".arch")
+    with open(model_mapping_metafile, "w", encoding="utf-8") as o:
+        o.write(vocoder_key_map.__doc__)  # type: ignore
+        o.write("\n")
+        o.write(wm_key_map.__doc__)  # type: ignore
+        o.write("\n")
+    torch.save(ckpt, out_path)
+
+
+if __name__ == "__main__":
+    # fmt: off
+    parser = ArgumentParser(description="Compile watermarking into p2v decoder and vocoder")
+    parser.add_argument(
+        "--pretssel",
+        default="/checkpoint/mjhwang/experiments/230930-noiseaug_p2v-mls_multilingual_6lang/231005-noiseaug_p2v-mls_multilingual_6lang-alignfix.config_v2.langemb1.vuv_logit1.denoise.ngpu16/checkpoint_best.pt",
+        type=str,
+        help="Path to the Pretssel model checkpoint",
+    )
+    parser.add_argument(
+        "--vocoder",
+        # default="/large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt",
+        default="/large_experiments/seamless/workstream/expressivity/oss/checkpoints/melhifigan_20231121.pt",
+        type=str,
+        help="Path to the mel-vocoder checkpoint",
+    )
+    parser.add_argument(
+        "--output",
+        default="/large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-final.pt",
+        # default="/large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-20231121.pt",
+        type=str,
+        help="Path to the output watermarked model checkpoint",
+    )
+    # fmt: on
+    args = parser.parse_args()
+    combine_chkpts(args.pretssel, args.vocoder, args.output)

+ 44 - 0
scripts/watermarking/seamlesswatermark.yaml

@@ -0,0 +1,44 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+name: seamlesswatermark
+model_type: seanet
+checkpoint: "/large_experiments/seamless/nllb/watermarking/checkpoints/ckpt_e9d0008c.th"
+watermarker_model:
+  channels: 1
+  sample_rate: 16000
+seanet:
+  activation: ELU
+  activation_params:
+    alpha: 1.0
+  causal: false
+  channels: 1
+  compress: 2
+  decoder:
+    final_activation: null
+    final_activation_params: null
+    trim_right_ratio: 1.0
+  detector: {}
+  dilation_base: 2
+  dimension: 128
+  disable_norm_outer_blocks: 0
+  encoder: {}
+  kernel_size: 7
+  last_kernel_size: 7
+  lstm: 2
+  n_filters: 32
+  n_residual_layers: 1
+  norm: weight_norm
+  norm_params: {}
+  pad_mode: constant
+  ratios:
+  - 8
+  - 5
+  - 4
+  - 2
+  residual_kernel_size: 3
+  true_skip: true

+ 236 - 0
scripts/watermarking/watermarking.py

@@ -0,0 +1,236 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# The original implementation for the watermarking
+# This is not open-sourced and only kept here for future reference
+# mypy: ignore-errors
+
+import math
+from argparse import ArgumentParser, ArgumentTypeError
+from pathlib import Path
+from typing import Any, Dict, Union, cast
+
+import audiocraft
+import omegaconf
+import torch
+import torch.nn as nn
+import torchaudio
+from audiocraft.modules.seanet import SEANetEncoder
+from audiocraft.utils.utils import dict_from_config
+from fairseq2.typing import DataType, Device
+
+
+class SEANetEncoderKeepDimension(SEANetEncoder):
+    """
+    similar architecture to the SEANet encoder but with an extra step that
+    projects the output dimension to the same input dimension by repeating
+    the sequential
+
+    Args:
+        SEANetEncoder (_type_): _description_
+    """
+
+    def __init__(self, output_hidden_dim: int = 8, *args, **kwargs):  # type: ignore
+        super().__init__(*args, **kwargs)
+        self.output_hidden_dim = output_hidden_dim
+        # Adding a reverse convolution layer
+        self.reverse_convolution = nn.ConvTranspose1d(
+            in_channels=self.dimension,
+            out_channels=self.output_hidden_dim,
+            kernel_size=math.prod(self.ratios),
+            stride=math.prod(self.ratios),
+            padding=0,
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        orig_nframes = x.shape[-1]
+        x = self.model(x)
+        x = self.reverse_convolution(x)
+        # make sure dim didn't change
+        x = x[:, :, :orig_nframes]
+        return x
+
+
+class Watermarker(nn.Module):
+    """
+    Initialize the Watermarker model.
+
+    Args:
+        encoder (nn.Module): Watermark Encoder.
+        decoder (nn.Module): Watermark Decoder.
+        detector (nn.Module): Watermark Detector.
+        sample_rate (int): Audio sample rate.
+        channels (int): Number of audio channels.
+    """
+
+    sample_rate: int = 0
+    channels: int = 0
+    encoder: SEANetEncoder
+    decoder: SEANetEncoder
+    detector: SEANetEncoderKeepDimension
+
+    def __init__(
+        self,
+        encoder: SEANetEncoder,
+        decoder: SEANetEncoder,
+        detector: SEANetEncoderKeepDimension,
+        sample_rate: int,
+        channels: int,
+    ):
+        super().__init__()
+        self.encoder = encoder
+        self.decoder = decoder
+        self.detector = detector
+        self.sample_rate = sample_rate
+        self.channels = channels
+
+    def get_watermark(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Get the watermark from a batch of audio input.
+
+        Args:
+            x (torch.Tensor): Input audio tensor with dimensions [batch size, channels = 1, frames].
+
+        Returns:
+            torch.Tensor: Output watermark with the same dimensionality as the input.
+        """
+        hidden = self.encoder(x)
+        # assert dim in = dim out
+        watermark = self.decoder(hidden)[:, :, : x.size(-1)]
+        return watermark
+
+    def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Detect the watermark in a batch of audio input.
+
+        Args:
+            x (torch.Tensor): Input audio tensor with dimensions
+            [batch size, channels = 1, frames].
+
+        Returns:
+            torch.Tensor: Predictions of the classifier for watermark
+            with dimensions [bsz, classes = 2, frames].
+            For each frame, the detector outputs probabilities of
+            non-watermarked class (class id 0) and
+            the probability of "watermarked" class (class id 1).
+            To do inference, you can use output[:, 1, :]
+            to get probabilities of input audio being watermarked.
+        """
+        return self.detector(x)
+
+
+def model_from_checkpoint(
+    checkpoint_path: Union[Path, str] = Path(__file__).parent
+    / "seamlesswatermark.yaml",
+    device: Union[torch.device, str] = "cpu",
+    dtype: DataType = torch.float32,
+) -> Watermarker:
+    """Instantiate a Watermarker model from a given checkpoint path.
+
+    Example usage:
+    >>> from watermarking.watermarking import *
+    >>> cfg = "seamlesswatermark.yaml"
+    >>> url = "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav"
+    >>> urllib.request.urlretrieve(url, "random.wav")
+    >>> wav, _ = torchaudio.load("random.wav")
+    >>> wav = wav.unsqueeze(0)  # add bsz dimension
+
+    # code starts here
+    >>> model = model_from_checkpoint(cfg, device = wav.device)
+
+    >>> watermark = model.get_watermark(wav)
+
+    >>> watermarked_audio = wav + watermark
+    >>> detection = model.detect_watermark(watermarked_audio)
+    >>> print(detection[:,1,:])  # print prob of watermarked class # should be > 0.5
+
+    >>> detection = model.detect_watermark(wav)
+    >>> print(detection[:,1,:])  # print prob of watermarked class  # should be < 0.5
+
+    Args:
+        checkpoint_path (Path or str): Path to the checkpoint file.
+        device (torch.device or str, optional): Device on which
+        the model is loaded (default is "cpu").
+
+    Returns:
+        Watermarker: An instance of the Watermarker model loaded from the checkpoint.
+    """
+    cfg = omegaconf.OmegaConf.load(checkpoint_path)
+    state: Dict[str, Any] = torch.load(cfg["checkpoint"])
+    watermarking_model = get_watermarking_model(cfg)
+    watermarking_model.load_state_dict(state)
+    watermarking_model = watermarking_model.to(device, dtype=dtype)
+    watermarking_model.eval()
+    return watermarking_model
+
+
+def get_watermarking_model(cfg: omegaconf.DictConfig) -> Watermarker:
+    kwargs = dict_from_config(getattr(cfg, "watermarker_model"))
+    encoder, decoder = get_encodec_autoencoder(cfg)
+    detector = get_detector(cfg)
+    return Watermarker(encoder, decoder, detector, **kwargs)
+
+
+def get_encodec_autoencoder(cfg: omegaconf.DictConfig):
+    kwargs = dict_from_config(getattr(cfg, "seanet"))
+    if hasattr(cfg.seanet, "detector"):
+        kwargs.pop("detector")
+    encoder_override_kwargs = kwargs.pop("encoder")
+    decoder_override_kwargs = kwargs.pop("decoder")
+    encoder_kwargs = {**kwargs, **encoder_override_kwargs}
+    decoder_kwargs = {**kwargs, **decoder_override_kwargs}
+    encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
+    decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
+    return encoder, decoder
+
+
+def get_detector(cfg: omegaconf.DictConfig):
+    kwargs = dict_from_config(getattr(cfg, "seanet"))
+    encoder_override_kwargs = kwargs.pop("detector")
+    kwargs.pop("decoder")
+    kwargs.pop("encoder")
+    encoder_kwargs = {**kwargs, **encoder_override_kwargs}
+    output_hidden_dim = 8
+    encoder = SEANetEncoderKeepDimension(output_hidden_dim, **encoder_kwargs)
+
+    last_layer = torch.nn.Conv1d(output_hidden_dim, 2, 1)
+    softmax = torch.nn.Softmax(dim=1)
+    detector = torch.nn.Sequential(encoder, last_layer, softmax)
+    return detector
+
+
+def parse_device_arg(value: str) -> Device:
+    try:
+        return Device(value)
+    except RuntimeError:
+        raise ArgumentTypeError(f"'{value}' is not a valid device name.")
+
+
+if __name__ == "__main__":
+    """
+    Example usage:
+    python watermarking.py --device cuda:0 detect [file.wav]
+    """
+    parser = ArgumentParser(description="Handle the watermarking for audios")
+    parser.add_argument(
+        "--device",
+        default="cpu",
+        type=parse_device_arg,
+        help="device on which to run tests (default: %(default)s)",
+    )
+    sub_parser = parser.add_subparsers(title="actions", dest="sub_cmd")
+    detect_parser = sub_parser.add_parser("detect")
+    wm_parser = sub_parser.add_parser("wm")
+    parser.add_argument("file", type=str, help="Path to the .wav file")
+
+    args = parser.parse_args()
+
+    if args.sub_cmd == "detect":
+        model = model_from_checkpoint(device=args.device)
+        wav, _ = torchaudio.load(args.file)
+        wav = wav.unsqueeze(0)
+        wav = wav.to(args.device)
+        detection = model.detect_watermark(wav)
+        print(detection[:, 1, :])

+ 182 - 0
src/seamless_communication/cards/vocoder_pretssel.yaml

@@ -0,0 +1,182 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: vocoder_pretssel
+model_type: vocoder_pretssel
+model_arch: 24khz
+checkpoint: "file:///large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-final.pt"
+sample_rate: 24000
+model_config:
+  langs:
+    - cmn
+    - deu
+    - eng
+    - fra
+    - ita
+    - spa
+  gcmvn_stats:
+    mean:
+      - 9.023406257490224
+      - 9.406622923058864
+      - 10.554165334059368
+      - 11.475190058682356
+      - 12.179117104099705
+      - 12.603782921407062
+      - 12.769632747861747
+      - 12.714276772934083
+      - 12.747612172560233
+      - 12.750373688097946
+      - 12.948050207790237
+      - 13.121829398704277
+      - 13.40130828476734
+      - 13.58028050886195
+      - 13.601835409305883
+      - 13.608734047373218
+      - 13.538274892335826
+      - 13.391518457210937
+      - 13.382843811359622
+      - 13.0524299456858
+      - 12.785193828396269
+      - 12.876608812372632
+      - 12.59571918874957
+      - 12.674484745567813
+      - 12.57325195345546
+      - 12.651938120109422
+      - 12.556821722150424
+      - 12.639338348530158
+      - 12.610449431411217
+      - 12.639992872912376
+      - 12.697503827987052
+      - 12.754788270377214
+      - 12.837605043617405
+      - 12.964379088501497
+      - 13.11997048142582
+      - 13.267395589173432
+      - 13.384668687260483
+      - 13.495000208959356
+      - 13.606835320307384
+      - 13.578073476073252
+      - 13.689796531497368
+      - 13.643079802391588
+      - 13.7340755472615
+      - 13.735199777666043
+      - 13.79347692248429
+      - 13.875183654243305
+      - 13.967272256671393
+      - 14.058507936754117
+      - 14.114704594203507
+      - 14.156211337193277
+      - 14.14747081594401
+      - 14.173917097974343
+      - 14.22330474758318
+      - 14.251272943225572
+      - 14.230904505178053
+      - 14.226937644205396
+      - 14.222223350670225
+      - 14.211638354996317
+      - 14.208930098405544
+      - 14.19476983404041
+      - 14.2195925729048
+      - 14.16490878238837
+      - 14.115436751205117
+      - 14.039442767347872
+      - 13.976934063901625
+      - 13.917068116556464
+      - 13.856293662219073
+      - 13.773769842100085
+      - 13.706245521082796
+      - 13.685052933361192
+      - 13.68570131643094
+      - 13.714811890011152
+      - 13.751451253935347
+      - 13.772212258132148
+      - 13.76013448427468
+      - 13.702368406557508
+      - 13.600406368803617
+      - 13.369574889658164
+      - 12.998399608309988
+      - 12.443732902848723
+    std:
+      - 3.729248515707457
+      - 4.001623098079929
+      - 4.570009061358065
+      - 4.811572361201577
+      - 5.010239923828185
+      - 5.152145212706857
+      - 5.223885876119451
+      - 5.224443623432338
+      - 5.161790275239061
+      - 5.098988232815804
+      - 5.090890035509122
+      - 5.130345212529546
+      - 5.165849688173366
+      - 5.164761699263693
+      - 5.131177988219367
+      - 5.085522051815558
+      - 5.035829108165894
+      - 4.987478975310455
+      - 4.932652442855969
+      - 4.8650037198748075
+      - 4.799238163232527
+      - 4.727086345775988
+      - 4.646858066575789
+      - 4.5733249959652715
+      - 4.51685060334288
+      - 4.467449073425149
+      - 4.4296881304192075
+      - 4.4028775449713775
+      - 4.397905653025904
+      - 4.3862594566308015
+      - 4.366485847923521
+      - 4.344483498393771
+      - 4.324692736391383
+      - 4.310481738978154
+      - 4.3053492473916
+      - 4.3035205126659655
+      - 4.2987898577000605
+      - 4.287403454800855
+      - 4.27087296372773
+      - 4.25387490294079
+      - 4.233513102251301
+      - 4.212047255068752
+      - 4.1810370158214445
+      - 4.186014591107853
+      - 4.194806047136222
+      - 4.2183377208747075
+      - 4.249293562464735
+      - 4.268847210561774
+      - 4.270455756367186
+      - 4.25811368227528
+      - 4.245975115347766
+      - 4.23058010369271
+      - 4.203075111087773
+      - 4.20123812057283
+      - 4.187143614375688
+      - 4.172633823274146
+      - 4.162541203161947
+      - 4.156022884601996
+      - 4.1618428838805706
+      - 4.157259439238067
+      - 4.139859013016601
+      - 4.150685014911159
+      - 4.152025499126372
+      - 4.165010788120131
+      - 4.15179422331336
+      - 4.137041631098819
+      - 4.10861757770052
+      - 4.119916019361405
+      - 4.131749366642117
+      - 4.119438578634397
+      - 4.100095269698108
+      - 4.073900009963118
+      - 4.0580796715728855
+      - 4.050916705279105
+      - 4.037976834115189
+      - 4.023757063156459
+      - 3.9987849927993353
+      - 3.989251079820668
+      - 3.9464430977885256
+      - 3.8673932921278995

+ 182 - 0
src/seamless_communication/cards/vocoder_pretssel_16khz.yaml

@@ -0,0 +1,182 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: vocoder_pretssel_16khz
+model_type: vocoder_pretssel
+model_arch: 16khz
+checkpoint: "file:///large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-16khz.pt"
+sample_rate: 16000
+model_config:
+  langs:
+    - cmn
+    - deu
+    - eng
+    - fra
+    - ita
+    - spa
+  gcmvn_stats:
+    mean:
+      - 9.023406257490224
+      - 9.406622923058864
+      - 10.554165334059368
+      - 11.475190058682356
+      - 12.179117104099705
+      - 12.603782921407062
+      - 12.769632747861747
+      - 12.714276772934083
+      - 12.747612172560233
+      - 12.750373688097946
+      - 12.948050207790237
+      - 13.121829398704277
+      - 13.40130828476734
+      - 13.58028050886195
+      - 13.601835409305883
+      - 13.608734047373218
+      - 13.538274892335826
+      - 13.391518457210937
+      - 13.382843811359622
+      - 13.0524299456858
+      - 12.785193828396269
+      - 12.876608812372632
+      - 12.59571918874957
+      - 12.674484745567813
+      - 12.57325195345546
+      - 12.651938120109422
+      - 12.556821722150424
+      - 12.639338348530158
+      - 12.610449431411217
+      - 12.639992872912376
+      - 12.697503827987052
+      - 12.754788270377214
+      - 12.837605043617405
+      - 12.964379088501497
+      - 13.11997048142582
+      - 13.267395589173432
+      - 13.384668687260483
+      - 13.495000208959356
+      - 13.606835320307384
+      - 13.578073476073252
+      - 13.689796531497368
+      - 13.643079802391588
+      - 13.7340755472615
+      - 13.735199777666043
+      - 13.79347692248429
+      - 13.875183654243305
+      - 13.967272256671393
+      - 14.058507936754117
+      - 14.114704594203507
+      - 14.156211337193277
+      - 14.14747081594401
+      - 14.173917097974343
+      - 14.22330474758318
+      - 14.251272943225572
+      - 14.230904505178053
+      - 14.226937644205396
+      - 14.222223350670225
+      - 14.211638354996317
+      - 14.208930098405544
+      - 14.19476983404041
+      - 14.2195925729048
+      - 14.16490878238837
+      - 14.115436751205117
+      - 14.039442767347872
+      - 13.976934063901625
+      - 13.917068116556464
+      - 13.856293662219073
+      - 13.773769842100085
+      - 13.706245521082796
+      - 13.685052933361192
+      - 13.68570131643094
+      - 13.714811890011152
+      - 13.751451253935347
+      - 13.772212258132148
+      - 13.76013448427468
+      - 13.702368406557508
+      - 13.600406368803617
+      - 13.369574889658164
+      - 12.998399608309988
+      - 12.443732902848723
+    std:
+      - 3.729248515707457
+      - 4.001623098079929
+      - 4.570009061358065
+      - 4.811572361201577
+      - 5.010239923828185
+      - 5.152145212706857
+      - 5.223885876119451
+      - 5.224443623432338
+      - 5.161790275239061
+      - 5.098988232815804
+      - 5.090890035509122
+      - 5.130345212529546
+      - 5.165849688173366
+      - 5.164761699263693
+      - 5.131177988219367
+      - 5.085522051815558
+      - 5.035829108165894
+      - 4.987478975310455
+      - 4.932652442855969
+      - 4.8650037198748075
+      - 4.799238163232527
+      - 4.727086345775988
+      - 4.646858066575789
+      - 4.5733249959652715
+      - 4.51685060334288
+      - 4.467449073425149
+      - 4.4296881304192075
+      - 4.4028775449713775
+      - 4.397905653025904
+      - 4.3862594566308015
+      - 4.366485847923521
+      - 4.344483498393771
+      - 4.324692736391383
+      - 4.310481738978154
+      - 4.3053492473916
+      - 4.3035205126659655
+      - 4.2987898577000605
+      - 4.287403454800855
+      - 4.27087296372773
+      - 4.25387490294079
+      - 4.233513102251301
+      - 4.212047255068752
+      - 4.1810370158214445
+      - 4.186014591107853
+      - 4.194806047136222
+      - 4.2183377208747075
+      - 4.249293562464735
+      - 4.268847210561774
+      - 4.270455756367186
+      - 4.25811368227528
+      - 4.245975115347766
+      - 4.23058010369271
+      - 4.203075111087773
+      - 4.20123812057283
+      - 4.187143614375688
+      - 4.172633823274146
+      - 4.162541203161947
+      - 4.156022884601996
+      - 4.1618428838805706
+      - 4.157259439238067
+      - 4.139859013016601
+      - 4.150685014911159
+      - 4.152025499126372
+      - 4.165010788120131
+      - 4.15179422331336
+      - 4.137041631098819
+      - 4.10861757770052
+      - 4.119916019361405
+      - 4.131749366642117
+      - 4.119438578634397
+      - 4.100095269698108
+      - 4.073900009963118
+      - 4.0580796715728855
+      - 4.050916705279105
+      - 4.037976834115189
+      - 4.023757063156459
+      - 3.9987849927993353
+      - 3.989251079820668
+      - 3.9464430977885256
+      - 3.8673932921278995

+ 1 - 1
src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -385,7 +385,7 @@ def main() -> None:
         text_generation_opts=text_generation_opts,
         unit_generation_opts=unit_generation_opts,
         unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
-        output_path=Path(args.output_path),
+        output_path=args.output_path,
         gcmvn_mean=torch.tensor(gcmvn_mean, device=device, dtype=dtype),
         gcmvn_std=torch.tensor(gcmvn_std, device=device, dtype=dtype),
         pretssel_model=args.pretssel_model,

+ 373 - 0
src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py

@@ -0,0 +1,373 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import contextlib
+import logging
+from argparse import Namespace
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch.nn import Module
+import torchaudio
+from fairseq2.assets.card import AssetCard
+from fairseq2.data import Collater, DataPipeline, FileMapper, SequenceData
+from fairseq2.data.audio import (
+    AudioDecoder,
+    WaveformToFbankConverter,
+    WaveformToFbankOutput,
+)
+from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
+from fairseq2.generation import SequenceGeneratorOptions
+from fairseq2.typing import DataType, Device
+from fairseq2.nn.padding import get_seqs_and_padding_mask
+from sacrebleu.metrics import BLEU  # type: ignore[attr-defined]
+from torch import Tensor
+from tqdm import tqdm
+
+from seamless_communication.models.unity import UnitTokenizer
+from seamless_communication.cli.m4t.evaluate.evaluate import (
+    adjust_output_for_corrupted_inputs,
+    count_lines,
+)
+from seamless_communication.cli.m4t.predict import (
+    add_inference_arguments,
+    set_generation_opts,
+)
+from seamless_communication.inference import BatchedSpeechOutput, Translator
+from seamless_communication.models.unity import (
+    load_gcmvn_stats,
+    load_unity_text_tokenizer,
+    load_unity_unit_tokenizer,
+)
+from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
+
+class PretsselGenerator(Module):
+    def __init__(
+        self,
+        pretssel_name_or_card: Union[str, AssetCard],
+        unit_tokenizer: UnitTokenizer,
+        device: Device,
+        dtype: DataType = torch.float16,
+    ):
+        super().__init__()
+        # Load the model.
+        if device == torch.device("cpu"):
+            dtype = torch.float32
+
+        self.device = device
+        self.dtype = dtype
+
+        self.pretssel_model = load_pretssel_vocoder_model(
+            pretssel_name_or_card,
+            device=device,
+            dtype=dtype,
+        )
+        self.pretssel_model.eval()
+
+        vocoder_model_card = asset_store.retrieve_card(vocoder_name_or_card)
+        self.output_sample_rate = vocoder_model_card.field("sample_rate").as_(int)
+
+        self.unit_tokenizer = unit_tokenizer
+        self.unit_collate = Collater(pad_value=unit_tokenizer.vocab_info.pad_idx)
+        self.duration_collate = Collater(pad_value=0)
+
+    @torch.inference_mode()
+    def predict(
+        self,
+        units: List[List[int]],
+        tgt_lang: str,
+        prosody_encoder_input: SequenceData,
+    ) -> BatchedSpeechOutput:
+        audio_wavs = []
+        unit_eos_token = torch.tensor(
+            [self.unit_tokenizer.vocab_info.eos_idx],
+            device=self.device,
+        )
+
+        prosody_input_seqs = prosody_encoder_input["seqs"]
+        prosody_input_lens = prosody_encoder_input["seq_lens"]
+
+        for i, u in enumerate(units):
+            unit = torch.tensor(u).to(unit_eos_token)
+
+            # adjust the control symbols for the embedding
+            unit += 4
+            unit = torch.cat([unit, unit_eos_token], dim=0)
+
+            unit, duration = torch.unique_consecutive(unit, return_counts=True)
+
+            # adjust for the last eos token
+            duration[-1] = 0
+
+            duration *= 2
+
+            prosody_input_seq = prosody_input_seqs[i][:prosody_input_lens[i]]
+
+            audio_wav = self.pretssel_model(
+                unit,
+                tgt_lang,
+                prosody_input_seq,
+                durations=duration.unsqueeze(0),
+            )
+
+            audio_wavs.append(audio_wav)
+
+        return BatchedSpeechOutput(
+            units=units,
+            audio_wavs=audio_wavs,
+            sample_rate=self.output_sample_rate,
+        )
+
+
+def build_data_pipeline(
+    args: Namespace,
+    text_tokenizer: TextTokenizer,
+    device: Device,
+    dtype: DataType,
+    gcmvn_mean: Tensor,
+    gcmvn_std: Tensor,
+) -> DataPipeline:
+    with open(args.data_file, "r") as f:
+        header = f.readline().strip("\n").split("\t")
+
+    n_parallel = 4
+
+    split_tsv = StrSplitter(names=header)
+
+    pipeline_builder = read_text(args.data_file, rtrim=True).skip(1).map(split_tsv)
+
+    assert args.audio_root_dir is not None
+
+    map_file = FileMapper(root_dir=args.audio_root_dir, cached_fd_count=10)
+
+    pipeline_builder.map(map_file, selector="audio", num_parallel_calls=n_parallel)
+
+    decode_audio = AudioDecoder(dtype=torch.float32, device=device)
+
+    convert_to_fbank = WaveformToFbankConverter(
+        num_mel_bins=80,
+        waveform_scale=2**15,
+        channel_last=True,
+        standardize=False,
+        device=device,
+        dtype=dtype,
+    )
+
+    def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
+        fbank = data["fbank"]
+        std, mean = torch.std_mean(fbank, dim=0)
+        data["fbank"] = fbank.subtract(mean).divide(std)
+        data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
+        return data
+
+    pipeline_builder.map(
+        [decode_audio, convert_to_fbank, normalize_fbank],
+        selector="audio.data",
+        num_parallel_calls=n_parallel,
+    )
+
+    pipeline_builder.bucket(bucket_size=args.batch_size)
+
+    collate = Collater(pad_value=0, pad_to_multiple=1)
+
+    pipeline_builder.map(collate, num_parallel_calls=n_parallel)
+
+    pipeline_builder.prefetch(4)
+
+    return pipeline_builder.and_return()
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Running PretsselModel inference")
+    parser.add_argument("data_file", type=Path, help="Data file (.tsv) to be evaluated.")
+
+    parser = add_inference_arguments(parser)
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        help="Inference batch size.",
+        default=4,
+    )
+    parser.add_argument(
+        "--audio_root_dir",
+        type=Path,
+        help="Root directory for the audio filenames in the data file.",
+        default="",
+    )
+    parser.add_argument(
+        "--ref_field",
+        type=str,
+        help="Reference target text field to compute the BLEU score against.",
+        default="tgt_text",
+    )
+    parser.add_argument(
+        "--duration_factor",
+        type=float,
+        help="The duration factor for NAR T2U model. Expressivity model uses 1.1",
+        default=1.1,
+    )
+    args = parser.parse_args()
+
+    if torch.cuda.is_available():
+        device = torch.device("cuda:0")
+        dtype = torch.float16
+    else:
+        device = torch.device("cpu")
+        dtype = torch.float32
+
+    text_tokenizer = load_unity_text_tokenizer(args.model_name)
+    unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
+
+    _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
+    gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
+    gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
+
+    pipeline = build_data_pipeline(
+        args, text_tokenizer, device, dtype, gcmvn_mean, gcmvn_std
+    )
+
+    translator = Translator(
+        args.model_name,
+        vocoder_name_or_card=None,
+        device=device,
+        text_tokenizer=text_tokenizer,
+        dtype=dtype,
+    )
+
+    text_generation_opts, unit_generation_opts = set_generation_opts(args)
+
+    logger.info(f"{text_generation_opts=}")
+    logger.info(f"{unit_generation_opts=}")
+    logger.info(
+        f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
+    )
+
+    pretssel_generator = PretsselGenerator(
+        args.vocoder_name,
+        unit_tokenizer=unit_tokenizer,
+        device=device,
+        dtype=dtype,
+    )
+
+    total_steps = count_lines(args.data_file) - 1
+    progress_bar = tqdm(total=total_steps)
+
+    output_path = args.output_path / args.data_file.stem
+    output_path.mkdir(parents=True, exist_ok=True)
+
+    waveforms_dir = output_path / f"waveform"
+    waveforms_dir.mkdir(parents=True, exist_ok=True)
+
+    hyps = []
+    refs = []
+
+    with contextlib.ExitStack() as stack:
+        hyp_file = stack.enter_context(
+            open(output_path / f"text_output-{args.data_file.stem}.txt", "w")
+        )
+        unit_file = stack.enter_context(
+            open(output_path / f"unit_output-{args.data_file.stem}.txt", "w")
+        )
+
+        sample_id = 0
+        for example in pipeline:
+            valid_sequences: Optional[Tensor] = None
+            src = example["audio"]["data"]["fbank"]
+            # Skip corrupted audio tensors.
+            valid_sequences = ~torch.any(
+                torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
+            )
+            if not valid_sequences.all():
+                logger.warning(
+                    f"Sample IDs {sample_id} to {sample_id + args.batch_size} has some corrupted input."
+                )
+                src["seqs"] = src["seqs"][valid_sequences]
+                src["seq_lens"] = src["seq_lens"][valid_sequences]
+
+            # Skip performing inference when the input is entirely corrupted.
+            if src["seqs"].numel() > 0:
+                prosody_encoder_input = example["audio"]["data"]["gcmvn_fbank"]
+                text_output, unit_output = translator.predict(
+                    src,
+                    args.task,
+                    args.tgt_lang,
+                    src_lang=args.src_lang,
+                    text_generation_opts=text_generation_opts,
+                    unit_generation_opts=unit_generation_opts,
+                    unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
+                    duration_factor=args.duration_factor,
+                    prosody_encoder_input=prosody_encoder_input,
+                )
+
+                assert unit_output is not None
+                speech_output = pretssel_generator.predict(
+                    unit_output.units,
+                    tgt_lang=args.tgt_lang,
+                    prosody_encoder_input=prosody_encoder_input,
+                )
+
+            else:
+                text_output = []
+                speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
+
+            if valid_sequences is not None and not valid_sequences.all():
+                text_output, speech_output = adjust_output_for_corrupted_inputs(  # type: ignore[assignment]
+                    valid_sequences,
+                    text_output,
+                    speech_output,
+                )
+
+            hyps += [str(s) for s in text_output]
+            refs += [str(s) for s in example[args.ref_field]]
+
+            for i in range(len(text_output)):
+                t = text_output[i]
+                idx = str(example["id"][i])
+                hyp_file.write(f"{t}\n")
+
+                u = speech_output.units[i]
+                str_units = [str(i) for i in u]
+                unit_file.write(" ".join(str_units) + "\n")
+                torchaudio.save(
+                    waveforms_dir / f"{idx}_pred.wav",
+                    speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
+                    sample_rate=speech_output.sample_rate,
+                )
+
+                sample_id += 1
+                progress_bar.update(1)
+
+    progress_bar.close()
+    logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
+
+    assert len(hyps) == len(refs)
+    if len(hyps) > 0:
+        if args.tgt_lang in ("cmn", "jpn", "lao", "mya", "tha"):
+            tokenizer = "char"
+        else:
+            tokenizer = "13a"
+
+        bleu = BLEU(tokenize=tokenizer)
+        score = bleu.corpus_score(hyps, [refs])
+        bleu_filename = output_path / f"{args.data_file.stem}_text_output_bleu.json"
+        with open(bleu_filename, "w") as f:
+            f.write(score.format(signature=str(bleu.get_signature()), is_json=True))
+        logger.info(score.format(signature=bleu.get_signature()))
+
+
+if __name__ == "__main__":
+    main()

+ 1 - 1
src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -424,7 +424,7 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
         text_generation_opts=text_generation_opts,
         unit_generation_opts=unit_generation_opts,
         unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
-        output_path=Path(args.output_path),
+        output_path=args.output_path,
     )
     # fmt: on
     logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")

+ 3 - 2
src/seamless_communication/cli/m4t/predict/predict.py

@@ -6,6 +6,7 @@
 import argparse
 import logging
 from argparse import Namespace
+from pathlib import Path
 from typing import Tuple
 
 import torch
@@ -35,7 +36,7 @@ def add_inference_arguments(parser: argparse.ArgumentParser) -> argparse.Argumen
     )
     parser.add_argument(
         "--output_path",
-        type=str,
+        type=Path,
         help="Path to save the generated audio.",
         default=None,
     )
@@ -167,7 +168,7 @@ def set_generation_opts(
     return text_generation_opts, unit_generation_opts
 
 
-def main():
+def main() -> None:
     parser = argparse.ArgumentParser(
         description="M4T inference on supported tasks using Translator."
     )

+ 5 - 0
src/seamless_communication/models/generator/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.

+ 506 - 0
src/seamless_communication/models/generator/builder.py

@@ -0,0 +1,506 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Literal, Optional, Tuple
+
+from fairseq2.data import VocabularyInfo
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.nn.embedding import StandardEmbedding, init_scaled_embedding
+from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
+from fairseq2.nn.projection import Linear
+from fairseq2.nn.transformer import (
+    MultiheadAttention,
+    StandardMultiheadAttention,
+    TransformerNormOrder,
+    create_default_sdpa,
+)
+from fairseq2.typing import DataType, Device
+from torch.nn import Conv1d
+
+from seamless_communication.models.generator.ecapa_tdnn_builder import (
+    EcapaTDNNBuilder,
+    EcapaTDNNConfig,
+    ecapa_tdnn_archs,
+)
+from seamless_communication.models.generator.vocoder import (
+    PretsselDecoderFrontend,
+    PretsselEncoderFrontend,
+    PretsselVocoder,
+)
+from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
+from seamless_communication.models.unity.fft_decoder_layer import (
+    Conv1dBlock,
+    FeedForwardTransformerLayer,
+)
+from seamless_communication.models.unity.length_regulator import (
+    VarianceAdaptor,
+    VariancePredictor,
+)
+from seamless_communication.models.unity.t2u_builder import VariancePredictorConfig
+
+
+@dataclass
+class PretsselEncoderFrontendConfig:
+    prosody_encoder_config: EcapaTDNNConfig
+    dropout: float
+    lang_embed_dim: Optional[int] = None
+
+
+@dataclass
+class FFTLayerConfig:
+    attention_heads: int
+    hidden_dim: int
+    kernel_size: int
+    dropout: float
+    conv1d_dropout: float
+    film_cond_dim: int
+    use_film: bool = False
+
+
+@dataclass
+class PretsselDecoderFrontendConfig:
+    upsampling_type: Literal["gaussian", "hard"]
+    variance_predictor_config: VariancePredictorConfig
+    add_variance_parallel: bool
+
+
+@dataclass
+class VocoderConfig:
+    """Holds the configuration of a Vocoder model."""
+
+    encoder_frontend_config: PretsselEncoderFrontendConfig
+    fft_layer_config: FFTLayerConfig
+    decoder_frontend_config: PretsselDecoderFrontendConfig
+    pn_conv_dim: int
+    pn_layers: int
+    pn_conv_kernel_size: int
+    pn_dropout: float
+    vocab_info: VocabularyInfo
+    model_dim: int
+    max_seq_len: int
+    encoder_layers: int
+    decoder_layers: int
+    mel_dim: int
+    langs: List  # type: ignore[type-arg]
+    upsample_rates: List[int]
+    upsample_kernel_sizes: List[int]
+    upsample_initial_channel: int
+    resblock_kernel_sizes: List[int]
+    resblock_dilation_sizes: List[List[int]]
+    channels: int
+    dimension: int
+    n_filters: int
+    ratios: List[int]
+    norm: Literal["none", "weight_norm", "spectral_norm", "time_group_norm"]
+    norm_params: Dict[str, Any]
+    kernel_size: int
+    last_kernel_size: int
+    residual_kernel_size: int
+    causal: bool
+    pad_mode: str
+    true_skip: bool
+    compress: int
+    lstm: int
+    disable_norm_outer_blocks: int
+    trim_right_ratio: float
+    gcmvn_stats: Dict[str, List]  # type: ignore[type-arg]
+
+
+vocoder_archs = ArchitectureRegistry[VocoderConfig]("vocoder_pretssel")
+
+
+vocoder_arch = vocoder_archs.decorator
+
+
+def pretssel_config() -> (
+    Tuple[PretsselEncoderFrontendConfig, FFTLayerConfig, PretsselDecoderFrontendConfig]
+):
+    prosody_encoder_config = ecapa_tdnn_archs.get_config("base")
+
+    encoder_frontend_config = PretsselEncoderFrontendConfig(
+        prosody_encoder_config=prosody_encoder_config,
+        dropout=0.2,
+        lang_embed_dim=64,
+    )
+
+    fft_layer_config = FFTLayerConfig(
+        attention_heads=2,
+        hidden_dim=1024,
+        kernel_size=9,
+        dropout=0.0,
+        conv1d_dropout=0.2,
+        use_film=True,
+        film_cond_dim=576,
+    )
+
+    variance_predictor_config = VariancePredictorConfig(
+        var_pred_hidden_dim=512,
+        var_pred_kernel_size=5,
+        var_pred_dropout=0.5,
+        use_film=True,
+        film_cond_dim=576,
+    )
+
+    decoder_frontend_config = PretsselDecoderFrontendConfig(
+        upsampling_type="gaussian",
+        variance_predictor_config=variance_predictor_config,
+        add_variance_parallel=True,
+    )
+    return (
+        encoder_frontend_config,
+        fft_layer_config,
+        decoder_frontend_config,
+    )
+
+
+@vocoder_arch("16khz")
+def _16khz_vocoder() -> VocoderConfig:
+    (
+        encoder_frontend_config,
+        fft_layer_config,
+        decoder_frontend_config,
+    ) = pretssel_config()
+
+    return VocoderConfig(
+        encoder_frontend_config=encoder_frontend_config,
+        fft_layer_config=fft_layer_config,
+        decoder_frontend_config=decoder_frontend_config,
+        pn_conv_dim=512,
+        pn_layers=5,
+        pn_conv_kernel_size=5,
+        pn_dropout=0.5,
+        vocab_info=VocabularyInfo(
+            size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
+        ),
+        model_dim=256,
+        max_seq_len=4000,
+        encoder_layers=4,
+        decoder_layers=4,
+        mel_dim=80,
+        langs=[],
+        upsample_rates=[5, 4, 4, 2],
+        upsample_kernel_sizes=[10, 8, 8, 4],
+        upsample_initial_channel=512,
+        resblock_kernel_sizes=[3, 7, 11],
+        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+        channels=1,
+        dimension=128,
+        n_filters=32,
+        ratios=[8, 5, 4, 2],
+        norm="weight_norm",
+        norm_params={},
+        kernel_size=7,
+        last_kernel_size=7,
+        residual_kernel_size=3,
+        causal=False,
+        pad_mode="constant",
+        true_skip=True,
+        compress=2,
+        lstm=2,
+        disable_norm_outer_blocks=0,
+        trim_right_ratio=1.0,
+        gcmvn_stats={},
+    )
+
+
+@vocoder_arch("24khz")
+def _24khz_vocoder() -> VocoderConfig:
+    (
+        encoder_frontend_config,
+        fft_layer_config,
+        decoder_frontend_config,
+    ) = pretssel_config()
+
+    return VocoderConfig(
+        encoder_frontend_config=encoder_frontend_config,
+        fft_layer_config=fft_layer_config,
+        decoder_frontend_config=decoder_frontend_config,
+        pn_conv_dim=512,
+        pn_layers=5,
+        pn_conv_kernel_size=5,
+        pn_dropout=0.5,
+        vocab_info=VocabularyInfo(
+            size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
+        ),
+        model_dim=256,
+        max_seq_len=4000,
+        encoder_layers=4,
+        decoder_layers=4,
+        mel_dim=80,
+        langs=[],
+        upsample_rates=[5, 4, 4, 3],
+        upsample_kernel_sizes=[10, 8, 8, 6],
+        upsample_initial_channel=512,
+        resblock_kernel_sizes=[3, 7, 11],
+        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+        channels=1,
+        dimension=128,
+        n_filters=32,
+        ratios=[8, 5, 4, 2],
+        norm="weight_norm",
+        norm_params={},
+        kernel_size=7,
+        last_kernel_size=7,
+        residual_kernel_size=3,
+        causal=False,
+        pad_mode="constant",
+        true_skip=True,
+        compress=2,
+        lstm=2,
+        disable_norm_outer_blocks=0,
+        trim_right_ratio=1.0,
+        gcmvn_stats={},
+    )
+
+
+class PretsselVocoderBuilder:
+    config: VocoderConfig
+    prosody_encoder_builder: EcapaTDNNBuilder
+    device: Optional[Device] = None
+    dtype: Optional[DataType] = None
+
+    def __init__(
+        self,
+        config: VocoderConfig,
+        prosody_encoder_builder: EcapaTDNNBuilder,
+        *,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param device:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        self.config = config
+        self.prosody_encoder_builder = prosody_encoder_builder
+        self.device, self.dtype = device, dtype
+
+    def build_embed_tokens(self) -> StandardEmbedding:
+        """Build a unit embedding table."""
+
+        return StandardEmbedding(
+            num_embeddings=self.config.vocab_info.size,
+            embedding_dim=self.config.model_dim,
+            init_fn=init_scaled_embedding,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_fft(self, num_layers: int) -> FeedForwardTransformer:
+        """Build a Transformer encoder."""
+
+        layers = [self.build_fft_layer() for _ in range(num_layers)]
+
+        return FeedForwardTransformer(
+            layers,
+            norm_order=TransformerNormOrder.POST,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_fft_layer(self) -> FeedForwardTransformerLayer:
+        """Build a Transformer decoder layer."""
+
+        self_attn = self.build_attention(self.config.fft_layer_config.attention_heads)
+
+        conv1d = Conv1dBlock(
+            self.config.model_dim,
+            self.config.fft_layer_config.hidden_dim,
+            self.config.fft_layer_config.kernel_size,
+            bias=True,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        return FeedForwardTransformerLayer(
+            self_attn,
+            conv1d,
+            dropout_p=0.0,  # fairseq1 doesn't have this
+            conv1d_dropout_p=self.config.fft_layer_config.conv1d_dropout,
+            use_film=self.config.fft_layer_config.use_film,
+            film_cond_dim=self.config.fft_layer_config.film_cond_dim,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_attention(self, num_heads: int) -> MultiheadAttention:
+        """Build a Transformer multi-head attention layer."""
+
+        sdpa = create_default_sdpa(attn_dropout_p=self.config.fft_layer_config.dropout)
+
+        return StandardMultiheadAttention(
+            self.config.model_dim,
+            num_heads,
+            sdpa=sdpa,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_variance_adaptor(
+        self,
+        decoder_frontend_config: PretsselDecoderFrontendConfig,
+    ) -> VarianceAdaptor:
+        """Build a variance adaptor module."""
+
+        variance_predictor_config = decoder_frontend_config.variance_predictor_config
+
+        pitch_predictor = VariancePredictor(
+            self.config.model_dim,
+            variance_predictor_config.var_pred_hidden_dim,
+            variance_predictor_config.var_pred_kernel_size,
+            variance_predictor_config.var_pred_dropout,
+            use_film=variance_predictor_config.use_film,
+            film_cond_dim=variance_predictor_config.film_cond_dim,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        embed_pitch = Conv1d(1, self.config.model_dim, kernel_size=1)
+
+        vuv_predictor = VariancePredictor(
+            self.config.model_dim,
+            variance_predictor_config.var_pred_hidden_dim,
+            variance_predictor_config.var_pred_kernel_size,
+            variance_predictor_config.var_pred_dropout,
+            use_film=variance_predictor_config.use_film,
+            film_cond_dim=variance_predictor_config.film_cond_dim,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        energy_predictor = VariancePredictor(
+            self.config.model_dim,
+            variance_predictor_config.var_pred_hidden_dim,
+            variance_predictor_config.var_pred_kernel_size,
+            variance_predictor_config.var_pred_dropout,
+            use_film=variance_predictor_config.use_film,
+            film_cond_dim=variance_predictor_config.film_cond_dim,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        embed_energy = Conv1d(1, self.config.model_dim, kernel_size=1)
+
+        variance_adaptor = VarianceAdaptor(
+            duration_predictor=None,
+            pitch_predictor=pitch_predictor,
+            embed_pitch=embed_pitch,
+            vuv_predictor=vuv_predictor,
+            energy_predictor=energy_predictor,
+            embed_energy=embed_energy,
+            add_variance_parallel=decoder_frontend_config.add_variance_parallel,
+            upsampling_type=decoder_frontend_config.upsampling_type,
+        )
+
+        return variance_adaptor
+
+    def build_model(self) -> PretsselVocoder:
+        """build the pretssel vocoder."""
+        prosody_encoder = self.prosody_encoder_builder.build_model()
+        embed_tokens = self.build_embed_tokens()
+
+        embed_positions = SinusoidalPositionEncoder(
+            self.config.model_dim,
+            self.config.max_seq_len,
+            _legacy_pad_idx=self.config.vocab_info.pad_idx,
+            device=self.device,
+        )
+        lang_to_index = {l: i for i, l in enumerate(self.config.langs)}
+        encoder_frontend = PretsselEncoderFrontend(
+            prosody_encoder,
+            embed_tokens,
+            embed_positions,
+            lang_to_index,
+            lang_embed_dim=self.config.encoder_frontend_config.lang_embed_dim,
+            dropout_p=self.config.encoder_frontend_config.dropout,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        encoder = self.build_fft(self.config.encoder_layers)
+
+        variance_adaptor = self.build_variance_adaptor(
+            self.config.decoder_frontend_config
+        )
+
+        decoder_frontend = PretsselDecoderFrontend(
+            variance_adaptor,
+            embed_positions,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        decoder = self.build_fft(self.config.decoder_layers)
+
+        final_proj = Linear(
+            self.config.model_dim,
+            self.config.mel_dim,
+            bias=True,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        gcmvn_mean = gcmvn_std = None
+        if self.config.gcmvn_stats is not None:
+            gcmvn_mean = self.config.gcmvn_stats["mean"]
+            gcmvn_std = self.config.gcmvn_stats["std"]
+
+        vocoder = PretsselVocoder(
+            encoder_frontend=encoder_frontend,
+            encoder=encoder,
+            decoder_frontend=decoder_frontend,
+            decoder=decoder,
+            final_proj=final_proj,
+            pn_n_channels=self.config.pn_conv_dim,
+            pn_kernel_size=self.config.pn_conv_kernel_size,
+            pn_layers=self.config.pn_layers,
+            pn_dropout=self.config.pn_dropout,
+            upsample_rates=self.config.upsample_rates,
+            upsample_kernel_sizes=self.config.upsample_kernel_sizes,
+            upsample_initial_channel=self.config.upsample_initial_channel,
+            resblock_kernel_sizes=self.config.resblock_kernel_sizes,
+            resblock_dilation_sizes=self.config.resblock_dilation_sizes,
+            channels=self.config.channels,
+            dimension=self.config.dimension,
+            n_filters=self.config.n_filters,
+            ratios=self.config.ratios,
+            norm=self.config.norm,
+            norm_params=self.config.norm_params,
+            kernel_size=self.config.kernel_size,
+            last_kernel_size=self.config.last_kernel_size,
+            residual_kernel_size=self.config.residual_kernel_size,
+            causal=self.config.causal,
+            pad_mode=self.config.pad_mode,
+            true_skip=self.config.true_skip,
+            compress=self.config.compress,
+            lstm=self.config.lstm,
+            disable_norm_outer_blocks=self.config.disable_norm_outer_blocks,
+            trim_right_ratio=self.config.trim_right_ratio,
+            gcmvn_mean=gcmvn_mean,
+            gcmvn_std=gcmvn_std,
+        )
+        vocoder.to(dtype=self.dtype, device=self.device)
+        return vocoder
+
+
+def create_vocoder_model(
+    config: VocoderConfig,
+    device: Optional[Device] = None,
+    dtype: Optional[DataType] = None,
+) -> PretsselVocoder:
+    prosody_encoder_builder = EcapaTDNNBuilder(
+        config.encoder_frontend_config.prosody_encoder_config,
+        device=device,
+        dtype=dtype,
+    )
+    return PretsselVocoderBuilder(
+        config, prosody_encoder_builder, device=device, dtype=dtype
+    ).build_model()

+ 474 - 0
src/seamless_communication/models/generator/ecapa_tdnn.py

@@ -0,0 +1,474 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from fairseq2.nn.padding import PaddingMask, to_padding_mask
+from torch import Tensor
+from torch.nn import Conv1d, LayerNorm, Module, ModuleList, ReLU, Sigmoid, Tanh, init
+
+
+class ECAPA_TDNN(Module):
+    """
+    Represents the ECAPA-TDNN model described in paper:
+    :cite:t`https://doi.org/10.48550/arxiv.2005.07143`.
+
+    Arguments
+    ---------
+    :param channels:
+        Output channels for TDNN/SERes2Net layer.
+    :param kernel_sizes:
+        List of kernel sizes for each layer.
+    :param dilations:
+        List of dilations for kernels in each layer.
+    :param groups:
+        List of groups for kernels in each layer.
+    """
+
+    def __init__(
+        self,
+        channels: List[int],
+        kernel_sizes: List[int],
+        dilations: List[int],
+        attention_channels: int,
+        res2net_scale: int,
+        se_channels: int,
+        global_context: bool,
+        groups: List[int],
+        embed_dim: int,
+        input_dim: int,
+    ):
+        super().__init__()
+        assert len(channels) == len(kernel_sizes) == len(dilations)
+        self.channels = channels
+        self.embed_dim = embed_dim
+        self.blocks = ModuleList()
+
+        self.blocks.append(
+            TDNNBlock(
+                input_dim,
+                channels[0],
+                kernel_sizes[0],
+                dilations[0],
+                groups[0],
+            )
+        )
+
+        # SE-Res2Net layers
+        for i in range(1, len(channels) - 1):
+            self.blocks.append(
+                SERes2NetBlock(
+                    channels[i - 1],
+                    channels[i],
+                    res2net_scale=res2net_scale,
+                    se_channels=se_channels,
+                    kernel_size=kernel_sizes[i],
+                    dilation=dilations[i],
+                    groups=groups[i],
+                )
+            )
+
+        # Multi-layer feature aggregation
+        self.mfa = TDNNBlock(
+            channels[-1],
+            channels[-1],
+            kernel_sizes[-1],
+            dilations[-1],
+            groups=groups[-1],
+        )
+
+        # Attentive Statistical Pooling
+        self.asp = AttentiveStatisticsPooling(
+            channels[-1],
+            attention_channels=attention_channels,
+            global_context=global_context,
+        )
+        self.asp_norm = LayerNorm(channels[-1] * 2, eps=1e-12)
+
+        # Final linear transformation
+        self.fc = Conv1d(
+            in_channels=channels[-1] * 2,
+            out_channels=embed_dim,
+            kernel_size=1,
+        )
+
+        self.reset_parameters()
+
+    def reset_parameters(self) -> None:
+        """Reset the parameters and buffers of the module."""
+
+        def encoder_init(m: Module) -> None:
+            if isinstance(m, Conv1d):
+                init.xavier_uniform_(m.weight, init.calculate_gain("relu"))
+
+        self.apply(encoder_init)
+
+    def forward(
+        self,
+        x: Tensor,
+        padding_mask: Optional[PaddingMask] = None,
+    ) -> Tensor:
+        """Returns the embedding vector.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Tensor of shape (batch, time, channel).
+        """
+        # Minimize transpose for efficiency
+        x = x.transpose(1, 2)
+
+        xl = []
+        for layer in self.blocks:
+            x = layer(x, padding_mask=padding_mask)
+            xl.append(x)
+
+        # Multi-layer feature aggregation
+        x = torch.cat(xl[1:], dim=1)
+        x = self.mfa(x)
+
+        # Attentive Statistical Pooling
+        x = self.asp(x, padding_mask=padding_mask)
+        x = self.asp_norm(x.transpose(1, 2)).transpose(1, 2)
+
+        # Final linear transformation
+        x = self.fc(x)
+
+        x = x.transpose(1, 2).squeeze(1)  # B x C
+        return F.normalize(x, dim=-1)
+
+
+class TDNNBlock(Module):
+    """An implementation of TDNN.
+
+    Arguments
+    ----------
+    :param in_channels : int
+        Number of input channels.
+    :param out_channels : int
+        The number of output channels.
+    :param kernel_size : int
+        The kernel size of the TDNN blocks.
+    :param dilation : int
+        The dilation of the TDNN block.
+    :param groups: int
+        The groups size of the TDNN blocks.
+
+    Example
+    -------
+    >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
+    >>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
+    >>> out_tensor = layer(inp_tensor).transpose(1, 2)
+    >>> out_tensor.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        dilation: int,
+        groups: int = 1,
+    ):
+        super().__init__()
+        self.conv = Conv1d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            dilation=dilation,
+            padding=dilation * (kernel_size - 1) // 2,
+            groups=groups,
+        )
+        self.activation = ReLU()
+        self.norm = LayerNorm(out_channels, eps=1e-12)
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        x = self.activation(self.conv(x))
+
+        return self.norm(x.transpose(1, 2)).transpose(1, 2)  # type: ignore[no-any-return]
+
+
+class Res2NetBlock(Module):
+    """An implementation of Res2NetBlock w/ dilation.
+
+    Arguments
+    ---------
+    :param in_channels : int
+        The number of channels expected in the input.
+    :param out_channels : int
+        The number of output channels.
+    :param scale : int
+        The scale of the Res2Net block.
+    :param kernel_size: int
+        The kernel size of the Res2Net block.
+    :param dilation : int
+        The dilation of the Res2Net block.
+
+    Example
+    -------
+    >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
+    >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
+    >>> out_tensor = layer(inp_tensor).transpose(1, 2)
+    >>> out_tensor.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        scale: int = 8,
+        kernel_size: int = 3,
+        dilation: int = 1,
+    ):
+        super().__init__()
+        assert in_channels % scale == 0
+        assert out_channels % scale == 0
+
+        in_channel = in_channels // scale
+        hidden_channel = out_channels // scale
+        self.blocks = ModuleList(
+            [
+                TDNNBlock(
+                    in_channel,
+                    hidden_channel,
+                    kernel_size=kernel_size,
+                    dilation=dilation,
+                )
+                for i in range(scale - 1)
+            ]
+        )
+        self.scale = scale
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        y = []
+        for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
+            if i == 0:
+                y_i = x_i
+            elif i == 1:
+                y_i = self.blocks[i - 1](x_i)
+            else:
+                y_i = self.blocks[i - 1](x_i + y_i)
+            y.append(y_i)
+
+        y_tensor = torch.cat(y, dim=1)
+        return y_tensor
+
+
+class SEBlock(Module):
+    """An implementation of squeeze-and-excitation block.
+
+    Arguments
+    ---------
+    in_channels : int
+        The number of input channels.
+    se_channels : int
+        The number of output channels after squeeze.
+    out_channels : int
+        The number of output channels.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        se_channels: int,
+        out_channels: int,
+    ):
+        super().__init__()
+
+        self.conv1 = Conv1d(
+            in_channels=in_channels, out_channels=se_channels, kernel_size=1
+        )
+        self.relu = ReLU(inplace=True)
+        self.conv2 = Conv1d(
+            in_channels=se_channels, out_channels=out_channels, kernel_size=1
+        )
+        self.sigmoid = Sigmoid()
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        if padding_mask is not None:
+            mask = padding_mask.materialize().unsqueeze(1)
+            s = (x * mask).sum(dim=2, keepdim=True) / padding_mask.seq_lens[
+                :, None, None
+            ]
+        else:
+            s = x.mean(dim=2, keepdim=True)
+
+        s = self.relu(self.conv1(s))
+        s = self.sigmoid(self.conv2(s))
+
+        return s * x
+
+
+class AttentiveStatisticsPooling(Module):
+    """This class implements an attentive statistic pooling layer for each channel.
+    It returns the concatenated mean and std of the input tensor.
+
+    Arguments
+    ---------
+    channels: int
+        The number of input channels.
+    attention_channels: int
+        The number of attention channels.
+    """
+
+    def __init__(
+        self, channels: int, attention_channels: int = 128, global_context: bool = True
+    ):
+        super().__init__()
+
+        self.eps = 1e-12
+        self.global_context = global_context
+        if global_context:
+            self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
+        else:
+            self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
+
+        self.tanh = Tanh()
+        self.conv = Conv1d(
+            in_channels=attention_channels, out_channels=channels, kernel_size=1
+        )
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Calculates mean and std for a batch (input tensor).
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Tensor of shape [N, C, L].
+        """
+        L = x.shape[-1]
+
+        def _compute_statistics(
+            x: Tensor, m: Tensor, dim: int = 2, eps: float = self.eps
+        ) -> Tuple[Tensor, Tensor]:
+            mean = (m * x).sum(dim)
+            std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
+            return mean, std
+
+        # Make binary mask of shape [N, 1, L]
+        # mask = to_padding_mask(lengths, max(lengths))
+        if padding_mask is not None:
+            mask = padding_mask.materialize()
+        else:
+            mask = to_padding_mask(torch.IntTensor([L]), L).repeat(x.shape[0], 1).to(x)
+        mask = mask.unsqueeze(1)
+
+        # Expand the temporal context of the pooling layer by allowing the
+        # self-attention to look at global properties of the utterance.
+        if self.global_context:
+            # torch.std is unstable for backward computation
+            # https://github.com/pytorch/pytorch/issues/4320
+            total = mask.sum(dim=2, keepdim=True).to(x)
+            mean, std = _compute_statistics(x, mask / total)
+            mean = mean.unsqueeze(2).repeat(1, 1, L)
+            std = std.unsqueeze(2).repeat(1, 1, L)
+            attn = torch.cat([x, mean, std], dim=1)
+        else:
+            attn = x
+
+        # Apply layers
+        attn = self.conv(self.tanh(self.tdnn(attn)))
+
+        # Filter out zero-paddings
+        attn = attn.masked_fill(mask == 0, float("-inf"))
+
+        attn = F.softmax(attn, dim=2)
+        mean, std = _compute_statistics(x, attn)
+        # Append mean and std of the batch
+        pooled_stats = torch.cat((mean, std), dim=1)
+        pooled_stats = pooled_stats.unsqueeze(2)
+
+        return pooled_stats
+
+
+class SERes2NetBlock(Module):
+    """An implementation of building block in ECAPA-TDNN, i.e.,
+    TDNN-Res2Net-TDNN-SEBlock.
+
+    Arguments
+    ----------
+    out_channels: int
+        The number of output channels.
+    res2net_scale: int
+        The scale of the Res2Net block.
+    kernel_size: int
+        The kernel size of the TDNN blocks.
+    dilation: int
+        The dilation of the Res2Net block.
+    groups: int
+    Number of blocked connections from input channels to output channels.
+
+    Example
+    -------
+    >>> x = torch.rand(8, 120, 64).transpose(1, 2)
+    >>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
+    >>> out = conv(x).transpose(1, 2)
+    >>> out.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        res2net_scale: int = 8,
+        se_channels: int = 128,
+        kernel_size: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+    ):
+        super().__init__()
+        self.out_channels = out_channels
+        self.tdnn1 = TDNNBlock(
+            in_channels,
+            out_channels,
+            kernel_size=1,
+            dilation=1,
+            groups=groups,
+        )
+        self.res2net_block = Res2NetBlock(
+            out_channels,
+            out_channels,
+            res2net_scale,
+            kernel_size,
+            dilation,
+        )
+        self.tdnn2 = TDNNBlock(
+            out_channels,
+            out_channels,
+            kernel_size=1,
+            dilation=1,
+            groups=groups,
+        )
+        self.se_block = SEBlock(out_channels, se_channels, out_channels)
+
+        self.shortcut = None
+        if in_channels != out_channels:
+            self.shortcut = Conv1d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=1,
+            )
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        residual = x
+        if self.shortcut:
+            residual = self.shortcut(x)
+
+        x = self.tdnn1(x)
+        x = self.res2net_block(x)
+        x = self.tdnn2(x)
+        x = self.se_block(x, padding_mask=padding_mask)
+
+        return x + residual

+ 112 - 0
src/seamless_communication/models/generator/ecapa_tdnn_builder.py

@@ -0,0 +1,112 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from typing import List, Optional
+
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.typing import DataType, Device
+
+from seamless_communication.models.generator.ecapa_tdnn import ECAPA_TDNN
+
+
+@dataclass
+class EcapaTDNNConfig:
+    channels: List[int]
+    kernel_sizes: List[int]
+    dilations: List[int]
+    attention_channels: int
+    res2net_scale: int
+    se_channels: int
+    global_context: bool
+    groups: List[int]
+    embed_dim: int
+    input_dim: int
+
+
+ecapa_tdnn_archs = ArchitectureRegistry[EcapaTDNNConfig]("ecapa_tdnn")
+
+ecapa_tdnn_arch = ecapa_tdnn_archs.decorator
+
+
+@ecapa_tdnn_arch("base")
+def _base_ecapa_tdnn() -> EcapaTDNNConfig:
+    return EcapaTDNNConfig(
+        channels=[512, 512, 512, 512, 1536],
+        kernel_sizes=[5, 3, 3, 3, 1],
+        dilations=[1, 2, 3, 4, 1],
+        attention_channels=128,
+        res2net_scale=8,
+        se_channels=128,
+        global_context=True,
+        groups=[1, 1, 1, 1, 1],
+        embed_dim=512,
+        input_dim=80,
+    )
+
+
+class EcapaTDNNBuilder:
+    """
+    Builder module for ECAPA_TDNN model
+    """
+
+    config: EcapaTDNNConfig
+    device: Optional[Device]
+    dtype: Optional[DataType]
+
+    def __init__(
+        self,
+        config: EcapaTDNNConfig,
+        *,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param devicev:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        self.config = config
+
+        self.device, self.dtype = device, dtype
+
+    def build_model(self) -> ECAPA_TDNN:
+        """Build a model."""
+        model = ECAPA_TDNN(
+            self.config.channels,
+            self.config.kernel_sizes,
+            self.config.dilations,
+            self.config.attention_channels,
+            self.config.res2net_scale,
+            self.config.se_channels,
+            self.config.global_context,
+            self.config.groups,
+            self.config.embed_dim,
+            self.config.input_dim,
+        )
+        model.to(device=self.device, dtype=self.dtype)
+        return model
+
+
+def create_ecapa_tdnn_model(
+    config: EcapaTDNNConfig,
+    device: Optional[Device] = None,
+    dtype: Optional[DataType] = None,
+) -> ECAPA_TDNN:
+    """Create a ECAPA_TDNN model.
+
+    :param config:
+        The configuration to use.
+    :param device:
+        The device on which to initialize modules.
+    :param dtype:
+        The data type of module parameters and buffers.
+    """
+
+    return EcapaTDNNBuilder(config, device=device, dtype=dtype).build_model()

+ 29 - 0
src/seamless_communication/models/generator/loader.py

@@ -0,0 +1,29 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from typing import Any, Mapping
+
+from fairseq2.assets import asset_store, download_manager
+from fairseq2.models.utils import ConfigLoader, ModelLoader
+
+from seamless_communication.models.generator.builder import (
+    VocoderConfig,
+    create_vocoder_model,
+    vocoder_archs,
+)
+from seamless_communication.models.generator.vocoder import PretsselVocoder
+
+load_pretssel_vocoder_config = ConfigLoader[VocoderConfig](asset_store, vocoder_archs)
+
+
+load_pretssel_vocoder_model = ModelLoader[PretsselVocoder, VocoderConfig](
+    asset_store,
+    download_manager,
+    load_pretssel_vocoder_config,
+    create_vocoder_model,
+    restrict_checkpoints=False,
+)

+ 452 - 0
src/seamless_communication/models/generator/streamable.py

@@ -0,0 +1,452 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import warnings
+from typing import Any, Dict, List, Literal, Optional, Tuple, TypeVar
+
+import torch
+from fairseq2.typing import DataType, Device
+from torch.nn import (
+    ELU,
+    LSTM,
+    Conv1d,
+    ConvTranspose1d,
+    GroupNorm,
+    Identity,
+    Module,
+    Sequential,
+)
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm, weight_norm  # type: ignore[attr-defined]
+
+CONV_NORMALIZATIONS = frozenset(
+    ["none", "weight_norm", "spectral_norm", "time_group_norm"]
+)
+
+
+def apply_parametrization_norm(
+    module: Module,
+    norm: Literal["none", "weight_norm", "spectral_norm", "time_group_norm"] = "none",
+) -> Module:
+    if norm == "weight_norm":
+        return weight_norm(module)
+    elif norm == "spectral_norm":
+        return spectral_norm(module)
+    else:
+        # We already check was in CONV_NORMALIZATION, so any other choice
+        # doesn't need reparametrization.
+        return module
+
+
+def get_norm_module(  # type: ignore[no-untyped-def]
+    module: Module,
+    causal: bool = False,
+    norm: Literal["none", "weight_norm", "spectral_norm", "time_group_norm"] = "none",
+    **norm_kwargs,
+) -> Module:
+    """Return the proper normalization module. If causal is True, this will ensure the returned
+    module is causal, or return an error if the normalization doesn't support causal evaluation.
+    """
+    assert norm in CONV_NORMALIZATIONS
+    if norm == "time_group_norm":
+        if causal:
+            raise ValueError("GroupNorm doesn't support causal evaluation.")
+        assert isinstance(module, torch.nn.modules.conv._ConvNd)
+        return GroupNorm(1, module.out_channels, **norm_kwargs)
+    else:
+        return Identity()
+
+
+def get_extra_padding_for_conv1d(
+    x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+) -> int:
+    """See `pad_for_conv1d`."""
+    length = x.shape[-1]
+    n_frames = (length - kernel_size + padding_total) / stride + 1
+    ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+    return ideal_length - length
+
+
+def pad_for_conv1d(
+    x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+) -> torch.Tensor:
+    """Pad for a convolution to make sure that the last window is full.
+    Extra padding is added at the end. This is required to ensure that we can rebuild
+    an output of the same length, as otherwise, even with padding, some time steps
+    might get removed.
+    For instance, with total padding = 4, kernel size = 4, stride = 2:
+        0 0 1 2 3 4 5 0 0   # (0s are padding)
+        1   2   3           # (output frames of a convolution, last 0 is never used)
+        0 0 1 2 3 4 5 0     # (output of tr. conv., but pos. 5 is going to get removed as padding)
+            1 2 3 4         # once you removed padding, we are missing one time step !
+    """
+    extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+    return F.pad(x, (0, extra_padding))  # noqa
+
+
+def pad1d(
+    x: torch.Tensor,
+    paddings: Tuple[int, int],
+    mode: str = "constant",
+    value: float = 0.0,
+) -> torch.Tensor:
+    """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+    If this is the case, we insert extra 0 padding to the right before the reflection happen.
+    """
+    length = x.shape[-1]
+    padding_left, padding_right = paddings
+    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+    if mode == "reflect":
+        max_pad = max(padding_left, padding_right)
+        extra_pad = 0
+        if length <= max_pad:
+            extra_pad = max_pad - length + 1
+            x = F.pad(x, (0, extra_pad))
+        padded = F.pad(x, paddings, mode, value)
+        end = padded.shape[-1] - extra_pad
+        return padded[..., :end]
+    else:
+        return F.pad(x, paddings, mode, value)
+
+
+def unpad1d(x: torch.Tensor, paddings: Tuple[int, int]) -> torch.Tensor:
+    """Remove padding from x, handling properly zero padding. Only for 1d!"""
+    padding_left, padding_right = paddings
+    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+    assert (padding_left + padding_right) <= x.shape[-1]
+    end = x.shape[-1] - padding_right
+    return x[..., padding_left:end]
+
+
+class NormConv1d(Module):
+    """Wrapper around Conv1d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+        bias: bool = True,
+        causal: bool = False,
+        norm: Literal[
+            "none", "weight_norm", "spectral_norm", "time_group_norm"
+        ] = "none",
+        norm_kwargs: Dict[str, Any] = {},
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+        self.conv: Module = apply_parametrization_norm(
+            Conv1d(
+                in_channels,
+                out_channels,
+                kernel_size,
+                stride,
+                dilation=dilation,
+                groups=groups,
+                bias=bias,
+                device=device,
+                dtype=dtype,
+            ),
+            norm,
+        )
+        self.norm: Module = get_norm_module(self.conv, causal, norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.conv(x)
+        x = self.norm(x)
+        return x
+
+
+class NormConvTranspose1d(Module):
+    """Wrapper around ConvTranspose1d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+
+    def __init__(  # type: ignore[no-untyped-def]
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        causal: bool = False,
+        norm: Literal[
+            "none", "weight_norm", "spectral_norm", "time_group_norm"
+        ] = "none",
+        norm_kwargs: Dict[str, Any] = {},
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+        self.convtr = apply_parametrization_norm(
+            ConvTranspose1d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=kernel_size,
+                stride=stride,
+                device=device,
+                dtype=dtype,
+            ),
+            norm,
+        )
+        self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.convtr(x)
+        x = self.norm(x)
+        return x
+
+
+class StreamableConv1d(Module):
+    """Conv1d with some builtin handling of asymmetric or causal padding
+    and normalization.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+        bias: bool = True,
+        causal: bool = False,
+        norm: Literal[
+            "none", "weight_norm", "spectral_norm", "time_group_norm"
+        ] = "none",
+        norm_kwargs: Dict[str, Any] = {},
+        pad_mode: str = "reflect",
+        activation: Optional[Module] = None,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+        # warn user on unusual setup between dilation and stride
+        if stride > 1 and dilation > 1:
+            warnings.warn(
+                "StreamableConv1d has been initialized with stride > 1 and dilation > 1"
+                f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
+            )
+        self.activation = activation
+        self.conv = NormConv1d(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride,
+            dilation=dilation,
+            groups=groups,
+            bias=bias,
+            causal=causal,
+            norm=norm,
+            norm_kwargs=norm_kwargs,
+            device=device,
+            dtype=dtype,
+        )
+        self.causal = causal
+        self.pad_mode = pad_mode
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.activation:
+            x = self.activation(x)
+        kernel_size: int = self.conv.conv.kernel_size[0]  # type: ignore[index,assignment]
+        stride: int = self.conv.conv.stride[0]  # type: ignore[index,assignment]
+        dilation = self.conv.conv.dilation[0]  # type: ignore[index]
+        kernel_size = (  # type: ignore[assignment]
+            kernel_size - 1
+        ) * dilation + 1  # effective kernel size with dilations
+        padding_total = kernel_size - stride
+        extra_padding = get_extra_padding_for_conv1d(
+            x, kernel_size, stride, padding_total
+        )
+        if self.causal:
+            # Left padding for causal
+            x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
+        else:
+            # Asymmetric padding required for odd strides
+            padding_right = padding_total // 2
+            padding_left = padding_total - padding_right
+            x = pad1d(
+                x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
+            )
+        return self.conv(x)  # type: ignore[no-any-return]
+
+
+class StreamableConvTranspose1d(Module):
+    """ConvTranspose1d with some builtin handling of asymmetric or causal padding
+    and normalization.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        causal: bool = False,
+        norm: Literal[
+            "none", "weight_norm", "spectral_norm", "time_group_norm"
+        ] = "none",
+        trim_right_ratio: float = 1.0,
+        norm_kwargs: Dict[str, Any] = {},
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+        self.convtr = NormConvTranspose1d(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride,
+            causal=causal,
+            norm=norm,
+            norm_kwargs=norm_kwargs,
+            device=device,
+            dtype=dtype,
+        )
+        self.causal = causal
+        self.trim_right_ratio = trim_right_ratio
+        assert (
+            self.causal or self.trim_right_ratio == 1.0
+        ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
+        assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        kernel_size: int = self.convtr.convtr.kernel_size[0]  # type: ignore[index,assignment]
+        stride: int = self.convtr.convtr.stride[0]  # type: ignore[index,assignment]
+        padding_total = kernel_size - stride
+
+        y: torch.Tensor = self.convtr(x)
+
+        # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+        # removed at the very end, when keeping only the right length for the output,
+        # as removing it here would require also passing the length at the matching layer
+        # in the encoder.
+        if self.causal:
+            # Trim the padding on the right according to the specified ratio
+            # if trim_right_ratio = 1.0, trim everything from right
+            padding_right = math.ceil(padding_total * self.trim_right_ratio)
+            padding_left = padding_total - padding_right
+            y = unpad1d(y, (padding_left, padding_right))
+        else:
+            # Asymmetric padding required for odd strides
+            padding_right = padding_total // 2
+            padding_left = padding_total - padding_right
+            y = unpad1d(y, (padding_left, padding_right))
+        return y
+
+
+class StreamableLSTM(Module):
+    """LSTM without worrying about the hidden state, nor the layout of the data.
+    Expects input as convolutional layout.
+    """
+
+    def __init__(
+        self,
+        dimension: int,
+        num_layers: int = 2,
+        skip: bool = True,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+        self.skip = skip
+        self.lstm = LSTM(dimension, dimension, num_layers, device=device, dtype=dtype)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = x.permute(2, 0, 1)
+        y, _ = self.lstm(x)
+        if self.skip:
+            y = y + x
+        y = y.permute(1, 2, 0)
+        return y  # type: ignore[no-any-return]
+
+
+class StreamableResnetBlock(Module):
+    """custom Residual block model with streamable convnet.
+
+    Args:
+        dim (int): Dimension of the input/output.
+        kernel_sizes (list): List of kernel sizes for the convolutions.
+        dilations (list): List of dilations for the convolutions.
+        activation_params (dict): Parameters to provide to the (ELU) activation function.
+        norm (str): Normalization method.
+        norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+        causal (bool): Whether to use fully causal convolution.
+        pad_mode (str): Padding mode for the convolutions.
+        compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+        true_skip (bool): Whether to use true skip connection or a simple
+            (streamable) convolution as the skip connection.
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        kernel_sizes: List[int] = [3, 1],
+        dilations: List[int] = [1, 1],
+        activation_params: Dict[str, Any] = {"alpha": 1.0},
+        norm: Literal[
+            "none", "weight_norm", "spectral_norm", "time_group_norm"
+        ] = "none",
+        norm_params: Dict[str, Any] = {},
+        causal: bool = False,
+        pad_mode: str = "reflect",
+        compress: int = 2,
+        true_skip: bool = True,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+        assert len(kernel_sizes) == len(
+            dilations
+        ), "Number of kernel sizes should match number of dilations"
+        hidden = dim // compress
+        block = []
+        for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+            in_chs = dim if i == 0 else hidden
+            out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+            block += [
+                ELU(**activation_params),
+                StreamableConv1d(
+                    in_chs,
+                    out_chs,
+                    kernel_size=kernel_size,
+                    dilation=dilation,
+                    norm=norm,
+                    norm_kwargs=norm_params,
+                    causal=causal,
+                    pad_mode=pad_mode,
+                    device=device,
+                    dtype=dtype,
+                ),
+            ]
+        self.block = Sequential(*block)
+        self.shortcut: Module
+        if true_skip:
+            self.shortcut = Identity()
+        else:
+            self.shortcut = StreamableConv1d(
+                dim,
+                dim,
+                kernel_size=1,
+                norm=norm,
+                norm_kwargs=norm_params,
+                causal=causal,
+                pad_mode=pad_mode,
+                device=device,
+                dtype=dtype,
+            )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.shortcut(x) + self.block(x)  # type: ignore[no-any-return]

+ 582 - 0
src/seamless_communication/models/generator/vocoder.py

@@ -0,0 +1,582 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any, Dict, List, Literal, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from fairseq2.nn.embedding import Embedding, StandardEmbedding
+from fairseq2.nn.padding import PaddingMask
+from fairseq2.nn.position_encoder import PositionEncoder
+from fairseq2.nn.projection import Projection
+from fairseq2.typing import DataType, Device
+from torch.nn import (
+    ELU,
+    BatchNorm1d,
+    Conv1d,
+    ConvTranspose1d,
+    Dropout,
+    Module,
+    ModuleList,
+    Parameter,
+    Sequential,
+    Tanh,
+    init,
+)
+from torch.nn.utils.weight_norm import remove_weight_norm, weight_norm
+
+from seamless_communication.models.generator.ecapa_tdnn import ECAPA_TDNN
+from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
+from seamless_communication.models.unity.length_regulator import VarianceAdaptor
+from seamless_communication.models.vocoder.hifigan import (
+    LRELU_SLOPE,
+    ResBlock,
+    init_weights,
+)
+
+from .streamable import (
+    StreamableConv1d,
+    StreamableConvTranspose1d,
+    StreamableLSTM,
+    StreamableResnetBlock,
+)
+
+ELU_PARAMS: Dict[str, Any] = {"alpha": 1.0}
+
+
+class PretsselEncoderFrontend(Module):
+    """
+    Represent Encoder frontend, including the prosody encoder and language embedding
+    """
+
+    prosody_encoder: ECAPA_TDNN
+    embed_tokens: Embedding
+    embed_positions: PositionEncoder
+    pos_emb_alpha: Parameter
+    embed_lang: Embedding
+    dropout: Dropout
+
+    def __init__(
+        self,
+        prosody_encoder: ECAPA_TDNN,
+        embed_tokens: Embedding,
+        embed_positions: PositionEncoder,
+        lang_to_index: Dict[str, int],
+        lang_embed_dim: Optional[int],
+        dropout_p: float,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+
+        self.prosody_encoder = prosody_encoder
+
+        self.embed_tokens = embed_tokens
+
+        self.embed_positions = embed_positions
+        self.pos_emb_alpha = Parameter(torch.ones(1, device=device, dtype=dtype))
+
+        self.lang_to_index = lang_to_index
+
+        if lang_embed_dim is not None:
+            self.embed_lang = StandardEmbedding(
+                len(lang_to_index), lang_embed_dim, device=device, dtype=dtype
+            )
+        else:
+            self.register_module("embed_lang", None)
+
+        self.dropout = Dropout(dropout_p)
+
+        self.device = device
+        self.dtype = dtype
+
+    def forward(
+        self,
+        seqs: torch.Tensor,
+        padding_mask: Optional[PaddingMask],
+        prosody_input_seqs: torch.Tensor,
+        prosody_padding_mask: Optional[PaddingMask],
+        tgt_lang: str,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        prosody_embs = self.prosody_encoder(
+            prosody_input_seqs,
+            prosody_padding_mask,
+        ).unsqueeze(1)
+
+        if self.embed_lang is not None:
+            lang_index = self.lang_to_index[tgt_lang]
+            lang_index_tensor = (
+                torch.Tensor([lang_index]).to(seqs).repeat(seqs.size(0), 1)
+            )
+            lang_embeds = self.embed_lang(lang_index_tensor)
+            prosody_embs = torch.cat([prosody_embs, lang_embeds], dim=-1)
+
+        seqs = self.embed_tokens(seqs)
+        seqs += self.pos_emb_alpha * (self.embed_positions(seqs, padding_mask) - seqs)
+        seqs = self.dropout(seqs)
+
+        return seqs, prosody_embs
+
+
+class PretsselDecoderFrontend(Module):
+    """Represent Decoder frontend, including VarianceAdaptor & Positional embedding"""
+
+    variance_adaptor: VarianceAdaptor
+    embed_positions: PositionEncoder
+    pos_emb_alpha: Parameter
+
+    def __init__(
+        self,
+        variance_adaptor: VarianceAdaptor,
+        embed_positions: PositionEncoder,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+
+        self.variance_adaptor = variance_adaptor
+        self.embed_positions = embed_positions
+        self.pos_emb_alpha = Parameter(torch.ones(1, device=device, dtype=dtype))
+
+        self.device = device
+        self.dtype = dtype
+
+    def forward(
+        self,
+        seqs: torch.Tensor,
+        padding_mask: PaddingMask,
+        durations: Optional[torch.Tensor] = None,
+        duration_factor: float = 1.0,
+        min_duration: int = 0,
+        film_cond_emb: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, PaddingMask]:
+        seqs, padding_mask, _ = self.variance_adaptor(
+            seqs, padding_mask, durations, duration_factor, min_duration, film_cond_emb
+        )
+
+        seqs += self.pos_emb_alpha * (self.embed_positions(seqs, padding_mask) - seqs)
+
+        return seqs, padding_mask
+
+
+class PretsselVocoder(Module):
+    """The expressivity-preserving vocoder"""
+
+    encoder_frontend: PretsselEncoderFrontend
+    encoder: FeedForwardTransformer
+    decoder_frontend: PretsselDecoderFrontend
+    decoder: FeedForwardTransformer
+    final_proj: Projection
+
+    def __init__(  # type: ignore[no-untyped-def]
+        self,
+        encoder_frontend: PretsselEncoderFrontend,
+        encoder: FeedForwardTransformer,
+        decoder_frontend: PretsselDecoderFrontend,
+        decoder: FeedForwardTransformer,
+        final_proj: Projection,
+        pn_n_channels: int,
+        pn_kernel_size: int,
+        pn_layers: int,
+        pn_dropout: float,
+        upsample_rates: List[int],
+        upsample_kernel_sizes: List[int],
+        upsample_initial_channel: int,
+        resblock_kernel_sizes: List[int],
+        resblock_dilation_sizes: List[List[int]],
+        mel_dim: int = 80,
+        add_ups_out_pad: bool = True,
+        channels: int = 1,
+        dimension: int = 128,
+        n_filters: int = 32,
+        ratios: List[int] = [8, 5, 4, 2],
+        norm: Literal[
+            "none", "weight_norm", "spectral_norm", "time_group_norm"
+        ] = "none",
+        norm_params: Dict[str, Any] = {},
+        kernel_size: int = 7,
+        last_kernel_size: int = 7,
+        residual_kernel_size: int = 3,
+        causal: bool = False,
+        pad_mode: str = "constant",
+        true_skip: bool = True,
+        compress: int = 2,
+        lstm: int = 0,
+        disable_norm_outer_blocks: int = 0,
+        trim_right_ratio: float = 1.0,
+        gcmvn_mean: Optional[List[float]] = None,
+        gcmvn_std: Optional[List[float]] = None,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+        self.encoder_frontend = encoder_frontend
+        self.encoder = encoder
+        self.decoder_frontend = decoder_frontend
+        self.decoder = decoder
+        self.final_proj = final_proj
+        mult = 1
+        stream_layers: List[Module] = [
+            StreamableConv1d(
+                channels,
+                mult * n_filters,
+                kernel_size,
+                norm="none" if disable_norm_outer_blocks >= 1 else norm,
+                norm_kwargs=norm_params,
+                causal=causal,
+                pad_mode=pad_mode,
+                activation=Tanh(),
+                device=device,
+                dtype=dtype,
+            )
+        ]
+        # Downsample to from audio scale
+        for i, ratio in enumerate(list(reversed(ratios))):
+            block_norm = "none" if disable_norm_outer_blocks >= i + 2 else norm
+            stream_layers.append(
+                StreamableResnetBlock(
+                    mult * n_filters,
+                    kernel_sizes=[residual_kernel_size, 1],
+                    dilations=[1, 1],
+                    norm=block_norm,
+                    norm_params=norm_params,
+                    causal=causal,
+                    pad_mode=pad_mode,
+                    compress=compress,
+                    true_skip=true_skip,
+                    device=device,
+                    dtype=dtype,
+                )
+            )
+            stream_layers.append(ELU(**ELU_PARAMS))
+            stream_layers.append(
+                StreamableConv1d(
+                    mult * n_filters,
+                    mult * n_filters * 2,
+                    kernel_size=ratio * 2,
+                    stride=ratio,
+                    norm=block_norm,
+                    norm_kwargs=norm_params,
+                    causal=causal,
+                    pad_mode=pad_mode,
+                    device=device,
+                    dtype=dtype,
+                )
+            )
+            mult *= 2
+
+        stream_layers.append(StreamableLSTM(mult * n_filters, num_layers=lstm))
+        stream_layers.append(ELU(**ELU_PARAMS))
+        n_blocks = len(ratios) + 2
+        stream_layers.append(
+            StreamableConv1d(
+                mult * n_filters,
+                dimension,
+                last_kernel_size,
+                norm="none" if disable_norm_outer_blocks == n_blocks else norm,
+                norm_kwargs=norm_params,
+                causal=causal,
+                pad_mode=pad_mode,
+                device=device,
+                dtype=dtype,
+            )
+        )
+        stream_layers.append(
+            StreamableConv1d(
+                dimension,
+                mult * n_filters,
+                kernel_size,
+                norm="none" if disable_norm_outer_blocks == n_blocks else norm,
+                norm_kwargs=norm_params,
+                causal=causal,
+                pad_mode=pad_mode,
+                device=device,
+                dtype=dtype,
+            )
+        )
+        stream_layers.append(
+            StreamableLSTM(
+                mult * n_filters, num_layers=lstm, device=device, dtype=dtype
+            )
+        )
+
+        # resample back to raw audio scale
+        for i, ratio in enumerate(ratios):
+            block_norm = (
+                "none" if disable_norm_outer_blocks >= n_blocks - (i + 1) else norm
+            )
+            stream_layers.append(ELU(**ELU_PARAMS))
+            stream_layers.append(
+                StreamableConvTranspose1d(
+                    mult * n_filters,
+                    mult * n_filters // 2,
+                    kernel_size=ratio * 2,
+                    stride=ratio,
+                    norm=block_norm,
+                    norm_kwargs=norm_params,
+                    causal=causal,
+                    trim_right_ratio=trim_right_ratio,
+                    device=device,
+                    dtype=dtype,
+                )
+            )
+            stream_layers.append(
+                StreamableResnetBlock(
+                    mult * n_filters // 2,
+                    kernel_sizes=[residual_kernel_size, 1],
+                    dilations=[1, 1],
+                    norm=block_norm,
+                    norm_params=norm_params,
+                    activation_params=ELU_PARAMS,
+                    causal=causal,
+                    pad_mode=pad_mode,
+                    compress=compress,
+                    true_skip=true_skip,
+                    device=device,
+                    dtype=dtype,
+                )
+            )
+            mult //= 2
+
+        stream_layers.append(ELU(**ELU_PARAMS))
+        stream_layers.append(
+            StreamableConv1d(
+                n_filters,
+                channels,
+                last_kernel_size,
+                norm="none" if disable_norm_outer_blocks >= 1 else norm,
+                norm_kwargs=norm_params,
+                causal=causal,
+                pad_mode=pad_mode,
+                device=device,
+                dtype=dtype,
+            )
+        )
+        self.n_streams = len(stream_layers)
+        chunk_size = self.n_streams // 4
+        stream_idx = 0
+
+        self.pn_layers = pn_layers
+        self.layers = ModuleList()
+        assert pn_kernel_size % 2 == 1
+        for i in range(pn_layers):
+            cur_layers = (
+                [
+                    Conv1d(
+                        mel_dim if i == 0 else pn_n_channels,
+                        pn_n_channels if i < pn_layers - 1 else mel_dim,
+                        kernel_size=pn_kernel_size,
+                        padding="same",
+                        device=device,
+                        dtype=dtype,
+                    ),
+                    BatchNorm1d(
+                        pn_n_channels if i < pn_layers - 1 else mel_dim,
+                        device=device,
+                        dtype=dtype,
+                    ),
+                ]
+                + ([Tanh()] if i < pn_layers - 1 else [])
+                + [Dropout(pn_dropout)]
+            )
+            self.layers.append(Sequential(*cur_layers))
+        self.reset_parameters()
+        self.layers.extend(stream_layers[:chunk_size])
+        stream_idx += chunk_size
+        self.layers.append(
+            weight_norm(
+                Conv1d(
+                    mel_dim if mel_dim is not None else 80,
+                    upsample_initial_channel,
+                    7,
+                    1,
+                    padding="same",
+                    device=device,
+                    dtype=dtype,
+                )
+            )
+        )
+        self.layers.extend(stream_layers[stream_idx : stream_idx + chunk_size])  # noqa
+        stream_idx += chunk_size
+
+        self.num_kernels = len(resblock_kernel_sizes)
+        self.num_upsamples = len(upsample_rates)
+        ups = ModuleList()
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            out_pad = u % 2 if add_ups_out_pad else 0
+            ups.append(
+                weight_norm(
+                    ConvTranspose1d(
+                        upsample_initial_channel // (2**i),
+                        upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2 + out_pad,
+                        output_padding=out_pad,
+                        device=device,
+                        dtype=dtype,
+                    )
+                )
+            )
+        ups.apply(init_weights)
+        self.layers.extend(ups)
+        self.layers.extend(stream_layers[stream_idx : stream_idx + chunk_size])  # noqa
+        stream_idx += chunk_size
+
+        for i in range(self.num_upsamples):
+            ch = upsample_initial_channel // (2 ** (i + 1))
+            for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
+                self.layers.append(
+                    ResBlock(
+                        ch,
+                        k,
+                        d,
+                    ).to(device, dtype=dtype)
+                )
+        self.layers.extend(stream_layers[stream_idx:])
+
+        conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+        conv_post.apply(init_weights)
+        self.layers.append(conv_post)
+        for u, k in zip(upsample_rates, upsample_kernel_sizes):
+            assert k == 2 * u, (k, u)
+
+        mean = torch.zeros((mel_dim,), dtype=torch.float)
+        scale = torch.zeros((mel_dim,), dtype=torch.float)
+        self.register_buffer("mean", mean)
+        self.register_buffer("scale", scale)
+
+        self.gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype)
+        self.gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype)
+
+    def reset_parameters(self) -> None:
+        for i in range(self.pn_layers):
+            init.xavier_uniform_(
+                self.layers[i][0].weight,
+                init.calculate_gain("tanh" if i < self.pn_layers - 1 else "linear"),
+            )
+
+    def gcmvn_denormalize(self, x: torch.Tensor) -> torch.Tensor:
+        if self.gcmvn_mean is None or self.gcmvn_std is None:
+            raise ValueError("gcmvn_mean is not set")
+
+        assert (
+            x.ndim == 3
+            and x.shape[2] == self.gcmvn_mean.shape[0]
+            and x.shape[2] == self.gcmvn_std.shape[0]
+        )
+        gcmvn_mean = self.gcmvn_mean.to(x)
+        gcmvn_std = self.gcmvn_std.to(x)
+        x = x * gcmvn_std.view(1, 1, -1).expand_as(x)  # type: ignore[attr-defined]
+        return x + gcmvn_mean.view(1, 1, -1).expand_as(x)  # type: ignore[attr-defined,no-any-return]
+
+    def forward(
+        self,
+        seqs: torch.Tensor,
+        tgt_lang: str,
+        prosody_input_seqs: torch.Tensor,
+        padding_mask: Optional[PaddingMask] = None,
+        prosody_padding_mask: Optional[PaddingMask] = None,
+        durations: Optional[torch.Tensor] = None,
+        duration_factor: float = 1.0,
+        min_duration: int = 0,
+        normalize_before: bool = True,
+    ) -> torch.Tensor:
+        # Here we are adding batch dimension for the pretssel
+        if seqs.ndim < 3:
+            seqs = seqs.unsqueeze(0)
+        if prosody_input_seqs.ndim < 3:
+            prosody_input_seqs = prosody_input_seqs.unsqueeze(0)
+        seqs, cond_embs = self.encoder_frontend(
+            seqs,
+            padding_mask,
+            prosody_input_seqs,
+            prosody_padding_mask,
+            tgt_lang,
+        )
+        seqs, padding_mask = self.encoder(seqs, padding_mask, cond_embs)
+        seqs, padding_mask = self.decoder_frontend(
+            seqs, padding_mask, durations, duration_factor, min_duration, cond_embs
+        )
+        seqs, padding_mask = self.decoder(seqs, padding_mask, cond_embs)
+        seqs = self.final_proj(seqs)
+
+        pn = seqs.transpose(1, 2)  # B x T x C -> B x C x T
+        for i in range(self.pn_layers):
+            pn = self.layers[i](pn)
+        pn = pn.transpose(1, 2)
+
+        x = seqs + pn
+        x = self.gcmvn_denormalize(x).squeeze(0)
+        if normalize_before:
+            x = (x - self.mean) / self.scale
+
+        x = x.transpose(1, 0).unsqueeze(0)
+        chunk_size = self.n_streams // 4
+        x = self.layers[self.pn_layers + chunk_size](x)
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            x = self.layers[i + self.pn_layers + 1 + 2 * chunk_size](x)
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.layers[
+                        i * self.num_kernels
+                        + j
+                        + self.pn_layers
+                        + 3 * chunk_size
+                        + self.num_upsamples
+                        + 1
+                    ](x)
+                else:
+                    xs += self.layers[
+                        i * self.num_kernels
+                        + j
+                        + self.pn_layers
+                        + 3 * chunk_size
+                        + self.num_upsamples
+                        + 1
+                    ](x)
+            x = xs / self.num_kernels  # type: ignore
+        x = F.leaky_relu(x)
+        x = self.layers[
+            self.pn_layers
+            + self.n_streams
+            + self.num_upsamples * (1 + self.num_kernels)
+            + 1
+        ](x)
+        skip_output = x
+        h = skip_output
+
+        for i1 in range(self.pn_layers, self.pn_layers + chunk_size):
+            h = self.layers[i1](h)
+        i1 += 2
+        for i2 in range(i1, i1 + chunk_size):
+            h = self.layers[i2](h)
+        i2 = i2 + self.num_upsamples + 1
+
+        for i3 in range(i2, i2 + chunk_size):
+            h = self.layers[i3](h)
+        i3 = i3 + (self.num_upsamples * self.num_kernels) + 1
+        for i4 in range(i3, i3 + chunk_size):
+            h = self.layers[i4](h)
+        h = h[:, :, : x.size(-1)]
+
+        h += torch.tanh(skip_output).squeeze(0)
+        return h
+
+    def remove_weight_norm(self) -> None:
+        i = self.pn_layers + 1
+        for j in range(self.num_upsamples):
+            remove_weight_norm(self.layers[i + j])
+        for k in range(self.num_upsamples * self.num_kernels):
+            self.layers[i + j + k + 1].remove_weight_norm()
+        remove_weight_norm(self.layers[self.pn_layers])
+        remove_weight_norm(
+            self.layers[
+                self.pn_layers + 1 + self.num_upsamples * (1 + self.num_kernels)
+            ]
+        )

+ 1 - 1
src/seamless_communication/models/unity/builder.py

@@ -26,7 +26,7 @@ from fairseq2.nn.transformer import (
 from fairseq2.typing import DataType, Device, override
 from torch.nn import GELU, ReLU
 
-from seamless_communication.models.pretssel import (
+from seamless_communication.models.generator.ecapa_tdnn_builder import (
     EcapaTDNNBuilder,
     EcapaTDNNConfig,
     ecapa_tdnn_archs,

+ 6 - 2
src/seamless_communication/models/unity/loader.py

@@ -8,7 +8,7 @@ from typing import Any, Dict, List, Mapping, Tuple, Union
 
 import torch
 from fairseq2.assets import AssetStore, asset_store, download_manager
-from fairseq2.assets.card import AssetCard
+from fairseq2.assets.card import AssetCard, AssetCardFieldNotFoundError
 from fairseq2.models.nllb import NllbConfig
 from fairseq2.models.nllb.loader import NllbTokenizerLoader
 from fairseq2.models.utils import ConfigLoader, ModelLoader
@@ -459,7 +459,11 @@ class GcmvnStatsLoader:
         else:
             card = self.asset_store.retrieve_card(model_name_or_card)
 
-        gcmvn_stats: Dict[str, List[float]] = card.field("gcmvn_stats").as_(dict)
+        try:
+            gcmvn_stats: Dict[str, List[float]] = card.field("gcmvn_stats").as_(dict)
+        except AssetCardFieldNotFoundError:
+            model_override = card.field("model_config").as_(dict)
+            gcmvn_stats = model_override["gcmvn_stats"]
 
         return gcmvn_stats["mean"], gcmvn_stats["std"]
 

+ 1 - 1
src/seamless_communication/models/unity/model.py

@@ -19,7 +19,7 @@ from overrides import final as finaloverride
 from torch import Tensor
 from torch.nn import Module
 
-from seamless_communication.models.pretssel.ecapa_tdnn import ECAPA_TDNN
+from seamless_communication.models.generator.ecapa_tdnn import ECAPA_TDNN
 from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
 from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
 

+ 152 - 0
tests/integration/models/test_watermarked_vocoder.py

@@ -0,0 +1,152 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import sys
+from typing import cast, List, Final, Optional
+from anyio import Path
+import torch
+from fairseq2.typing import Device
+from fairseq2.data import Collater, SequenceData
+from fairseq2.data.audio import AudioDecoderOutput
+from torch.nn import Module
+
+from seamless_communication.inference.pretssel_generator import PretsselGenerator
+from seamless_communication.models.unity.loader import load_gcmvn_stats
+from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
+from tests.common import (
+    assert_close,
+    convert_to_collated_fbank,
+)
+
+
+N_MEL_BINS = 80
+
+# fmt: off
+REF_FRA_UNITS: Final = [8976, 6589, 6589, 5736, 7542, 6515, 1240, 8335, 2381, 1076, 1076, 3380, 4085, 8207, 7957, 4446, 2641, 2544, 5552, 5529, 6319, 2779, 2890, 2890, 3229, 3303, 9751, 1979, 664, 1859, 1302, 528, 1303, 9543, 5770, 3532, 1286, 1286, 1727, 9287, 5248, 5586, 594, 3385, 2613, 1717, 7529, 7634, 931, 1602, 4512, 850, 2748, 5056, 1086, 2320, 2320, 9320, 3223, 5592, 1122, 419, 24, 4126, 5200, 2712, 9549, 8676, 8676, 3443, 7598, 7598, 2200, 2745, 1215, 118, 3840, 2703, 1616, 8788, 1240, 3349, 4890, 2756, 166, 9574, 9773, 5887, 2516, 9332, 6092, 3377, 4334, 3127, 3127, 3127, 944, 3089, 5947, 6572, 6572, 7561, 4358, 4358, 4358, 8124, 5549, 9275, 82, 8830, 8830, 5949, 22, 6729, 6878, 3817, 1871, 6092, 1441, 3127, 3928, 8254, 7984, 1116, 2796, 1806, 3710, 797, 9269, 576, 576, 2020, 137, 6624, 3815, 8690, 3634, 6036, 3530, 8719, 3458, 138, 8745, 5233, 2235, 8580, 8580, 6831, 2709, 7136, 9693, 3437, 3437, 3238, 4368, 2321, 2321, 391, 391, 4976, 8622, 6722, 3864, 9113, 9113, 7222, 7222, 7937, 999, 1286, 1286, 7789, 9396, 9603, 6690, 5233, 2235, 618, 8830, 6954, 3668, 4302, 596, 1934, 2886, 2704, 9097, 4161, 458, 4147, 9245, 9245, 3127, 3127, 944, 9676, 9676, 3468, 270, 270, 4608, 5549, 4182, 102, 8568, 1286, 1286, 5087, 817, 4153, 207, 207, 3763, 6415, 5188, 6010, 554, 753, 9953, 5104, 3828, 1879, 995, 9683, 6932, 3644, 2683, 9335, 183, 5525, 7023, 9568, 6222, 6315, 676, 3443, 6971, 2084, 999, 1286, 1286, 9620, 9620, 1048, 5577, 9328, 4963, 1364, 8328, 4573, 4573, 7917, 7917, 560, 2020, 4923, 137, 9542, 5832, 9775, 4780, 9400, 2745, 2745, 8984, 628, 8834, 6932, 3817, 8312, 5393, 458, 4147, 9191, 2225, 2759, 8980, 2351, 193, 1476, 9347, 3063, 2076, 3641, 1614, 9832, 3554, 8197, 5589, 5589, 7306, 184, 1708, 2954, 2954, 3485, 3485, 7665, 8909, 5405, 3590, 3590, 3446, 6442, 6442, 2802, 5549, 3791]
+# fmt: on
+
+
+def load_watermarking_model() -> Optional[Module]:
+    import importlib.util
+
+    # Run in CPU mode until pretssel inconsistent behavious is fixed
+    device = Device("cpu")
+    dtype = torch.float32
+    wm_py_file = Path(__file__).parents[3] / "scripts/watermarking/watermarking.py"
+    assert wm_py_file.is_file()
+    wm_spec = importlib.util.spec_from_file_location("watermark.f1", wm_py_file)
+    assert wm_spec, f"Module not found: {wm_py_file}"
+    wm_py_module = importlib.util.module_from_spec(wm_spec)
+    assert wm_py_module, f"Invalid Python module file: {wm_py_file}"
+    sys.modules["watermark.f1"] = wm_py_module
+    assert wm_spec.loader, f"Module cannot be loaded from {wm_py_file}"
+    wm_spec.loader.exec_module(wm_py_module)
+
+    return cast(Module, wm_py_module.model_from_checkpoint(device=device, dtype=dtype))
+
+
+def test_pretssel_vocoder_watermarking(
+    example_rate16k_audio: AudioDecoderOutput,
+) -> None:
+    """
+    Test that the watermarked pretssel vocoder generates the same output
+    as the non-watermarked (pretssel_generator)
+    """
+    audio = example_rate16k_audio
+
+    # Run in CPU mode until pretssel inconsistent behavious is fixed
+    device = Device("cpu")
+    dtype = torch.float32
+    audio["waveform"] = audio["waveform"].to(device, dtype=dtype)
+    feat = convert_to_collated_fbank(audio, dtype=dtype)["seqs"][0]
+    feat = feat.to(device, dtype=dtype)
+    # Run the watermarked vocoding
+    # TODO: Build a generator API for the watermarked vocoder
+    vocoder = load_pretssel_vocoder_model(
+        "vocoder_pretssel", device=device, dtype=dtype
+    )
+
+    units = torch.tensor(REF_FRA_UNITS, device=device, dtype=torch.int64)
+
+    # adjust the control symbols for the embedding
+    units += 4
+
+    # eos_idx = 2 in the VocabularyInfo setting for base pretssel_vocoder
+    unit_eos_token = torch.tensor([2], device=device)
+    units = torch.cat([units, unit_eos_token], dim=0)
+    units, duration = torch.unique_consecutive(units, return_counts=True)
+
+    # adjust for the last eos token
+    duration[-1] = 0
+    duration *= 2
+
+    # bos_idx=0 in base VocabularyInfo
+    duration_collate = Collater(pad_value=0)
+    duration_seqs = duration_collate(duration)
+
+    with torch.no_grad():
+        vocoder.eval()
+        wav_wm = vocoder(
+            seqs=units,
+            tgt_lang="fra",
+            prosody_input_seqs=feat,
+            durations=duration_seqs["seqs"],
+            normalize_before=True,
+        )
+
+    # torchaudio.save("wm.wav", wav_wm.squeeze(0).float().cpu(), sample_rate=16000)
+
+    # Run the non-watermarked vocoder using pretssel generator
+    gcmvn_mean, gcmvn_std = load_gcmvn_stats("pretssel_v1")
+    gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype)  # type: ignore[assignment]
+    gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype)  # type: ignore[assignment]
+
+    generator = PretsselGenerator(
+        "seamless_expressivity",
+        "vocoder_mel_24khz",
+        "pretssel_v1",
+        gcmvn_mean=gcmvn_mean,  # type: ignore[arg-type]
+        gcmvn_std=gcmvn_std,  # type: ignore[arg-type]
+        device=device,
+        dtype=dtype,
+    )
+
+    # PretsselGenerator expects a batch of units
+    unit_list: List[List[int]] = [REF_FRA_UNITS]
+    prosody_input_seqs = SequenceData(
+        is_ragged=False,
+        seqs=feat.unsqueeze(0),  # add batch dim
+        seq_lens=torch.tensor([feat.size(0)]),
+    )
+    speech_output = generator.predict(
+        unit_list,
+        tgt_lang="fra",
+        prosody_encoder_input=prosody_input_seqs,
+    )
+    wav = speech_output.audio_wavs[0].unsqueeze(0)
+
+    # torchaudio.save("mel.wav", wav.float().cpu(), sample_rate=16000)
+
+    # Run the watermark model separately after the PretsselGenerator
+    watermarker = load_watermarking_model()
+    wm = watermarker.get_watermark(wav)  # type: ignore
+    wav_wm_hat = wav + wm
+
+    # Test that the watermark is detecte-able
+    detection = watermarker.detect_watermark(wav_wm)  # type: ignore
+    assert torch.all(detection[:, 1, :] > 0.5)
+
+    # Remove the batch and compare parity on the overlapping frames
+    wav_wm = wav_wm.squeeze(0)
+    wav_wm_hat = wav_wm_hat.squeeze(0)
+
+    nframes = min(wav_wm_hat.size(1), wav_wm.size(1))
+    assert_close(
+        wav_wm[:, :nframes],
+        wav_wm_hat[:, :nframes],
+        atol=0.0,
+        rtol=5.0,
+    )