Browse Source

Make a public-facing watermarked vocoder (PretsselVocoder) (#97)

* initial commit

* add audiocraft to requrements

* all in one file

* cleanup

* remove unused imports

* cleanup and add doc string

* add checkpoint to model card instead

* add pretssel hifigan

* update

* update

* update

* update

* update

* --amend

* Update src/seamless_communication/models/vocoder/builder.py

Co-authored-by: Kaushik Ram Sadagopan <krs@fb.com>

* update

* add pretssel vocoder

* update builder and loader

* Implement PretsselModel & its inference

* update checkpoint handling code

* add mel_vocoder

* refactor inference

* minor fix

* minor fix

* minor fix

* mypy pytest isort black formatting

* change padding to 'same'

* minor renaming

* functional melvocoder + watermark

* remove pretssel from the public-facing module

* checkpoint code

* update checkpoint and fix typos in builder

* update integration test

* make languages updatable automatically

* blend PostNet

* update test cases with new mixing logic of wm layers

* linting

* remove obsolete state dict keys in the checkpoint

* debug the wm deltas error

* linting

* linting

* correct typo

* exclude scripts/convert_mel_hifigan_chkpt.py from the PR

* fix typo

* exclude non-related linting (we will do it in another PR)

* exclude non-related linting (we will do it in another PR)

* exclude non-related linting (we will do it in another PR)

* exclude non-related linting and make it in a separate PR

* exclude non-related linting (we will do it in another PR)

* obsfucate further the code by making torch.tanh embeddable in conv1d, so it is not possible to remove one block of code from vocoder, but only multiple blocks at multiple places

* update watermarking.py as a standalone script

* add temporary coments to support the PR review

* fix typos

* fix linting

* update with new loader API from fairseq2

* re-sync ecapa_tdnn with main

* re-sync ecapa_tdnn_builder with main

* fix typo in card

* update compile_checkpoint to new loader API

* Yilin's comments

* check if gcmnvn_mean is loaded

* update deps of unity; update test case

* update cuda setting for watermarking

* (1) remove the unity deps in this PR - we put this in another PR; (2) address last comments from Alex and Yilin

* tmp

* enable batched watermark inference

* revise

* update config for 24khz audios

* update model config for 24khz expetect rate

* fix model card and checkpoint compilation script

* update test case to 24khz

* linting

* integra Kaushik's PR #134

* update model card info, update deps on unity

* revise sample_rate

* remove internal comments

---------

Co-authored-by: hady elsahar <hadyelsahar@meta.com>
Co-authored-by: Changhan Wang <changhan@fb.com>
Co-authored-by: Changhan Wang <wangchanghan@gmail.com>
Co-authored-by: Kaushik Ram Sadagopan <krs@fb.com>
Co-authored-by: Tuan Tran <tuantran@devfair0436.h2.fair>
Co-authored-by: Yilin Yang <yilinyang721@gmail.com>
Tuan Tran 1 year ago
parent
commit
9f6ade6ee4

+ 6 - 6
scripts/convert_pretssel_hifigan_chkpt.py → scripts/convert_mel_hifigan_chkpt.py

@@ -9,10 +9,9 @@ out_channels -> upsample_initial_channel
 """
 """
 
 
 
 
-def main():
-    chkpt_root = "/checkpoint/mjhwang/experiments/231007-mel_vocoder-mls_multilingual_6lang/train_mls_multilingual_6lang_subset_hifigan.v1_8gpu_adapt"
-    cfg = f"{chkpt_root}/config.yml"
-    # TODO: display cfg
+def main() -> None:
+    # chkpt_root = "/checkpoint/mjhwang/experiments/231007-mel_vocoder-mls_multilingual_6lang/train_mls_multilingual_6lang_subset_hifigan.v1_8gpu_adapt"
+    chkpt_root = "/checkpoint/mjhwang/experiments/231112-mel_vocoder-ai_speech_24khz/train_train_highquality_speech_20231111_no16khz_100000_hifigan.v1_8gpu_adapt"
     chkpt = torch.load(f"{chkpt_root}/checkpoint-400000steps.pkl")
     chkpt = torch.load(f"{chkpt_root}/checkpoint-400000steps.pkl")
     del chkpt["model"]["discriminator"]
     del chkpt["model"]["discriminator"]
     conv_seq_map = {
     conv_seq_map = {
@@ -21,7 +20,7 @@ def main():
         ".1.weight_v": ".weight_v",
         ".1.weight_v": ".weight_v",
     }
     }
 
 
-    def update_key(k):
+    def update_key(k: str) -> str:
         if k.startswith("input_conv"):
         if k.startswith("input_conv"):
             k = k.replace("input_conv", "conv_pre")
             k = k.replace("input_conv", "conv_pre")
         elif k.startswith("upsamples"):
         elif k.startswith("upsamples"):
@@ -50,7 +49,8 @@ def main():
     for k in ["optimizer", "scheduler", "steps", "epochs"]:
     for k in ["optimizer", "scheduler", "steps", "epochs"]:
         del chkpt[k]
         del chkpt[k]
 
 
-    out_path = "/large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt"
+    # out_path = "/large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt"
+    out_path = "/large_experiments/seamless/workstream/expressivity/oss/checkpoints/melhifigan_20231121.pt"
     torch.save(chkpt, out_path)
     torch.save(chkpt, out_path)
 
 
 
 

+ 208 - 0
scripts/watermarking/compile_chkpt.py

@@ -0,0 +1,208 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+#
+#
+# The rules to blend the p2v decoder, mel-vocoder and the watermarking:
+#
+# Step 1) Make the big sequential module `layers` that consists of:
+#    - PostNet (last layer of the p2v decoder) : 5 layers
+#    - mel-vocoder layers (conv_pre, ups, resblocks, conv_post): 18 layers
+#    - watermarking encoder and decoder: 32 layers
+#
+# Step 2) Take the last 32 layers of the watermarking, split into 4 blocks of
+# 8 layers. Mix these blocks into the previous layers
+#
+# The final mixed architecture SPVM (Spaghetti Pretssel Vovoder Model):
+#
+#     <P2V: Post Net>
+#           |
+# <Block 1 of Watermarker> ------
+#           |                   |
+#          \/                   |
+#  <Melvocoder: Conv_pre>       |
+#           | (skip)            |
+# <Block 2 of Watermarker> -----|
+#           |                   |
+#          \/                   |
+# <Melvocoder: Upsampler>       |
+#           | (skip)            |
+# <Block 3 of Watermarker> -----|
+#           |                   |
+#          \/                   |
+# <Melvocoder: Resblocks>       |
+#           | (skip)            |
+# <Block 4 of Watermarker> -----|
+#           |                   |
+#          \/                   |
+#  <Melvocoder: Conv_post>      |
+#           |                   |
+#           | ------------------|
+#           |
+#          \/
+#    watermarked wavs
+
+from pathlib import Path
+from argparse import ArgumentParser
+from typing import Any, Mapping, Match
+
+import torch
+from fairseq2.models.utils.checkpoint import (
+    convert_fairseq_checkpoint,
+    convert_model_state_dict,
+    load_checkpoint,
+)
+
+
+def pretssel_key_map() -> Mapping[str, str]:
+    """
+    The rule for renaming the layers of Pretssel model checkpoint:
+        - Merge decoder.postnet into `layers`
+    """
+    from seamless_communication.models.pretssel.loader import _fairseq_key_map  # noqa
+
+    key_map = _fairseq_key_map(None)  # type: ignore[arg-type]
+    del key_map[r"^decoder\.postnet\."]
+    key_map[r"^decoder\.postnet\.convolutions\."] = r"layers."
+    return key_map
+
+
+def vocoder_key_map() -> Mapping[str, Any]:
+    """
+    Rename layers in the mel-vocoder checkpoint. We flatten the vocoder arch and put everything
+    into the `layers`, in which `postnet_size` layers from the PostNet already present. In other words:
+        - conv_pre -> layers.<postnet_size + watermark / 4>
+        - ups.i -> layers.<postnet_size + 1 + i + watermark_size / 2>
+        - resblocks.i -> layers.<postnet_size + 1 + ups_size + i + 3 * watermark_size / 4>
+        - conv_post.i -> layers.<postnet_size + 1 + ups_size + resblocks_size + i + watermark_size>
+    """
+
+    return {
+        # fmt: off
+        # postnet_size = 5, 1st wm block = 8 -> 13
+        r"^conv_pre\.":               r"layers.13.",                                 # noqa, 2nd wm block = 8 -> +8
+        r"^ups\.([0-9]+)\.":          lambda x: f"layers.{int(x.group(1)) + 22}.",   # noqa, ups_size = 4, 3rd wm block = 8 -> +12
+        r"^resblocks\.([0-9]+)\.":    lambda x: f"layers.{int(x.group(1)) + 34}.",   # noqa, resblocks_size = 12, 4th wm block = 8 -> +20
+        r"^conv_post\.":              r"layers.54.",
+        # fmt: on
+    }
+
+
+def wm_key_map() -> Mapping[Any, Any]:
+    """
+    flatten all encoders and decoders into the one sequential layer (step 1) and split the watermaker
+    into 4 blocks and mix into the layers of the p2v decoder and mel-vocoder:
+        - encoder.model.[0-7] --> layers.<postnet_size + i> (5 --> 12)
+        - encoder.model.[8-15] --> layers.<postnet_size + 9> (14 --> 21)
+        - decoder.model.[0-7] --> layers.<postnet_size + vocoder_ups_size + conv_pre + 16> (26 -> 33)
+        - decoder.model.[8-15] --> layers.<postnet_size + vocoder_ups_size + conv_pre + resblock_size + 24> (46 -> 53)
+    """
+
+    def encoder_layer_index(match_obj: Match[str]) -> str:
+        idx = int(match_obj.group(1))
+        # First half of the encoder is after the PostNet
+        if idx < 8:
+            # postnet_size = 5
+            return f"layers.{idx + 5}."
+
+        # Second half of the encoder goes after the mel-vocoder:conv_pre
+        else:
+            # postnet = 5, conv_pre = 1 --> +6
+            return f"layers.{idx + 6}."
+
+    def decoder_layer_index(match_obj: Match[str]) -> str:
+        idx = int(match_obj.group(1))
+        # First half of the decoder is after the mel-vocoder:ups
+        if idx < 8:
+            # postnet 5, conv_pre 1, encoder 16, ups 4 --> +26
+            return f"layers.{idx + 26}."
+        else:
+            # postnet 5, conv_pre 1, encoder 16, ups 4, resblock 12 -> +38
+            return f"layers.{idx + 38}."
+
+    return {
+        r"^encoder\.model\.([0-9]+)\.": encoder_layer_index,
+        r"^decoder\.model\.([0-9]+)\.": decoder_layer_index,
+    }
+
+
+def combine_chkpts(pretssel_file: str, vocoder_file: str, out_path: str) -> None:
+    """Combine the pretssel and melhifigan into one model"""
+    pretssel_chkpt = load_checkpoint(pretssel_file)
+    pretssel_chkpt = convert_fairseq_checkpoint(pretssel_chkpt, pretssel_key_map())
+
+    vocoder_chkpt = load_checkpoint(vocoder_file)
+    vocoder_chkpt = convert_fairseq_checkpoint(vocoder_chkpt, vocoder_key_map())
+
+    wm_ckpt = load_checkpoint(
+        "/large_experiments/seamless/nllb/watermarking/checkpoints/ckpt_e9d0008c.th",
+    )
+    # wm_ckpt is not a fairseq2 checkpoint so we have to handle it differently
+    wm_ckpt = convert_model_state_dict(wm_ckpt, wm_key_map())
+
+    # Merge the state dicts
+    ckpt = pretssel_chkpt
+    state_dict = ckpt["model"]
+    for vocoder_key in vocoder_chkpt["model"]:
+        state_dict[vocoder_key] = vocoder_chkpt["model"][vocoder_key]
+
+    for wm_key, wm_val in wm_ckpt.items():
+        if wm_key.startswith("layers."):
+            state_dict[wm_key] = wm_val
+
+    # Remove obsolete layers
+    keys_to_delete = [
+        "encoder.embed_positions._float_tensor",
+        "decoder.embed_positions._float_tensor",
+        "enc_emb_proj.weight",
+        "enc_emb_proj.bias",
+    ]
+    keys_to_delete.extend(
+        [
+            key
+            for key in state_dict
+            if key.startswith("decoder.var_adaptor.duration_predictor")
+        ]
+    )
+    for key in keys_to_delete:
+        if key in state_dict:
+            del state_dict[key]
+
+    out_path = "/large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-final.pt"
+    model_mapping_metafile = Path(out_path).with_suffix(".arch")
+    with open(model_mapping_metafile, "w", encoding="utf-8") as o:
+        o.write(vocoder_key_map.__doc__)  # type: ignore
+        o.write("\n")
+        o.write(wm_key_map.__doc__)  # type: ignore
+        o.write("\n")
+    torch.save(ckpt, out_path)
+
+
+if __name__ == "__main__":
+    # fmt: off
+    parser = ArgumentParser(description="Compile watermarking into p2v decoder and vocoder")
+    parser.add_argument(
+        "--pretssel",
+        default="/checkpoint/mjhwang/experiments/230930-noiseaug_p2v-mls_multilingual_6lang/231005-noiseaug_p2v-mls_multilingual_6lang-alignfix.config_v2.langemb1.vuv_logit1.denoise.ngpu16/checkpoint_best.pt",
+        type=str,
+        help="Path to the Pretssel model checkpoint",
+    )
+    parser.add_argument(
+        "--vocoder",
+        # default="/large_experiments/seamless/ust/changhan/checkpoints/fairseq2/pretssel_hifigan.pt",
+        default="/large_experiments/seamless/workstream/expressivity/oss/checkpoints/melhifigan_20231121.pt",
+        type=str,
+        help="Path to the mel-vocoder checkpoint",
+    )
+    parser.add_argument(
+        "--output",
+        default="/large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-final.pt",
+        # default="/large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-20231121.pt",
+        type=str,
+        help="Path to the output watermarked model checkpoint",
+    )
+    # fmt: on
+    args = parser.parse_args()
+    combine_chkpts(args.pretssel, args.vocoder, args.output)

+ 44 - 0
scripts/watermarking/seamlesswatermark.yaml

@@ -0,0 +1,44 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+name: seamlesswatermark
+model_type: seanet
+checkpoint: "/large_experiments/seamless/nllb/watermarking/checkpoints/ckpt_e9d0008c.th"
+watermarker_model:
+  channels: 1
+  sample_rate: 16000
+seanet:
+  activation: ELU
+  activation_params:
+    alpha: 1.0
+  causal: false
+  channels: 1
+  compress: 2
+  decoder:
+    final_activation: null
+    final_activation_params: null
+    trim_right_ratio: 1.0
+  detector: {}
+  dilation_base: 2
+  dimension: 128
+  disable_norm_outer_blocks: 0
+  encoder: {}
+  kernel_size: 7
+  last_kernel_size: 7
+  lstm: 2
+  n_filters: 32
+  n_residual_layers: 1
+  norm: weight_norm
+  norm_params: {}
+  pad_mode: constant
+  ratios:
+  - 8
+  - 5
+  - 4
+  - 2
+  residual_kernel_size: 3
+  true_skip: true

+ 236 - 0
scripts/watermarking/watermarking.py

@@ -0,0 +1,236 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+# The original implementation for the watermarking
+# This is not open-sourced and only kept here for future reference
+# mypy: ignore-errors
+
+import math
+from argparse import ArgumentParser, ArgumentTypeError
+from pathlib import Path
+from typing import Any, Dict, Union, cast
+
+import audiocraft
+import omegaconf
+import torch
+import torch.nn as nn
+import torchaudio
+from audiocraft.modules.seanet import SEANetEncoder
+from audiocraft.utils.utils import dict_from_config
+from fairseq2.typing import DataType, Device
+
+
+class SEANetEncoderKeepDimension(SEANetEncoder):
+    """
+    similar architecture to the SEANet encoder but with an extra step that
+    projects the output dimension to the same input dimension by repeating
+    the sequential
+
+    Args:
+        SEANetEncoder (_type_): _description_
+    """
+
+    def __init__(self, output_hidden_dim: int = 8, *args, **kwargs):  # type: ignore
+        super().__init__(*args, **kwargs)
+        self.output_hidden_dim = output_hidden_dim
+        # Adding a reverse convolution layer
+        self.reverse_convolution = nn.ConvTranspose1d(
+            in_channels=self.dimension,
+            out_channels=self.output_hidden_dim,
+            kernel_size=math.prod(self.ratios),
+            stride=math.prod(self.ratios),
+            padding=0,
+        )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        orig_nframes = x.shape[-1]
+        x = self.model(x)
+        x = self.reverse_convolution(x)
+        # make sure dim didn't change
+        x = x[:, :, :orig_nframes]
+        return x
+
+
+class Watermarker(nn.Module):
+    """
+    Initialize the Watermarker model.
+
+    Args:
+        encoder (nn.Module): Watermark Encoder.
+        decoder (nn.Module): Watermark Decoder.
+        detector (nn.Module): Watermark Detector.
+        sample_rate (int): Audio sample rate.
+        channels (int): Number of audio channels.
+    """
+
+    sample_rate: int = 0
+    channels: int = 0
+    encoder: SEANetEncoder
+    decoder: SEANetEncoder
+    detector: SEANetEncoderKeepDimension
+
+    def __init__(
+        self,
+        encoder: SEANetEncoder,
+        decoder: SEANetEncoder,
+        detector: SEANetEncoderKeepDimension,
+        sample_rate: int,
+        channels: int,
+    ):
+        super().__init__()
+        self.encoder = encoder
+        self.decoder = decoder
+        self.detector = detector
+        self.sample_rate = sample_rate
+        self.channels = channels
+
+    def get_watermark(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Get the watermark from a batch of audio input.
+
+        Args:
+            x (torch.Tensor): Input audio tensor with dimensions [batch size, channels = 1, frames].
+
+        Returns:
+            torch.Tensor: Output watermark with the same dimensionality as the input.
+        """
+        hidden = self.encoder(x)
+        # assert dim in = dim out
+        watermark = self.decoder(hidden)[:, :, : x.size(-1)]
+        return watermark
+
+    def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
+        """
+        Detect the watermark in a batch of audio input.
+
+        Args:
+            x (torch.Tensor): Input audio tensor with dimensions
+            [batch size, channels = 1, frames].
+
+        Returns:
+            torch.Tensor: Predictions of the classifier for watermark
+            with dimensions [bsz, classes = 2, frames].
+            For each frame, the detector outputs probabilities of
+            non-watermarked class (class id 0) and
+            the probability of "watermarked" class (class id 1).
+            To do inference, you can use output[:, 1, :]
+            to get probabilities of input audio being watermarked.
+        """
+        return self.detector(x)
+
+
+def model_from_checkpoint(
+    checkpoint_path: Union[Path, str] = Path(__file__).parent
+    / "seamlesswatermark.yaml",
+    device: Union[torch.device, str] = "cpu",
+    dtype: DataType = torch.float32,
+) -> Watermarker:
+    """Instantiate a Watermarker model from a given checkpoint path.
+
+    Example usage:
+    >>> from watermarking.watermarking import *
+    >>> cfg = "seamlesswatermark.yaml"
+    >>> url = "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav"
+    >>> urllib.request.urlretrieve(url, "random.wav")
+    >>> wav, _ = torchaudio.load("random.wav")
+    >>> wav = wav.unsqueeze(0)  # add bsz dimension
+
+    # code starts here
+    >>> model = model_from_checkpoint(cfg, device = wav.device)
+
+    >>> watermark = model.get_watermark(wav)
+
+    >>> watermarked_audio = wav + watermark
+    >>> detection = model.detect_watermark(watermarked_audio)
+    >>> print(detection[:,1,:])  # print prob of watermarked class # should be > 0.5
+
+    >>> detection = model.detect_watermark(wav)
+    >>> print(detection[:,1,:])  # print prob of watermarked class  # should be < 0.5
+
+    Args:
+        checkpoint_path (Path or str): Path to the checkpoint file.
+        device (torch.device or str, optional): Device on which
+        the model is loaded (default is "cpu").
+
+    Returns:
+        Watermarker: An instance of the Watermarker model loaded from the checkpoint.
+    """
+    cfg = omegaconf.OmegaConf.load(checkpoint_path)
+    state: Dict[str, Any] = torch.load(cfg["checkpoint"])
+    watermarking_model = get_watermarking_model(cfg)
+    watermarking_model.load_state_dict(state)
+    watermarking_model = watermarking_model.to(device, dtype=dtype)
+    watermarking_model.eval()
+    return watermarking_model
+
+
+def get_watermarking_model(cfg: omegaconf.DictConfig) -> Watermarker:
+    kwargs = dict_from_config(getattr(cfg, "watermarker_model"))
+    encoder, decoder = get_encodec_autoencoder(cfg)
+    detector = get_detector(cfg)
+    return Watermarker(encoder, decoder, detector, **kwargs)
+
+
+def get_encodec_autoencoder(cfg: omegaconf.DictConfig):
+    kwargs = dict_from_config(getattr(cfg, "seanet"))
+    if hasattr(cfg.seanet, "detector"):
+        kwargs.pop("detector")
+    encoder_override_kwargs = kwargs.pop("encoder")
+    decoder_override_kwargs = kwargs.pop("decoder")
+    encoder_kwargs = {**kwargs, **encoder_override_kwargs}
+    decoder_kwargs = {**kwargs, **decoder_override_kwargs}
+    encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
+    decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
+    return encoder, decoder
+
+
+def get_detector(cfg: omegaconf.DictConfig):
+    kwargs = dict_from_config(getattr(cfg, "seanet"))
+    encoder_override_kwargs = kwargs.pop("detector")
+    kwargs.pop("decoder")
+    kwargs.pop("encoder")
+    encoder_kwargs = {**kwargs, **encoder_override_kwargs}
+    output_hidden_dim = 8
+    encoder = SEANetEncoderKeepDimension(output_hidden_dim, **encoder_kwargs)
+
+    last_layer = torch.nn.Conv1d(output_hidden_dim, 2, 1)
+    softmax = torch.nn.Softmax(dim=1)
+    detector = torch.nn.Sequential(encoder, last_layer, softmax)
+    return detector
+
+
+def parse_device_arg(value: str) -> Device:
+    try:
+        return Device(value)
+    except RuntimeError:
+        raise ArgumentTypeError(f"'{value}' is not a valid device name.")
+
+
+if __name__ == "__main__":
+    """
+    Example usage:
+    python watermarking.py --device cuda:0 detect [file.wav]
+    """
+    parser = ArgumentParser(description="Handle the watermarking for audios")
+    parser.add_argument(
+        "--device",
+        default="cpu",
+        type=parse_device_arg,
+        help="device on which to run tests (default: %(default)s)",
+    )
+    sub_parser = parser.add_subparsers(title="actions", dest="sub_cmd")
+    detect_parser = sub_parser.add_parser("detect")
+    wm_parser = sub_parser.add_parser("wm")
+    parser.add_argument("file", type=str, help="Path to the .wav file")
+
+    args = parser.parse_args()
+
+    if args.sub_cmd == "detect":
+        model = model_from_checkpoint(device=args.device)
+        wav, _ = torchaudio.load(args.file)
+        wav = wav.unsqueeze(0)
+        wav = wav.to(args.device)
+        detection = model.detect_watermark(wav)
+        print(detection[:, 1, :])

+ 182 - 0
src/seamless_communication/cards/vocoder_pretssel.yaml

@@ -0,0 +1,182 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: vocoder_pretssel
+model_type: vocoder_pretssel
+model_arch: 24khz
+checkpoint: "file:///large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-final.pt"
+sample_rate: 24000
+model_config:
+  langs:
+    - cmn
+    - deu
+    - eng
+    - fra
+    - ita
+    - spa
+  gcmvn_stats:
+    mean:
+      - 9.023406257490224
+      - 9.406622923058864
+      - 10.554165334059368
+      - 11.475190058682356
+      - 12.179117104099705
+      - 12.603782921407062
+      - 12.769632747861747
+      - 12.714276772934083
+      - 12.747612172560233
+      - 12.750373688097946
+      - 12.948050207790237
+      - 13.121829398704277
+      - 13.40130828476734
+      - 13.58028050886195
+      - 13.601835409305883
+      - 13.608734047373218
+      - 13.538274892335826
+      - 13.391518457210937
+      - 13.382843811359622
+      - 13.0524299456858
+      - 12.785193828396269
+      - 12.876608812372632
+      - 12.59571918874957
+      - 12.674484745567813
+      - 12.57325195345546
+      - 12.651938120109422
+      - 12.556821722150424
+      - 12.639338348530158
+      - 12.610449431411217
+      - 12.639992872912376
+      - 12.697503827987052
+      - 12.754788270377214
+      - 12.837605043617405
+      - 12.964379088501497
+      - 13.11997048142582
+      - 13.267395589173432
+      - 13.384668687260483
+      - 13.495000208959356
+      - 13.606835320307384
+      - 13.578073476073252
+      - 13.689796531497368
+      - 13.643079802391588
+      - 13.7340755472615
+      - 13.735199777666043
+      - 13.79347692248429
+      - 13.875183654243305
+      - 13.967272256671393
+      - 14.058507936754117
+      - 14.114704594203507
+      - 14.156211337193277
+      - 14.14747081594401
+      - 14.173917097974343
+      - 14.22330474758318
+      - 14.251272943225572
+      - 14.230904505178053
+      - 14.226937644205396
+      - 14.222223350670225
+      - 14.211638354996317
+      - 14.208930098405544
+      - 14.19476983404041
+      - 14.2195925729048
+      - 14.16490878238837
+      - 14.115436751205117
+      - 14.039442767347872
+      - 13.976934063901625
+      - 13.917068116556464
+      - 13.856293662219073
+      - 13.773769842100085
+      - 13.706245521082796
+      - 13.685052933361192
+      - 13.68570131643094
+      - 13.714811890011152
+      - 13.751451253935347
+      - 13.772212258132148
+      - 13.76013448427468
+      - 13.702368406557508
+      - 13.600406368803617
+      - 13.369574889658164
+      - 12.998399608309988
+      - 12.443732902848723
+    std:
+      - 3.729248515707457
+      - 4.001623098079929
+      - 4.570009061358065
+      - 4.811572361201577
+      - 5.010239923828185
+      - 5.152145212706857
+      - 5.223885876119451
+      - 5.224443623432338
+      - 5.161790275239061
+      - 5.098988232815804
+      - 5.090890035509122
+      - 5.130345212529546
+      - 5.165849688173366
+      - 5.164761699263693
+      - 5.131177988219367
+      - 5.085522051815558
+      - 5.035829108165894
+      - 4.987478975310455
+      - 4.932652442855969
+      - 4.8650037198748075
+      - 4.799238163232527
+      - 4.727086345775988
+      - 4.646858066575789
+      - 4.5733249959652715
+      - 4.51685060334288
+      - 4.467449073425149
+      - 4.4296881304192075
+      - 4.4028775449713775
+      - 4.397905653025904
+      - 4.3862594566308015
+      - 4.366485847923521
+      - 4.344483498393771
+      - 4.324692736391383
+      - 4.310481738978154
+      - 4.3053492473916
+      - 4.3035205126659655
+      - 4.2987898577000605
+      - 4.287403454800855
+      - 4.27087296372773
+      - 4.25387490294079
+      - 4.233513102251301
+      - 4.212047255068752
+      - 4.1810370158214445
+      - 4.186014591107853
+      - 4.194806047136222
+      - 4.2183377208747075
+      - 4.249293562464735
+      - 4.268847210561774
+      - 4.270455756367186
+      - 4.25811368227528
+      - 4.245975115347766
+      - 4.23058010369271
+      - 4.203075111087773
+      - 4.20123812057283
+      - 4.187143614375688
+      - 4.172633823274146
+      - 4.162541203161947
+      - 4.156022884601996
+      - 4.1618428838805706
+      - 4.157259439238067
+      - 4.139859013016601
+      - 4.150685014911159
+      - 4.152025499126372
+      - 4.165010788120131
+      - 4.15179422331336
+      - 4.137041631098819
+      - 4.10861757770052
+      - 4.119916019361405
+      - 4.131749366642117
+      - 4.119438578634397
+      - 4.100095269698108
+      - 4.073900009963118
+      - 4.0580796715728855
+      - 4.050916705279105
+      - 4.037976834115189
+      - 4.023757063156459
+      - 3.9987849927993353
+      - 3.989251079820668
+      - 3.9464430977885256
+      - 3.8673932921278995

+ 182 - 0
src/seamless_communication/cards/vocoder_pretssel_16khz.yaml

@@ -0,0 +1,182 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+name: vocoder_pretssel_16khz
+model_type: vocoder_pretssel
+model_arch: 16khz
+checkpoint: "file:///large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-16khz.pt"
+sample_rate: 16000
+model_config:
+  langs:
+    - cmn
+    - deu
+    - eng
+    - fra
+    - ita
+    - spa
+  gcmvn_stats:
+    mean:
+      - 9.023406257490224
+      - 9.406622923058864
+      - 10.554165334059368
+      - 11.475190058682356
+      - 12.179117104099705
+      - 12.603782921407062
+      - 12.769632747861747
+      - 12.714276772934083
+      - 12.747612172560233
+      - 12.750373688097946
+      - 12.948050207790237
+      - 13.121829398704277
+      - 13.40130828476734
+      - 13.58028050886195
+      - 13.601835409305883
+      - 13.608734047373218
+      - 13.538274892335826
+      - 13.391518457210937
+      - 13.382843811359622
+      - 13.0524299456858
+      - 12.785193828396269
+      - 12.876608812372632
+      - 12.59571918874957
+      - 12.674484745567813
+      - 12.57325195345546
+      - 12.651938120109422
+      - 12.556821722150424
+      - 12.639338348530158
+      - 12.610449431411217
+      - 12.639992872912376
+      - 12.697503827987052
+      - 12.754788270377214
+      - 12.837605043617405
+      - 12.964379088501497
+      - 13.11997048142582
+      - 13.267395589173432
+      - 13.384668687260483
+      - 13.495000208959356
+      - 13.606835320307384
+      - 13.578073476073252
+      - 13.689796531497368
+      - 13.643079802391588
+      - 13.7340755472615
+      - 13.735199777666043
+      - 13.79347692248429
+      - 13.875183654243305
+      - 13.967272256671393
+      - 14.058507936754117
+      - 14.114704594203507
+      - 14.156211337193277
+      - 14.14747081594401
+      - 14.173917097974343
+      - 14.22330474758318
+      - 14.251272943225572
+      - 14.230904505178053
+      - 14.226937644205396
+      - 14.222223350670225
+      - 14.211638354996317
+      - 14.208930098405544
+      - 14.19476983404041
+      - 14.2195925729048
+      - 14.16490878238837
+      - 14.115436751205117
+      - 14.039442767347872
+      - 13.976934063901625
+      - 13.917068116556464
+      - 13.856293662219073
+      - 13.773769842100085
+      - 13.706245521082796
+      - 13.685052933361192
+      - 13.68570131643094
+      - 13.714811890011152
+      - 13.751451253935347
+      - 13.772212258132148
+      - 13.76013448427468
+      - 13.702368406557508
+      - 13.600406368803617
+      - 13.369574889658164
+      - 12.998399608309988
+      - 12.443732902848723
+    std:
+      - 3.729248515707457
+      - 4.001623098079929
+      - 4.570009061358065
+      - 4.811572361201577
+      - 5.010239923828185
+      - 5.152145212706857
+      - 5.223885876119451
+      - 5.224443623432338
+      - 5.161790275239061
+      - 5.098988232815804
+      - 5.090890035509122
+      - 5.130345212529546
+      - 5.165849688173366
+      - 5.164761699263693
+      - 5.131177988219367
+      - 5.085522051815558
+      - 5.035829108165894
+      - 4.987478975310455
+      - 4.932652442855969
+      - 4.8650037198748075
+      - 4.799238163232527
+      - 4.727086345775988
+      - 4.646858066575789
+      - 4.5733249959652715
+      - 4.51685060334288
+      - 4.467449073425149
+      - 4.4296881304192075
+      - 4.4028775449713775
+      - 4.397905653025904
+      - 4.3862594566308015
+      - 4.366485847923521
+      - 4.344483498393771
+      - 4.324692736391383
+      - 4.310481738978154
+      - 4.3053492473916
+      - 4.3035205126659655
+      - 4.2987898577000605
+      - 4.287403454800855
+      - 4.27087296372773
+      - 4.25387490294079
+      - 4.233513102251301
+      - 4.212047255068752
+      - 4.1810370158214445
+      - 4.186014591107853
+      - 4.194806047136222
+      - 4.2183377208747075
+      - 4.249293562464735
+      - 4.268847210561774
+      - 4.270455756367186
+      - 4.25811368227528
+      - 4.245975115347766
+      - 4.23058010369271
+      - 4.203075111087773
+      - 4.20123812057283
+      - 4.187143614375688
+      - 4.172633823274146
+      - 4.162541203161947
+      - 4.156022884601996
+      - 4.1618428838805706
+      - 4.157259439238067
+      - 4.139859013016601
+      - 4.150685014911159
+      - 4.152025499126372
+      - 4.165010788120131
+      - 4.15179422331336
+      - 4.137041631098819
+      - 4.10861757770052
+      - 4.119916019361405
+      - 4.131749366642117
+      - 4.119438578634397
+      - 4.100095269698108
+      - 4.073900009963118
+      - 4.0580796715728855
+      - 4.050916705279105
+      - 4.037976834115189
+      - 4.023757063156459
+      - 3.9987849927993353
+      - 3.989251079820668
+      - 3.9464430977885256
+      - 3.8673932921278995

+ 1 - 1
src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -385,7 +385,7 @@ def main() -> None:
         text_generation_opts=text_generation_opts,
         text_generation_opts=text_generation_opts,
         unit_generation_opts=unit_generation_opts,
         unit_generation_opts=unit_generation_opts,
         unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
         unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
-        output_path=Path(args.output_path),
+        output_path=args.output_path,
         gcmvn_mean=torch.tensor(gcmvn_mean, device=device, dtype=dtype),
         gcmvn_mean=torch.tensor(gcmvn_mean, device=device, dtype=dtype),
         gcmvn_std=torch.tensor(gcmvn_std, device=device, dtype=dtype),
         gcmvn_std=torch.tensor(gcmvn_std, device=device, dtype=dtype),
         pretssel_model=args.pretssel_model,
         pretssel_model=args.pretssel_model,

+ 373 - 0
src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py

@@ -0,0 +1,373 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import argparse
+import contextlib
+import logging
+from argparse import Namespace
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Callable, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch.nn import Module
+import torchaudio
+from fairseq2.assets.card import AssetCard
+from fairseq2.data import Collater, DataPipeline, FileMapper, SequenceData
+from fairseq2.data.audio import (
+    AudioDecoder,
+    WaveformToFbankConverter,
+    WaveformToFbankOutput,
+)
+from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
+from fairseq2.generation import SequenceGeneratorOptions
+from fairseq2.typing import DataType, Device
+from fairseq2.nn.padding import get_seqs_and_padding_mask
+from sacrebleu.metrics import BLEU  # type: ignore[attr-defined]
+from torch import Tensor
+from tqdm import tqdm
+
+from seamless_communication.models.unity import UnitTokenizer
+from seamless_communication.cli.m4t.evaluate.evaluate import (
+    adjust_output_for_corrupted_inputs,
+    count_lines,
+)
+from seamless_communication.cli.m4t.predict import (
+    add_inference_arguments,
+    set_generation_opts,
+)
+from seamless_communication.inference import BatchedSpeechOutput, Translator
+from seamless_communication.models.unity import (
+    load_gcmvn_stats,
+    load_unity_text_tokenizer,
+    load_unity_unit_tokenizer,
+)
+from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
+
+logging.basicConfig(
+    level=logging.INFO,
+    format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
+)
+
+logger = logging.getLogger(__name__)
+
+
+class PretsselGenerator(Module):
+    def __init__(
+        self,
+        pretssel_name_or_card: Union[str, AssetCard],
+        unit_tokenizer: UnitTokenizer,
+        device: Device,
+        dtype: DataType = torch.float16,
+    ):
+        super().__init__()
+        # Load the model.
+        if device == torch.device("cpu"):
+            dtype = torch.float32
+
+        self.device = device
+        self.dtype = dtype
+
+        self.pretssel_model = load_pretssel_vocoder_model(
+            pretssel_name_or_card,
+            device=device,
+            dtype=dtype,
+        )
+        self.pretssel_model.eval()
+
+        vocoder_model_card = asset_store.retrieve_card(vocoder_name_or_card)
+        self.output_sample_rate = vocoder_model_card.field("sample_rate").as_(int)
+
+        self.unit_tokenizer = unit_tokenizer
+        self.unit_collate = Collater(pad_value=unit_tokenizer.vocab_info.pad_idx)
+        self.duration_collate = Collater(pad_value=0)
+
+    @torch.inference_mode()
+    def predict(
+        self,
+        units: List[List[int]],
+        tgt_lang: str,
+        prosody_encoder_input: SequenceData,
+    ) -> BatchedSpeechOutput:
+        audio_wavs = []
+        unit_eos_token = torch.tensor(
+            [self.unit_tokenizer.vocab_info.eos_idx],
+            device=self.device,
+        )
+
+        prosody_input_seqs = prosody_encoder_input["seqs"]
+        prosody_input_lens = prosody_encoder_input["seq_lens"]
+
+        for i, u in enumerate(units):
+            unit = torch.tensor(u).to(unit_eos_token)
+
+            # adjust the control symbols for the embedding
+            unit += 4
+            unit = torch.cat([unit, unit_eos_token], dim=0)
+
+            unit, duration = torch.unique_consecutive(unit, return_counts=True)
+
+            # adjust for the last eos token
+            duration[-1] = 0
+
+            duration *= 2
+
+            prosody_input_seq = prosody_input_seqs[i][:prosody_input_lens[i]]
+
+            audio_wav = self.pretssel_model(
+                unit,
+                tgt_lang,
+                prosody_input_seq,
+                durations=duration.unsqueeze(0),
+            )
+
+            audio_wavs.append(audio_wav)
+
+        return BatchedSpeechOutput(
+            units=units,
+            audio_wavs=audio_wavs,
+            sample_rate=self.output_sample_rate,
+        )
+
+
+def build_data_pipeline(
+    args: Namespace,
+    text_tokenizer: TextTokenizer,
+    device: Device,
+    dtype: DataType,
+    gcmvn_mean: Tensor,
+    gcmvn_std: Tensor,
+) -> DataPipeline:
+    with open(args.data_file, "r") as f:
+        header = f.readline().strip("\n").split("\t")
+
+    n_parallel = 4
+
+    split_tsv = StrSplitter(names=header)
+
+    pipeline_builder = read_text(args.data_file, rtrim=True).skip(1).map(split_tsv)
+
+    assert args.audio_root_dir is not None
+
+    map_file = FileMapper(root_dir=args.audio_root_dir, cached_fd_count=10)
+
+    pipeline_builder.map(map_file, selector="audio", num_parallel_calls=n_parallel)
+
+    decode_audio = AudioDecoder(dtype=torch.float32, device=device)
+
+    convert_to_fbank = WaveformToFbankConverter(
+        num_mel_bins=80,
+        waveform_scale=2**15,
+        channel_last=True,
+        standardize=False,
+        device=device,
+        dtype=dtype,
+    )
+
+    def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
+        fbank = data["fbank"]
+        std, mean = torch.std_mean(fbank, dim=0)
+        data["fbank"] = fbank.subtract(mean).divide(std)
+        data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
+        return data
+
+    pipeline_builder.map(
+        [decode_audio, convert_to_fbank, normalize_fbank],
+        selector="audio.data",
+        num_parallel_calls=n_parallel,
+    )
+
+    pipeline_builder.bucket(bucket_size=args.batch_size)
+
+    collate = Collater(pad_value=0, pad_to_multiple=1)
+
+    pipeline_builder.map(collate, num_parallel_calls=n_parallel)
+
+    pipeline_builder.prefetch(4)
+
+    return pipeline_builder.and_return()
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(description="Running PretsselModel inference")
+    parser.add_argument("data_file", type=Path, help="Data file (.tsv) to be evaluated.")
+
+    parser = add_inference_arguments(parser)
+    parser.add_argument(
+        "--batch_size",
+        type=int,
+        help="Inference batch size.",
+        default=4,
+    )
+    parser.add_argument(
+        "--audio_root_dir",
+        type=Path,
+        help="Root directory for the audio filenames in the data file.",
+        default="",
+    )
+    parser.add_argument(
+        "--ref_field",
+        type=str,
+        help="Reference target text field to compute the BLEU score against.",
+        default="tgt_text",
+    )
+    parser.add_argument(
+        "--duration_factor",
+        type=float,
+        help="The duration factor for NAR T2U model. Expressivity model uses 1.1",
+        default=1.1,
+    )
+    args = parser.parse_args()
+
+    if torch.cuda.is_available():
+        device = torch.device("cuda:0")
+        dtype = torch.float16
+    else:
+        device = torch.device("cpu")
+        dtype = torch.float32
+
+    text_tokenizer = load_unity_text_tokenizer(args.model_name)
+    unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
+
+    _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
+    gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
+    gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
+
+    pipeline = build_data_pipeline(
+        args, text_tokenizer, device, dtype, gcmvn_mean, gcmvn_std
+    )
+
+    translator = Translator(
+        args.model_name,
+        vocoder_name_or_card=None,
+        device=device,
+        text_tokenizer=text_tokenizer,
+        dtype=dtype,
+    )
+
+    text_generation_opts, unit_generation_opts = set_generation_opts(args)
+
+    logger.info(f"{text_generation_opts=}")
+    logger.info(f"{unit_generation_opts=}")
+    logger.info(
+        f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
+    )
+
+    pretssel_generator = PretsselGenerator(
+        args.vocoder_name,
+        unit_tokenizer=unit_tokenizer,
+        device=device,
+        dtype=dtype,
+    )
+
+    total_steps = count_lines(args.data_file) - 1
+    progress_bar = tqdm(total=total_steps)
+
+    output_path = args.output_path / args.data_file.stem
+    output_path.mkdir(parents=True, exist_ok=True)
+
+    waveforms_dir = output_path / f"waveform"
+    waveforms_dir.mkdir(parents=True, exist_ok=True)
+
+    hyps = []
+    refs = []
+
+    with contextlib.ExitStack() as stack:
+        hyp_file = stack.enter_context(
+            open(output_path / f"text_output-{args.data_file.stem}.txt", "w")
+        )
+        unit_file = stack.enter_context(
+            open(output_path / f"unit_output-{args.data_file.stem}.txt", "w")
+        )
+
+        sample_id = 0
+        for example in pipeline:
+            valid_sequences: Optional[Tensor] = None
+            src = example["audio"]["data"]["fbank"]
+            # Skip corrupted audio tensors.
+            valid_sequences = ~torch.any(
+                torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
+            )
+            if not valid_sequences.all():
+                logger.warning(
+                    f"Sample IDs {sample_id} to {sample_id + args.batch_size} has some corrupted input."
+                )
+                src["seqs"] = src["seqs"][valid_sequences]
+                src["seq_lens"] = src["seq_lens"][valid_sequences]
+
+            # Skip performing inference when the input is entirely corrupted.
+            if src["seqs"].numel() > 0:
+                prosody_encoder_input = example["audio"]["data"]["gcmvn_fbank"]
+                text_output, unit_output = translator.predict(
+                    src,
+                    args.task,
+                    args.tgt_lang,
+                    src_lang=args.src_lang,
+                    text_generation_opts=text_generation_opts,
+                    unit_generation_opts=unit_generation_opts,
+                    unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
+                    duration_factor=args.duration_factor,
+                    prosody_encoder_input=prosody_encoder_input,
+                )
+
+                assert unit_output is not None
+                speech_output = pretssel_generator.predict(
+                    unit_output.units,
+                    tgt_lang=args.tgt_lang,
+                    prosody_encoder_input=prosody_encoder_input,
+                )
+
+            else:
+                text_output = []
+                speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
+
+            if valid_sequences is not None and not valid_sequences.all():
+                text_output, speech_output = adjust_output_for_corrupted_inputs(  # type: ignore[assignment]
+                    valid_sequences,
+                    text_output,
+                    speech_output,
+                )
+
+            hyps += [str(s) for s in text_output]
+            refs += [str(s) for s in example[args.ref_field]]
+
+            for i in range(len(text_output)):
+                t = text_output[i]
+                idx = str(example["id"][i])
+                hyp_file.write(f"{t}\n")
+
+                u = speech_output.units[i]
+                str_units = [str(i) for i in u]
+                unit_file.write(" ".join(str_units) + "\n")
+                torchaudio.save(
+                    waveforms_dir / f"{idx}_pred.wav",
+                    speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
+                    sample_rate=speech_output.sample_rate,
+                )
+
+                sample_id += 1
+                progress_bar.update(1)
+
+    progress_bar.close()
+    logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
+
+    assert len(hyps) == len(refs)
+    if len(hyps) > 0:
+        if args.tgt_lang in ("cmn", "jpn", "lao", "mya", "tha"):
+            tokenizer = "char"
+        else:
+            tokenizer = "13a"
+
+        bleu = BLEU(tokenize=tokenizer)
+        score = bleu.corpus_score(hyps, [refs])
+        bleu_filename = output_path / f"{args.data_file.stem}_text_output_bleu.json"
+        with open(bleu_filename, "w") as f:
+            f.write(score.format(signature=str(bleu.get_signature()), is_json=True))
+        logger.info(score.format(signature=bleu.get_signature()))
+
+
+if __name__ == "__main__":
+    main()

+ 1 - 1
src/seamless_communication/cli/m4t/evaluate/evaluate.py

@@ -424,7 +424,7 @@ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
         text_generation_opts=text_generation_opts,
         text_generation_opts=text_generation_opts,
         unit_generation_opts=unit_generation_opts,
         unit_generation_opts=unit_generation_opts,
         unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
         unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
-        output_path=Path(args.output_path),
+        output_path=args.output_path,
     )
     )
     # fmt: on
     # fmt: on
     logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")
     logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")

+ 3 - 2
src/seamless_communication/cli/m4t/predict/predict.py

@@ -6,6 +6,7 @@
 import argparse
 import argparse
 import logging
 import logging
 from argparse import Namespace
 from argparse import Namespace
+from pathlib import Path
 from typing import Tuple
 from typing import Tuple
 
 
 import torch
 import torch
@@ -35,7 +36,7 @@ def add_inference_arguments(parser: argparse.ArgumentParser) -> argparse.Argumen
     )
     )
     parser.add_argument(
     parser.add_argument(
         "--output_path",
         "--output_path",
-        type=str,
+        type=Path,
         help="Path to save the generated audio.",
         help="Path to save the generated audio.",
         default=None,
         default=None,
     )
     )
@@ -167,7 +168,7 @@ def set_generation_opts(
     return text_generation_opts, unit_generation_opts
     return text_generation_opts, unit_generation_opts
 
 
 
 
-def main():
+def main() -> None:
     parser = argparse.ArgumentParser(
     parser = argparse.ArgumentParser(
         description="M4T inference on supported tasks using Translator."
         description="M4T inference on supported tasks using Translator."
     )
     )

+ 5 - 0
src/seamless_communication/models/generator/__init__.py

@@ -0,0 +1,5 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.

+ 506 - 0
src/seamless_communication/models/generator/builder.py

@@ -0,0 +1,506 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from typing import Any, Dict, List, Literal, Optional, Tuple
+
+from fairseq2.data import VocabularyInfo
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.nn.embedding import StandardEmbedding, init_scaled_embedding
+from fairseq2.nn.position_encoder import SinusoidalPositionEncoder
+from fairseq2.nn.projection import Linear
+from fairseq2.nn.transformer import (
+    MultiheadAttention,
+    StandardMultiheadAttention,
+    TransformerNormOrder,
+    create_default_sdpa,
+)
+from fairseq2.typing import DataType, Device
+from torch.nn import Conv1d
+
+from seamless_communication.models.generator.ecapa_tdnn_builder import (
+    EcapaTDNNBuilder,
+    EcapaTDNNConfig,
+    ecapa_tdnn_archs,
+)
+from seamless_communication.models.generator.vocoder import (
+    PretsselDecoderFrontend,
+    PretsselEncoderFrontend,
+    PretsselVocoder,
+)
+from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
+from seamless_communication.models.unity.fft_decoder_layer import (
+    Conv1dBlock,
+    FeedForwardTransformerLayer,
+)
+from seamless_communication.models.unity.length_regulator import (
+    VarianceAdaptor,
+    VariancePredictor,
+)
+from seamless_communication.models.unity.t2u_builder import VariancePredictorConfig
+
+
+@dataclass
+class PretsselEncoderFrontendConfig:
+    prosody_encoder_config: EcapaTDNNConfig
+    dropout: float
+    lang_embed_dim: Optional[int] = None
+
+
+@dataclass
+class FFTLayerConfig:
+    attention_heads: int
+    hidden_dim: int
+    kernel_size: int
+    dropout: float
+    conv1d_dropout: float
+    film_cond_dim: int
+    use_film: bool = False
+
+
+@dataclass
+class PretsselDecoderFrontendConfig:
+    upsampling_type: Literal["gaussian", "hard"]
+    variance_predictor_config: VariancePredictorConfig
+    add_variance_parallel: bool
+
+
+@dataclass
+class VocoderConfig:
+    """Holds the configuration of a Vocoder model."""
+
+    encoder_frontend_config: PretsselEncoderFrontendConfig
+    fft_layer_config: FFTLayerConfig
+    decoder_frontend_config: PretsselDecoderFrontendConfig
+    pn_conv_dim: int
+    pn_layers: int
+    pn_conv_kernel_size: int
+    pn_dropout: float
+    vocab_info: VocabularyInfo
+    model_dim: int
+    max_seq_len: int
+    encoder_layers: int
+    decoder_layers: int
+    mel_dim: int
+    langs: List  # type: ignore[type-arg]
+    upsample_rates: List[int]
+    upsample_kernel_sizes: List[int]
+    upsample_initial_channel: int
+    resblock_kernel_sizes: List[int]
+    resblock_dilation_sizes: List[List[int]]
+    channels: int
+    dimension: int
+    n_filters: int
+    ratios: List[int]
+    norm: Literal["none", "weight_norm", "spectral_norm", "time_group_norm"]
+    norm_params: Dict[str, Any]
+    kernel_size: int
+    last_kernel_size: int
+    residual_kernel_size: int
+    causal: bool
+    pad_mode: str
+    true_skip: bool
+    compress: int
+    lstm: int
+    disable_norm_outer_blocks: int
+    trim_right_ratio: float
+    gcmvn_stats: Dict[str, List]  # type: ignore[type-arg]
+
+
+vocoder_archs = ArchitectureRegistry[VocoderConfig]("vocoder_pretssel")
+
+
+vocoder_arch = vocoder_archs.decorator
+
+
+def pretssel_config() -> (
+    Tuple[PretsselEncoderFrontendConfig, FFTLayerConfig, PretsselDecoderFrontendConfig]
+):
+    prosody_encoder_config = ecapa_tdnn_archs.get_config("base")
+
+    encoder_frontend_config = PretsselEncoderFrontendConfig(
+        prosody_encoder_config=prosody_encoder_config,
+        dropout=0.2,
+        lang_embed_dim=64,
+    )
+
+    fft_layer_config = FFTLayerConfig(
+        attention_heads=2,
+        hidden_dim=1024,
+        kernel_size=9,
+        dropout=0.0,
+        conv1d_dropout=0.2,
+        use_film=True,
+        film_cond_dim=576,
+    )
+
+    variance_predictor_config = VariancePredictorConfig(
+        var_pred_hidden_dim=512,
+        var_pred_kernel_size=5,
+        var_pred_dropout=0.5,
+        use_film=True,
+        film_cond_dim=576,
+    )
+
+    decoder_frontend_config = PretsselDecoderFrontendConfig(
+        upsampling_type="gaussian",
+        variance_predictor_config=variance_predictor_config,
+        add_variance_parallel=True,
+    )
+    return (
+        encoder_frontend_config,
+        fft_layer_config,
+        decoder_frontend_config,
+    )
+
+
+@vocoder_arch("16khz")
+def _16khz_vocoder() -> VocoderConfig:
+    (
+        encoder_frontend_config,
+        fft_layer_config,
+        decoder_frontend_config,
+    ) = pretssel_config()
+
+    return VocoderConfig(
+        encoder_frontend_config=encoder_frontend_config,
+        fft_layer_config=fft_layer_config,
+        decoder_frontend_config=decoder_frontend_config,
+        pn_conv_dim=512,
+        pn_layers=5,
+        pn_conv_kernel_size=5,
+        pn_dropout=0.5,
+        vocab_info=VocabularyInfo(
+            size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
+        ),
+        model_dim=256,
+        max_seq_len=4000,
+        encoder_layers=4,
+        decoder_layers=4,
+        mel_dim=80,
+        langs=[],
+        upsample_rates=[5, 4, 4, 2],
+        upsample_kernel_sizes=[10, 8, 8, 4],
+        upsample_initial_channel=512,
+        resblock_kernel_sizes=[3, 7, 11],
+        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+        channels=1,
+        dimension=128,
+        n_filters=32,
+        ratios=[8, 5, 4, 2],
+        norm="weight_norm",
+        norm_params={},
+        kernel_size=7,
+        last_kernel_size=7,
+        residual_kernel_size=3,
+        causal=False,
+        pad_mode="constant",
+        true_skip=True,
+        compress=2,
+        lstm=2,
+        disable_norm_outer_blocks=0,
+        trim_right_ratio=1.0,
+        gcmvn_stats={},
+    )
+
+
+@vocoder_arch("24khz")
+def _24khz_vocoder() -> VocoderConfig:
+    (
+        encoder_frontend_config,
+        fft_layer_config,
+        decoder_frontend_config,
+    ) = pretssel_config()
+
+    return VocoderConfig(
+        encoder_frontend_config=encoder_frontend_config,
+        fft_layer_config=fft_layer_config,
+        decoder_frontend_config=decoder_frontend_config,
+        pn_conv_dim=512,
+        pn_layers=5,
+        pn_conv_kernel_size=5,
+        pn_dropout=0.5,
+        vocab_info=VocabularyInfo(
+            size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1
+        ),
+        model_dim=256,
+        max_seq_len=4000,
+        encoder_layers=4,
+        decoder_layers=4,
+        mel_dim=80,
+        langs=[],
+        upsample_rates=[5, 4, 4, 3],
+        upsample_kernel_sizes=[10, 8, 8, 6],
+        upsample_initial_channel=512,
+        resblock_kernel_sizes=[3, 7, 11],
+        resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
+        channels=1,
+        dimension=128,
+        n_filters=32,
+        ratios=[8, 5, 4, 2],
+        norm="weight_norm",
+        norm_params={},
+        kernel_size=7,
+        last_kernel_size=7,
+        residual_kernel_size=3,
+        causal=False,
+        pad_mode="constant",
+        true_skip=True,
+        compress=2,
+        lstm=2,
+        disable_norm_outer_blocks=0,
+        trim_right_ratio=1.0,
+        gcmvn_stats={},
+    )
+
+
+class PretsselVocoderBuilder:
+    config: VocoderConfig
+    prosody_encoder_builder: EcapaTDNNBuilder
+    device: Optional[Device] = None
+    dtype: Optional[DataType] = None
+
+    def __init__(
+        self,
+        config: VocoderConfig,
+        prosody_encoder_builder: EcapaTDNNBuilder,
+        *,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param device:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        self.config = config
+        self.prosody_encoder_builder = prosody_encoder_builder
+        self.device, self.dtype = device, dtype
+
+    def build_embed_tokens(self) -> StandardEmbedding:
+        """Build a unit embedding table."""
+
+        return StandardEmbedding(
+            num_embeddings=self.config.vocab_info.size,
+            embedding_dim=self.config.model_dim,
+            init_fn=init_scaled_embedding,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_fft(self, num_layers: int) -> FeedForwardTransformer:
+        """Build a Transformer encoder."""
+
+        layers = [self.build_fft_layer() for _ in range(num_layers)]
+
+        return FeedForwardTransformer(
+            layers,
+            norm_order=TransformerNormOrder.POST,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_fft_layer(self) -> FeedForwardTransformerLayer:
+        """Build a Transformer decoder layer."""
+
+        self_attn = self.build_attention(self.config.fft_layer_config.attention_heads)
+
+        conv1d = Conv1dBlock(
+            self.config.model_dim,
+            self.config.fft_layer_config.hidden_dim,
+            self.config.fft_layer_config.kernel_size,
+            bias=True,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        return FeedForwardTransformerLayer(
+            self_attn,
+            conv1d,
+            dropout_p=0.0,  # fairseq1 doesn't have this
+            conv1d_dropout_p=self.config.fft_layer_config.conv1d_dropout,
+            use_film=self.config.fft_layer_config.use_film,
+            film_cond_dim=self.config.fft_layer_config.film_cond_dim,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_attention(self, num_heads: int) -> MultiheadAttention:
+        """Build a Transformer multi-head attention layer."""
+
+        sdpa = create_default_sdpa(attn_dropout_p=self.config.fft_layer_config.dropout)
+
+        return StandardMultiheadAttention(
+            self.config.model_dim,
+            num_heads,
+            sdpa=sdpa,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+    def build_variance_adaptor(
+        self,
+        decoder_frontend_config: PretsselDecoderFrontendConfig,
+    ) -> VarianceAdaptor:
+        """Build a variance adaptor module."""
+
+        variance_predictor_config = decoder_frontend_config.variance_predictor_config
+
+        pitch_predictor = VariancePredictor(
+            self.config.model_dim,
+            variance_predictor_config.var_pred_hidden_dim,
+            variance_predictor_config.var_pred_kernel_size,
+            variance_predictor_config.var_pred_dropout,
+            use_film=variance_predictor_config.use_film,
+            film_cond_dim=variance_predictor_config.film_cond_dim,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        embed_pitch = Conv1d(1, self.config.model_dim, kernel_size=1)
+
+        vuv_predictor = VariancePredictor(
+            self.config.model_dim,
+            variance_predictor_config.var_pred_hidden_dim,
+            variance_predictor_config.var_pred_kernel_size,
+            variance_predictor_config.var_pred_dropout,
+            use_film=variance_predictor_config.use_film,
+            film_cond_dim=variance_predictor_config.film_cond_dim,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        energy_predictor = VariancePredictor(
+            self.config.model_dim,
+            variance_predictor_config.var_pred_hidden_dim,
+            variance_predictor_config.var_pred_kernel_size,
+            variance_predictor_config.var_pred_dropout,
+            use_film=variance_predictor_config.use_film,
+            film_cond_dim=variance_predictor_config.film_cond_dim,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        embed_energy = Conv1d(1, self.config.model_dim, kernel_size=1)
+
+        variance_adaptor = VarianceAdaptor(
+            duration_predictor=None,
+            pitch_predictor=pitch_predictor,
+            embed_pitch=embed_pitch,
+            vuv_predictor=vuv_predictor,
+            energy_predictor=energy_predictor,
+            embed_energy=embed_energy,
+            add_variance_parallel=decoder_frontend_config.add_variance_parallel,
+            upsampling_type=decoder_frontend_config.upsampling_type,
+        )
+
+        return variance_adaptor
+
+    def build_model(self) -> PretsselVocoder:
+        """build the pretssel vocoder."""
+        prosody_encoder = self.prosody_encoder_builder.build_model()
+        embed_tokens = self.build_embed_tokens()
+
+        embed_positions = SinusoidalPositionEncoder(
+            self.config.model_dim,
+            self.config.max_seq_len,
+            _legacy_pad_idx=self.config.vocab_info.pad_idx,
+            device=self.device,
+        )
+        lang_to_index = {l: i for i, l in enumerate(self.config.langs)}
+        encoder_frontend = PretsselEncoderFrontend(
+            prosody_encoder,
+            embed_tokens,
+            embed_positions,
+            lang_to_index,
+            lang_embed_dim=self.config.encoder_frontend_config.lang_embed_dim,
+            dropout_p=self.config.encoder_frontend_config.dropout,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        encoder = self.build_fft(self.config.encoder_layers)
+
+        variance_adaptor = self.build_variance_adaptor(
+            self.config.decoder_frontend_config
+        )
+
+        decoder_frontend = PretsselDecoderFrontend(
+            variance_adaptor,
+            embed_positions,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        decoder = self.build_fft(self.config.decoder_layers)
+
+        final_proj = Linear(
+            self.config.model_dim,
+            self.config.mel_dim,
+            bias=True,
+            device=self.device,
+            dtype=self.dtype,
+        )
+
+        gcmvn_mean = gcmvn_std = None
+        if self.config.gcmvn_stats is not None:
+            gcmvn_mean = self.config.gcmvn_stats["mean"]
+            gcmvn_std = self.config.gcmvn_stats["std"]
+
+        vocoder = PretsselVocoder(
+            encoder_frontend=encoder_frontend,
+            encoder=encoder,
+            decoder_frontend=decoder_frontend,
+            decoder=decoder,
+            final_proj=final_proj,
+            pn_n_channels=self.config.pn_conv_dim,
+            pn_kernel_size=self.config.pn_conv_kernel_size,
+            pn_layers=self.config.pn_layers,
+            pn_dropout=self.config.pn_dropout,
+            upsample_rates=self.config.upsample_rates,
+            upsample_kernel_sizes=self.config.upsample_kernel_sizes,
+            upsample_initial_channel=self.config.upsample_initial_channel,
+            resblock_kernel_sizes=self.config.resblock_kernel_sizes,
+            resblock_dilation_sizes=self.config.resblock_dilation_sizes,
+            channels=self.config.channels,
+            dimension=self.config.dimension,
+            n_filters=self.config.n_filters,
+            ratios=self.config.ratios,
+            norm=self.config.norm,
+            norm_params=self.config.norm_params,
+            kernel_size=self.config.kernel_size,
+            last_kernel_size=self.config.last_kernel_size,
+            residual_kernel_size=self.config.residual_kernel_size,
+            causal=self.config.causal,
+            pad_mode=self.config.pad_mode,
+            true_skip=self.config.true_skip,
+            compress=self.config.compress,
+            lstm=self.config.lstm,
+            disable_norm_outer_blocks=self.config.disable_norm_outer_blocks,
+            trim_right_ratio=self.config.trim_right_ratio,
+            gcmvn_mean=gcmvn_mean,
+            gcmvn_std=gcmvn_std,
+        )
+        vocoder.to(dtype=self.dtype, device=self.device)
+        return vocoder
+
+
+def create_vocoder_model(
+    config: VocoderConfig,
+    device: Optional[Device] = None,
+    dtype: Optional[DataType] = None,
+) -> PretsselVocoder:
+    prosody_encoder_builder = EcapaTDNNBuilder(
+        config.encoder_frontend_config.prosody_encoder_config,
+        device=device,
+        dtype=dtype,
+    )
+    return PretsselVocoderBuilder(
+        config, prosody_encoder_builder, device=device, dtype=dtype
+    ).build_model()

+ 474 - 0
src/seamless_communication/models/generator/ecapa_tdnn.py

@@ -0,0 +1,474 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from fairseq2.nn.padding import PaddingMask, to_padding_mask
+from torch import Tensor
+from torch.nn import Conv1d, LayerNorm, Module, ModuleList, ReLU, Sigmoid, Tanh, init
+
+
+class ECAPA_TDNN(Module):
+    """
+    Represents the ECAPA-TDNN model described in paper:
+    :cite:t`https://doi.org/10.48550/arxiv.2005.07143`.
+
+    Arguments
+    ---------
+    :param channels:
+        Output channels for TDNN/SERes2Net layer.
+    :param kernel_sizes:
+        List of kernel sizes for each layer.
+    :param dilations:
+        List of dilations for kernels in each layer.
+    :param groups:
+        List of groups for kernels in each layer.
+    """
+
+    def __init__(
+        self,
+        channels: List[int],
+        kernel_sizes: List[int],
+        dilations: List[int],
+        attention_channels: int,
+        res2net_scale: int,
+        se_channels: int,
+        global_context: bool,
+        groups: List[int],
+        embed_dim: int,
+        input_dim: int,
+    ):
+        super().__init__()
+        assert len(channels) == len(kernel_sizes) == len(dilations)
+        self.channels = channels
+        self.embed_dim = embed_dim
+        self.blocks = ModuleList()
+
+        self.blocks.append(
+            TDNNBlock(
+                input_dim,
+                channels[0],
+                kernel_sizes[0],
+                dilations[0],
+                groups[0],
+            )
+        )
+
+        # SE-Res2Net layers
+        for i in range(1, len(channels) - 1):
+            self.blocks.append(
+                SERes2NetBlock(
+                    channels[i - 1],
+                    channels[i],
+                    res2net_scale=res2net_scale,
+                    se_channels=se_channels,
+                    kernel_size=kernel_sizes[i],
+                    dilation=dilations[i],
+                    groups=groups[i],
+                )
+            )
+
+        # Multi-layer feature aggregation
+        self.mfa = TDNNBlock(
+            channels[-1],
+            channels[-1],
+            kernel_sizes[-1],
+            dilations[-1],
+            groups=groups[-1],
+        )
+
+        # Attentive Statistical Pooling
+        self.asp = AttentiveStatisticsPooling(
+            channels[-1],
+            attention_channels=attention_channels,
+            global_context=global_context,
+        )
+        self.asp_norm = LayerNorm(channels[-1] * 2, eps=1e-12)
+
+        # Final linear transformation
+        self.fc = Conv1d(
+            in_channels=channels[-1] * 2,
+            out_channels=embed_dim,
+            kernel_size=1,
+        )
+
+        self.reset_parameters()
+
+    def reset_parameters(self) -> None:
+        """Reset the parameters and buffers of the module."""
+
+        def encoder_init(m: Module) -> None:
+            if isinstance(m, Conv1d):
+                init.xavier_uniform_(m.weight, init.calculate_gain("relu"))
+
+        self.apply(encoder_init)
+
+    def forward(
+        self,
+        x: Tensor,
+        padding_mask: Optional[PaddingMask] = None,
+    ) -> Tensor:
+        """Returns the embedding vector.
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Tensor of shape (batch, time, channel).
+        """
+        # Minimize transpose for efficiency
+        x = x.transpose(1, 2)
+
+        xl = []
+        for layer in self.blocks:
+            x = layer(x, padding_mask=padding_mask)
+            xl.append(x)
+
+        # Multi-layer feature aggregation
+        x = torch.cat(xl[1:], dim=1)
+        x = self.mfa(x)
+
+        # Attentive Statistical Pooling
+        x = self.asp(x, padding_mask=padding_mask)
+        x = self.asp_norm(x.transpose(1, 2)).transpose(1, 2)
+
+        # Final linear transformation
+        x = self.fc(x)
+
+        x = x.transpose(1, 2).squeeze(1)  # B x C
+        return F.normalize(x, dim=-1)
+
+
+class TDNNBlock(Module):
+    """An implementation of TDNN.
+
+    Arguments
+    ----------
+    :param in_channels : int
+        Number of input channels.
+    :param out_channels : int
+        The number of output channels.
+    :param kernel_size : int
+        The kernel size of the TDNN blocks.
+    :param dilation : int
+        The dilation of the TDNN block.
+    :param groups: int
+        The groups size of the TDNN blocks.
+
+    Example
+    -------
+    >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
+    >>> layer = TDNNBlock(64, 64, kernel_size=3, dilation=1)
+    >>> out_tensor = layer(inp_tensor).transpose(1, 2)
+    >>> out_tensor.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        dilation: int,
+        groups: int = 1,
+    ):
+        super().__init__()
+        self.conv = Conv1d(
+            in_channels=in_channels,
+            out_channels=out_channels,
+            kernel_size=kernel_size,
+            dilation=dilation,
+            padding=dilation * (kernel_size - 1) // 2,
+            groups=groups,
+        )
+        self.activation = ReLU()
+        self.norm = LayerNorm(out_channels, eps=1e-12)
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        x = self.activation(self.conv(x))
+
+        return self.norm(x.transpose(1, 2)).transpose(1, 2)  # type: ignore[no-any-return]
+
+
+class Res2NetBlock(Module):
+    """An implementation of Res2NetBlock w/ dilation.
+
+    Arguments
+    ---------
+    :param in_channels : int
+        The number of channels expected in the input.
+    :param out_channels : int
+        The number of output channels.
+    :param scale : int
+        The scale of the Res2Net block.
+    :param kernel_size: int
+        The kernel size of the Res2Net block.
+    :param dilation : int
+        The dilation of the Res2Net block.
+
+    Example
+    -------
+    >>> inp_tensor = torch.rand([8, 120, 64]).transpose(1, 2)
+    >>> layer = Res2NetBlock(64, 64, scale=4, dilation=3)
+    >>> out_tensor = layer(inp_tensor).transpose(1, 2)
+    >>> out_tensor.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        scale: int = 8,
+        kernel_size: int = 3,
+        dilation: int = 1,
+    ):
+        super().__init__()
+        assert in_channels % scale == 0
+        assert out_channels % scale == 0
+
+        in_channel = in_channels // scale
+        hidden_channel = out_channels // scale
+        self.blocks = ModuleList(
+            [
+                TDNNBlock(
+                    in_channel,
+                    hidden_channel,
+                    kernel_size=kernel_size,
+                    dilation=dilation,
+                )
+                for i in range(scale - 1)
+            ]
+        )
+        self.scale = scale
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        y = []
+        for i, x_i in enumerate(torch.chunk(x, self.scale, dim=1)):
+            if i == 0:
+                y_i = x_i
+            elif i == 1:
+                y_i = self.blocks[i - 1](x_i)
+            else:
+                y_i = self.blocks[i - 1](x_i + y_i)
+            y.append(y_i)
+
+        y_tensor = torch.cat(y, dim=1)
+        return y_tensor
+
+
+class SEBlock(Module):
+    """An implementation of squeeze-and-excitation block.
+
+    Arguments
+    ---------
+    in_channels : int
+        The number of input channels.
+    se_channels : int
+        The number of output channels after squeeze.
+    out_channels : int
+        The number of output channels.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        se_channels: int,
+        out_channels: int,
+    ):
+        super().__init__()
+
+        self.conv1 = Conv1d(
+            in_channels=in_channels, out_channels=se_channels, kernel_size=1
+        )
+        self.relu = ReLU(inplace=True)
+        self.conv2 = Conv1d(
+            in_channels=se_channels, out_channels=out_channels, kernel_size=1
+        )
+        self.sigmoid = Sigmoid()
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        if padding_mask is not None:
+            mask = padding_mask.materialize().unsqueeze(1)
+            s = (x * mask).sum(dim=2, keepdim=True) / padding_mask.seq_lens[
+                :, None, None
+            ]
+        else:
+            s = x.mean(dim=2, keepdim=True)
+
+        s = self.relu(self.conv1(s))
+        s = self.sigmoid(self.conv2(s))
+
+        return s * x
+
+
+class AttentiveStatisticsPooling(Module):
+    """This class implements an attentive statistic pooling layer for each channel.
+    It returns the concatenated mean and std of the input tensor.
+
+    Arguments
+    ---------
+    channels: int
+        The number of input channels.
+    attention_channels: int
+        The number of attention channels.
+    """
+
+    def __init__(
+        self, channels: int, attention_channels: int = 128, global_context: bool = True
+    ):
+        super().__init__()
+
+        self.eps = 1e-12
+        self.global_context = global_context
+        if global_context:
+            self.tdnn = TDNNBlock(channels * 3, attention_channels, 1, 1)
+        else:
+            self.tdnn = TDNNBlock(channels, attention_channels, 1, 1)
+
+        self.tanh = Tanh()
+        self.conv = Conv1d(
+            in_channels=attention_channels, out_channels=channels, kernel_size=1
+        )
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Calculates mean and std for a batch (input tensor).
+
+        Arguments
+        ---------
+        x : torch.Tensor
+            Tensor of shape [N, C, L].
+        """
+        L = x.shape[-1]
+
+        def _compute_statistics(
+            x: Tensor, m: Tensor, dim: int = 2, eps: float = self.eps
+        ) -> Tuple[Tensor, Tensor]:
+            mean = (m * x).sum(dim)
+            std = torch.sqrt((m * (x - mean.unsqueeze(dim)).pow(2)).sum(dim).clamp(eps))
+            return mean, std
+
+        # Make binary mask of shape [N, 1, L]
+        # mask = to_padding_mask(lengths, max(lengths))
+        if padding_mask is not None:
+            mask = padding_mask.materialize()
+        else:
+            mask = to_padding_mask(torch.IntTensor([L]), L).repeat(x.shape[0], 1).to(x)
+        mask = mask.unsqueeze(1)
+
+        # Expand the temporal context of the pooling layer by allowing the
+        # self-attention to look at global properties of the utterance.
+        if self.global_context:
+            # torch.std is unstable for backward computation
+            # https://github.com/pytorch/pytorch/issues/4320
+            total = mask.sum(dim=2, keepdim=True).to(x)
+            mean, std = _compute_statistics(x, mask / total)
+            mean = mean.unsqueeze(2).repeat(1, 1, L)
+            std = std.unsqueeze(2).repeat(1, 1, L)
+            attn = torch.cat([x, mean, std], dim=1)
+        else:
+            attn = x
+
+        # Apply layers
+        attn = self.conv(self.tanh(self.tdnn(attn)))
+
+        # Filter out zero-paddings
+        attn = attn.masked_fill(mask == 0, float("-inf"))
+
+        attn = F.softmax(attn, dim=2)
+        mean, std = _compute_statistics(x, attn)
+        # Append mean and std of the batch
+        pooled_stats = torch.cat((mean, std), dim=1)
+        pooled_stats = pooled_stats.unsqueeze(2)
+
+        return pooled_stats
+
+
+class SERes2NetBlock(Module):
+    """An implementation of building block in ECAPA-TDNN, i.e.,
+    TDNN-Res2Net-TDNN-SEBlock.
+
+    Arguments
+    ----------
+    out_channels: int
+        The number of output channels.
+    res2net_scale: int
+        The scale of the Res2Net block.
+    kernel_size: int
+        The kernel size of the TDNN blocks.
+    dilation: int
+        The dilation of the Res2Net block.
+    groups: int
+    Number of blocked connections from input channels to output channels.
+
+    Example
+    -------
+    >>> x = torch.rand(8, 120, 64).transpose(1, 2)
+    >>> conv = SERes2NetBlock(64, 64, res2net_scale=4)
+    >>> out = conv(x).transpose(1, 2)
+    >>> out.shape
+    torch.Size([8, 120, 64])
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        res2net_scale: int = 8,
+        se_channels: int = 128,
+        kernel_size: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+    ):
+        super().__init__()
+        self.out_channels = out_channels
+        self.tdnn1 = TDNNBlock(
+            in_channels,
+            out_channels,
+            kernel_size=1,
+            dilation=1,
+            groups=groups,
+        )
+        self.res2net_block = Res2NetBlock(
+            out_channels,
+            out_channels,
+            res2net_scale,
+            kernel_size,
+            dilation,
+        )
+        self.tdnn2 = TDNNBlock(
+            out_channels,
+            out_channels,
+            kernel_size=1,
+            dilation=1,
+            groups=groups,
+        )
+        self.se_block = SEBlock(out_channels, se_channels, out_channels)
+
+        self.shortcut = None
+        if in_channels != out_channels:
+            self.shortcut = Conv1d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=1,
+            )
+
+    def forward(self, x: Tensor, padding_mask: Optional[PaddingMask] = None) -> Tensor:
+        """Processes the input tensor x and returns an output tensor."""
+        residual = x
+        if self.shortcut:
+            residual = self.shortcut(x)
+
+        x = self.tdnn1(x)
+        x = self.res2net_block(x)
+        x = self.tdnn2(x)
+        x = self.se_block(x, padding_mask=padding_mask)
+
+        return x + residual

+ 112 - 0
src/seamless_communication/models/generator/ecapa_tdnn_builder.py

@@ -0,0 +1,112 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from dataclasses import dataclass
+from typing import List, Optional
+
+from fairseq2.models.utils.arch_registry import ArchitectureRegistry
+from fairseq2.typing import DataType, Device
+
+from seamless_communication.models.generator.ecapa_tdnn import ECAPA_TDNN
+
+
+@dataclass
+class EcapaTDNNConfig:
+    channels: List[int]
+    kernel_sizes: List[int]
+    dilations: List[int]
+    attention_channels: int
+    res2net_scale: int
+    se_channels: int
+    global_context: bool
+    groups: List[int]
+    embed_dim: int
+    input_dim: int
+
+
+ecapa_tdnn_archs = ArchitectureRegistry[EcapaTDNNConfig]("ecapa_tdnn")
+
+ecapa_tdnn_arch = ecapa_tdnn_archs.decorator
+
+
+@ecapa_tdnn_arch("base")
+def _base_ecapa_tdnn() -> EcapaTDNNConfig:
+    return EcapaTDNNConfig(
+        channels=[512, 512, 512, 512, 1536],
+        kernel_sizes=[5, 3, 3, 3, 1],
+        dilations=[1, 2, 3, 4, 1],
+        attention_channels=128,
+        res2net_scale=8,
+        se_channels=128,
+        global_context=True,
+        groups=[1, 1, 1, 1, 1],
+        embed_dim=512,
+        input_dim=80,
+    )
+
+
+class EcapaTDNNBuilder:
+    """
+    Builder module for ECAPA_TDNN model
+    """
+
+    config: EcapaTDNNConfig
+    device: Optional[Device]
+    dtype: Optional[DataType]
+
+    def __init__(
+        self,
+        config: EcapaTDNNConfig,
+        *,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ) -> None:
+        """
+        :param config:
+            The configuration to use.
+        :param devicev:
+            The device on which to initialize modules.
+        :param dtype:
+            The data type of module parameters and buffers.
+        """
+        self.config = config
+
+        self.device, self.dtype = device, dtype
+
+    def build_model(self) -> ECAPA_TDNN:
+        """Build a model."""
+        model = ECAPA_TDNN(
+            self.config.channels,
+            self.config.kernel_sizes,
+            self.config.dilations,
+            self.config.attention_channels,
+            self.config.res2net_scale,
+            self.config.se_channels,
+            self.config.global_context,
+            self.config.groups,
+            self.config.embed_dim,
+            self.config.input_dim,
+        )
+        model.to(device=self.device, dtype=self.dtype)
+        return model
+
+
+def create_ecapa_tdnn_model(
+    config: EcapaTDNNConfig,
+    device: Optional[Device] = None,
+    dtype: Optional[DataType] = None,
+) -> ECAPA_TDNN:
+    """Create a ECAPA_TDNN model.
+
+    :param config:
+        The configuration to use.
+    :param device:
+        The device on which to initialize modules.
+    :param dtype:
+        The data type of module parameters and buffers.
+    """
+
+    return EcapaTDNNBuilder(config, device=device, dtype=dtype).build_model()

+ 29 - 0
src/seamless_communication/models/generator/loader.py

@@ -0,0 +1,29 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+from typing import Any, Mapping
+
+from fairseq2.assets import asset_store, download_manager
+from fairseq2.models.utils import ConfigLoader, ModelLoader
+
+from seamless_communication.models.generator.builder import (
+    VocoderConfig,
+    create_vocoder_model,
+    vocoder_archs,
+)
+from seamless_communication.models.generator.vocoder import PretsselVocoder
+
+load_pretssel_vocoder_config = ConfigLoader[VocoderConfig](asset_store, vocoder_archs)
+
+
+load_pretssel_vocoder_model = ModelLoader[PretsselVocoder, VocoderConfig](
+    asset_store,
+    download_manager,
+    load_pretssel_vocoder_config,
+    create_vocoder_model,
+    restrict_checkpoints=False,
+)

+ 452 - 0
src/seamless_communication/models/generator/streamable.py

@@ -0,0 +1,452 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import math
+import warnings
+from typing import Any, Dict, List, Literal, Optional, Tuple, TypeVar
+
+import torch
+from fairseq2.typing import DataType, Device
+from torch.nn import (
+    ELU,
+    LSTM,
+    Conv1d,
+    ConvTranspose1d,
+    GroupNorm,
+    Identity,
+    Module,
+    Sequential,
+)
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm, weight_norm  # type: ignore[attr-defined]
+
+CONV_NORMALIZATIONS = frozenset(
+    ["none", "weight_norm", "spectral_norm", "time_group_norm"]
+)
+
+
+def apply_parametrization_norm(
+    module: Module,
+    norm: Literal["none", "weight_norm", "spectral_norm", "time_group_norm"] = "none",
+) -> Module:
+    if norm == "weight_norm":
+        return weight_norm(module)
+    elif norm == "spectral_norm":
+        return spectral_norm(module)
+    else:
+        # We already check was in CONV_NORMALIZATION, so any other choice
+        # doesn't need reparametrization.
+        return module
+
+
+def get_norm_module(  # type: ignore[no-untyped-def]
+    module: Module,
+    causal: bool = False,
+    norm: Literal["none", "weight_norm", "spectral_norm", "time_group_norm"] = "none",
+    **norm_kwargs,
+) -> Module:
+    """Return the proper normalization module. If causal is True, this will ensure the returned
+    module is causal, or return an error if the normalization doesn't support causal evaluation.
+    """
+    assert norm in CONV_NORMALIZATIONS
+    if norm == "time_group_norm":
+        if causal:
+            raise ValueError("GroupNorm doesn't support causal evaluation.")
+        assert isinstance(module, torch.nn.modules.conv._ConvNd)
+        return GroupNorm(1, module.out_channels, **norm_kwargs)
+    else:
+        return Identity()
+
+
+def get_extra_padding_for_conv1d(
+    x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+) -> int:
+    """See `pad_for_conv1d`."""
+    length = x.shape[-1]
+    n_frames = (length - kernel_size + padding_total) / stride + 1
+    ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
+    return ideal_length - length
+
+
+def pad_for_conv1d(
+    x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
+) -> torch.Tensor:
+    """Pad for a convolution to make sure that the last window is full.
+    Extra padding is added at the end. This is required to ensure that we can rebuild
+    an output of the same length, as otherwise, even with padding, some time steps
+    might get removed.
+    For instance, with total padding = 4, kernel size = 4, stride = 2:
+        0 0 1 2 3 4 5 0 0   # (0s are padding)
+        1   2   3           # (output frames of a convolution, last 0 is never used)
+        0 0 1 2 3 4 5 0     # (output of tr. conv., but pos. 5 is going to get removed as padding)
+            1 2 3 4         # once you removed padding, we are missing one time step !
+    """
+    extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
+    return F.pad(x, (0, extra_padding))  # noqa
+
+
+def pad1d(
+    x: torch.Tensor,
+    paddings: Tuple[int, int],
+    mode: str = "constant",
+    value: float = 0.0,
+) -> torch.Tensor:
+    """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
+    If this is the case, we insert extra 0 padding to the right before the reflection happen.
+    """
+    length = x.shape[-1]
+    padding_left, padding_right = paddings
+    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+    if mode == "reflect":
+        max_pad = max(padding_left, padding_right)
+        extra_pad = 0
+        if length <= max_pad:
+            extra_pad = max_pad - length + 1
+            x = F.pad(x, (0, extra_pad))
+        padded = F.pad(x, paddings, mode, value)
+        end = padded.shape[-1] - extra_pad
+        return padded[..., :end]
+    else:
+        return F.pad(x, paddings, mode, value)
+
+
+def unpad1d(x: torch.Tensor, paddings: Tuple[int, int]) -> torch.Tensor:
+    """Remove padding from x, handling properly zero padding. Only for 1d!"""
+    padding_left, padding_right = paddings
+    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
+    assert (padding_left + padding_right) <= x.shape[-1]
+    end = x.shape[-1] - padding_right
+    return x[..., padding_left:end]
+
+
+class NormConv1d(Module):
+    """Wrapper around Conv1d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+        bias: bool = True,
+        causal: bool = False,
+        norm: Literal[
+            "none", "weight_norm", "spectral_norm", "time_group_norm"
+        ] = "none",
+        norm_kwargs: Dict[str, Any] = {},
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+        self.conv: Module = apply_parametrization_norm(
+            Conv1d(
+                in_channels,
+                out_channels,
+                kernel_size,
+                stride,
+                dilation=dilation,
+                groups=groups,
+                bias=bias,
+                device=device,
+                dtype=dtype,
+            ),
+            norm,
+        )
+        self.norm: Module = get_norm_module(self.conv, causal, norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.conv(x)
+        x = self.norm(x)
+        return x
+
+
+class NormConvTranspose1d(Module):
+    """Wrapper around ConvTranspose1d and normalization applied to this conv
+    to provide a uniform interface across normalization approaches.
+    """
+
+    def __init__(  # type: ignore[no-untyped-def]
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        causal: bool = False,
+        norm: Literal[
+            "none", "weight_norm", "spectral_norm", "time_group_norm"
+        ] = "none",
+        norm_kwargs: Dict[str, Any] = {},
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+        self.convtr = apply_parametrization_norm(
+            ConvTranspose1d(
+                in_channels=in_channels,
+                out_channels=out_channels,
+                kernel_size=kernel_size,
+                stride=stride,
+                device=device,
+                dtype=dtype,
+            ),
+            norm,
+        )
+        self.norm = get_norm_module(self.convtr, causal, norm, **norm_kwargs)
+        self.norm_type = norm
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = self.convtr(x)
+        x = self.norm(x)
+        return x
+
+
+class StreamableConv1d(Module):
+    """Conv1d with some builtin handling of asymmetric or causal padding
+    and normalization.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        dilation: int = 1,
+        groups: int = 1,
+        bias: bool = True,
+        causal: bool = False,
+        norm: Literal[
+            "none", "weight_norm", "spectral_norm", "time_group_norm"
+        ] = "none",
+        norm_kwargs: Dict[str, Any] = {},
+        pad_mode: str = "reflect",
+        activation: Optional[Module] = None,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+        # warn user on unusual setup between dilation and stride
+        if stride > 1 and dilation > 1:
+            warnings.warn(
+                "StreamableConv1d has been initialized with stride > 1 and dilation > 1"
+                f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
+            )
+        self.activation = activation
+        self.conv = NormConv1d(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride,
+            dilation=dilation,
+            groups=groups,
+            bias=bias,
+            causal=causal,
+            norm=norm,
+            norm_kwargs=norm_kwargs,
+            device=device,
+            dtype=dtype,
+        )
+        self.causal = causal
+        self.pad_mode = pad_mode
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        if self.activation:
+            x = self.activation(x)
+        kernel_size: int = self.conv.conv.kernel_size[0]  # type: ignore[index,assignment]
+        stride: int = self.conv.conv.stride[0]  # type: ignore[index,assignment]
+        dilation = self.conv.conv.dilation[0]  # type: ignore[index]
+        kernel_size = (  # type: ignore[assignment]
+            kernel_size - 1
+        ) * dilation + 1  # effective kernel size with dilations
+        padding_total = kernel_size - stride
+        extra_padding = get_extra_padding_for_conv1d(
+            x, kernel_size, stride, padding_total
+        )
+        if self.causal:
+            # Left padding for causal
+            x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
+        else:
+            # Asymmetric padding required for odd strides
+            padding_right = padding_total // 2
+            padding_left = padding_total - padding_right
+            x = pad1d(
+                x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
+            )
+        return self.conv(x)  # type: ignore[no-any-return]
+
+
+class StreamableConvTranspose1d(Module):
+    """ConvTranspose1d with some builtin handling of asymmetric or causal padding
+    and normalization.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        out_channels: int,
+        kernel_size: int,
+        stride: int = 1,
+        causal: bool = False,
+        norm: Literal[
+            "none", "weight_norm", "spectral_norm", "time_group_norm"
+        ] = "none",
+        trim_right_ratio: float = 1.0,
+        norm_kwargs: Dict[str, Any] = {},
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+        self.convtr = NormConvTranspose1d(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride,
+            causal=causal,
+            norm=norm,
+            norm_kwargs=norm_kwargs,
+            device=device,
+            dtype=dtype,
+        )
+        self.causal = causal
+        self.trim_right_ratio = trim_right_ratio
+        assert (
+            self.causal or self.trim_right_ratio == 1.0
+        ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
+        assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        kernel_size: int = self.convtr.convtr.kernel_size[0]  # type: ignore[index,assignment]
+        stride: int = self.convtr.convtr.stride[0]  # type: ignore[index,assignment]
+        padding_total = kernel_size - stride
+
+        y: torch.Tensor = self.convtr(x)
+
+        # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
+        # removed at the very end, when keeping only the right length for the output,
+        # as removing it here would require also passing the length at the matching layer
+        # in the encoder.
+        if self.causal:
+            # Trim the padding on the right according to the specified ratio
+            # if trim_right_ratio = 1.0, trim everything from right
+            padding_right = math.ceil(padding_total * self.trim_right_ratio)
+            padding_left = padding_total - padding_right
+            y = unpad1d(y, (padding_left, padding_right))
+        else:
+            # Asymmetric padding required for odd strides
+            padding_right = padding_total // 2
+            padding_left = padding_total - padding_right
+            y = unpad1d(y, (padding_left, padding_right))
+        return y
+
+
+class StreamableLSTM(Module):
+    """LSTM without worrying about the hidden state, nor the layout of the data.
+    Expects input as convolutional layout.
+    """
+
+    def __init__(
+        self,
+        dimension: int,
+        num_layers: int = 2,
+        skip: bool = True,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+        self.skip = skip
+        self.lstm = LSTM(dimension, dimension, num_layers, device=device, dtype=dtype)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        x = x.permute(2, 0, 1)
+        y, _ = self.lstm(x)
+        if self.skip:
+            y = y + x
+        y = y.permute(1, 2, 0)
+        return y  # type: ignore[no-any-return]
+
+
+class StreamableResnetBlock(Module):
+    """custom Residual block model with streamable convnet.
+
+    Args:
+        dim (int): Dimension of the input/output.
+        kernel_sizes (list): List of kernel sizes for the convolutions.
+        dilations (list): List of dilations for the convolutions.
+        activation_params (dict): Parameters to provide to the (ELU) activation function.
+        norm (str): Normalization method.
+        norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
+        causal (bool): Whether to use fully causal convolution.
+        pad_mode (str): Padding mode for the convolutions.
+        compress (int): Reduced dimensionality in residual branches (from Demucs v3).
+        true_skip (bool): Whether to use true skip connection or a simple
+            (streamable) convolution as the skip connection.
+    """
+
+    def __init__(
+        self,
+        dim: int,
+        kernel_sizes: List[int] = [3, 1],
+        dilations: List[int] = [1, 1],
+        activation_params: Dict[str, Any] = {"alpha": 1.0},
+        norm: Literal[
+            "none", "weight_norm", "spectral_norm", "time_group_norm"
+        ] = "none",
+        norm_params: Dict[str, Any] = {},
+        causal: bool = False,
+        pad_mode: str = "reflect",
+        compress: int = 2,
+        true_skip: bool = True,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+        assert len(kernel_sizes) == len(
+            dilations
+        ), "Number of kernel sizes should match number of dilations"
+        hidden = dim // compress
+        block = []
+        for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
+            in_chs = dim if i == 0 else hidden
+            out_chs = dim if i == len(kernel_sizes) - 1 else hidden
+            block += [
+                ELU(**activation_params),
+                StreamableConv1d(
+                    in_chs,
+                    out_chs,
+                    kernel_size=kernel_size,
+                    dilation=dilation,
+                    norm=norm,
+                    norm_kwargs=norm_params,
+                    causal=causal,
+                    pad_mode=pad_mode,
+                    device=device,
+                    dtype=dtype,
+                ),
+            ]
+        self.block = Sequential(*block)
+        self.shortcut: Module
+        if true_skip:
+            self.shortcut = Identity()
+        else:
+            self.shortcut = StreamableConv1d(
+                dim,
+                dim,
+                kernel_size=1,
+                norm=norm,
+                norm_kwargs=norm_params,
+                causal=causal,
+                pad_mode=pad_mode,
+                device=device,
+                dtype=dtype,
+            )
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        return self.shortcut(x) + self.block(x)  # type: ignore[no-any-return]

+ 582 - 0
src/seamless_communication/models/generator/vocoder.py

@@ -0,0 +1,582 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any, Dict, List, Literal, Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+from fairseq2.nn.embedding import Embedding, StandardEmbedding
+from fairseq2.nn.padding import PaddingMask
+from fairseq2.nn.position_encoder import PositionEncoder
+from fairseq2.nn.projection import Projection
+from fairseq2.typing import DataType, Device
+from torch.nn import (
+    ELU,
+    BatchNorm1d,
+    Conv1d,
+    ConvTranspose1d,
+    Dropout,
+    Module,
+    ModuleList,
+    Parameter,
+    Sequential,
+    Tanh,
+    init,
+)
+from torch.nn.utils.weight_norm import remove_weight_norm, weight_norm
+
+from seamless_communication.models.generator.ecapa_tdnn import ECAPA_TDNN
+from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
+from seamless_communication.models.unity.length_regulator import VarianceAdaptor
+from seamless_communication.models.vocoder.hifigan import (
+    LRELU_SLOPE,
+    ResBlock,
+    init_weights,
+)
+
+from .streamable import (
+    StreamableConv1d,
+    StreamableConvTranspose1d,
+    StreamableLSTM,
+    StreamableResnetBlock,
+)
+
+ELU_PARAMS: Dict[str, Any] = {"alpha": 1.0}
+
+
+class PretsselEncoderFrontend(Module):
+    """
+    Represent Encoder frontend, including the prosody encoder and language embedding
+    """
+
+    prosody_encoder: ECAPA_TDNN
+    embed_tokens: Embedding
+    embed_positions: PositionEncoder
+    pos_emb_alpha: Parameter
+    embed_lang: Embedding
+    dropout: Dropout
+
+    def __init__(
+        self,
+        prosody_encoder: ECAPA_TDNN,
+        embed_tokens: Embedding,
+        embed_positions: PositionEncoder,
+        lang_to_index: Dict[str, int],
+        lang_embed_dim: Optional[int],
+        dropout_p: float,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+
+        self.prosody_encoder = prosody_encoder
+
+        self.embed_tokens = embed_tokens
+
+        self.embed_positions = embed_positions
+        self.pos_emb_alpha = Parameter(torch.ones(1, device=device, dtype=dtype))
+
+        self.lang_to_index = lang_to_index
+
+        if lang_embed_dim is not None:
+            self.embed_lang = StandardEmbedding(
+                len(lang_to_index), lang_embed_dim, device=device, dtype=dtype
+            )
+        else:
+            self.register_module("embed_lang", None)
+
+        self.dropout = Dropout(dropout_p)
+
+        self.device = device
+        self.dtype = dtype
+
+    def forward(
+        self,
+        seqs: torch.Tensor,
+        padding_mask: Optional[PaddingMask],
+        prosody_input_seqs: torch.Tensor,
+        prosody_padding_mask: Optional[PaddingMask],
+        tgt_lang: str,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        prosody_embs = self.prosody_encoder(
+            prosody_input_seqs,
+            prosody_padding_mask,
+        ).unsqueeze(1)
+
+        if self.embed_lang is not None:
+            lang_index = self.lang_to_index[tgt_lang]
+            lang_index_tensor = (
+                torch.Tensor([lang_index]).to(seqs).repeat(seqs.size(0), 1)
+            )
+            lang_embeds = self.embed_lang(lang_index_tensor)
+            prosody_embs = torch.cat([prosody_embs, lang_embeds], dim=-1)
+
+        seqs = self.embed_tokens(seqs)
+        seqs += self.pos_emb_alpha * (self.embed_positions(seqs, padding_mask) - seqs)
+        seqs = self.dropout(seqs)
+
+        return seqs, prosody_embs
+
+
+class PretsselDecoderFrontend(Module):
+    """Represent Decoder frontend, including VarianceAdaptor & Positional embedding"""
+
+    variance_adaptor: VarianceAdaptor
+    embed_positions: PositionEncoder
+    pos_emb_alpha: Parameter
+
+    def __init__(
+        self,
+        variance_adaptor: VarianceAdaptor,
+        embed_positions: PositionEncoder,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+
+        self.variance_adaptor = variance_adaptor
+        self.embed_positions = embed_positions
+        self.pos_emb_alpha = Parameter(torch.ones(1, device=device, dtype=dtype))
+
+        self.device = device
+        self.dtype = dtype
+
+    def forward(
+        self,
+        seqs: torch.Tensor,
+        padding_mask: PaddingMask,
+        durations: Optional[torch.Tensor] = None,
+        duration_factor: float = 1.0,
+        min_duration: int = 0,
+        film_cond_emb: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, PaddingMask]:
+        seqs, padding_mask, _ = self.variance_adaptor(
+            seqs, padding_mask, durations, duration_factor, min_duration, film_cond_emb
+        )
+
+        seqs += self.pos_emb_alpha * (self.embed_positions(seqs, padding_mask) - seqs)
+
+        return seqs, padding_mask
+
+
+class PretsselVocoder(Module):
+    """The expressivity-preserving vocoder"""
+
+    encoder_frontend: PretsselEncoderFrontend
+    encoder: FeedForwardTransformer
+    decoder_frontend: PretsselDecoderFrontend
+    decoder: FeedForwardTransformer
+    final_proj: Projection
+
+    def __init__(  # type: ignore[no-untyped-def]
+        self,
+        encoder_frontend: PretsselEncoderFrontend,
+        encoder: FeedForwardTransformer,
+        decoder_frontend: PretsselDecoderFrontend,
+        decoder: FeedForwardTransformer,
+        final_proj: Projection,
+        pn_n_channels: int,
+        pn_kernel_size: int,
+        pn_layers: int,
+        pn_dropout: float,
+        upsample_rates: List[int],
+        upsample_kernel_sizes: List[int],
+        upsample_initial_channel: int,
+        resblock_kernel_sizes: List[int],
+        resblock_dilation_sizes: List[List[int]],
+        mel_dim: int = 80,
+        add_ups_out_pad: bool = True,
+        channels: int = 1,
+        dimension: int = 128,
+        n_filters: int = 32,
+        ratios: List[int] = [8, 5, 4, 2],
+        norm: Literal[
+            "none", "weight_norm", "spectral_norm", "time_group_norm"
+        ] = "none",
+        norm_params: Dict[str, Any] = {},
+        kernel_size: int = 7,
+        last_kernel_size: int = 7,
+        residual_kernel_size: int = 3,
+        causal: bool = False,
+        pad_mode: str = "constant",
+        true_skip: bool = True,
+        compress: int = 2,
+        lstm: int = 0,
+        disable_norm_outer_blocks: int = 0,
+        trim_right_ratio: float = 1.0,
+        gcmvn_mean: Optional[List[float]] = None,
+        gcmvn_std: Optional[List[float]] = None,
+        device: Optional[Device] = None,
+        dtype: Optional[DataType] = None,
+    ):
+        super().__init__()
+        self.encoder_frontend = encoder_frontend
+        self.encoder = encoder
+        self.decoder_frontend = decoder_frontend
+        self.decoder = decoder
+        self.final_proj = final_proj
+        mult = 1
+        stream_layers: List[Module] = [
+            StreamableConv1d(
+                channels,
+                mult * n_filters,
+                kernel_size,
+                norm="none" if disable_norm_outer_blocks >= 1 else norm,
+                norm_kwargs=norm_params,
+                causal=causal,
+                pad_mode=pad_mode,
+                activation=Tanh(),
+                device=device,
+                dtype=dtype,
+            )
+        ]
+        # Downsample to from audio scale
+        for i, ratio in enumerate(list(reversed(ratios))):
+            block_norm = "none" if disable_norm_outer_blocks >= i + 2 else norm
+            stream_layers.append(
+                StreamableResnetBlock(
+                    mult * n_filters,
+                    kernel_sizes=[residual_kernel_size, 1],
+                    dilations=[1, 1],
+                    norm=block_norm,
+                    norm_params=norm_params,
+                    causal=causal,
+                    pad_mode=pad_mode,
+                    compress=compress,
+                    true_skip=true_skip,
+                    device=device,
+                    dtype=dtype,
+                )
+            )
+            stream_layers.append(ELU(**ELU_PARAMS))
+            stream_layers.append(
+                StreamableConv1d(
+                    mult * n_filters,
+                    mult * n_filters * 2,
+                    kernel_size=ratio * 2,
+                    stride=ratio,
+                    norm=block_norm,
+                    norm_kwargs=norm_params,
+                    causal=causal,
+                    pad_mode=pad_mode,
+                    device=device,
+                    dtype=dtype,
+                )
+            )
+            mult *= 2
+
+        stream_layers.append(StreamableLSTM(mult * n_filters, num_layers=lstm))
+        stream_layers.append(ELU(**ELU_PARAMS))
+        n_blocks = len(ratios) + 2
+        stream_layers.append(
+            StreamableConv1d(
+                mult * n_filters,
+                dimension,
+                last_kernel_size,
+                norm="none" if disable_norm_outer_blocks == n_blocks else norm,
+                norm_kwargs=norm_params,
+                causal=causal,
+                pad_mode=pad_mode,
+                device=device,
+                dtype=dtype,
+            )
+        )
+        stream_layers.append(
+            StreamableConv1d(
+                dimension,
+                mult * n_filters,
+                kernel_size,
+                norm="none" if disable_norm_outer_blocks == n_blocks else norm,
+                norm_kwargs=norm_params,
+                causal=causal,
+                pad_mode=pad_mode,
+                device=device,
+                dtype=dtype,
+            )
+        )
+        stream_layers.append(
+            StreamableLSTM(
+                mult * n_filters, num_layers=lstm, device=device, dtype=dtype
+            )
+        )
+
+        # resample back to raw audio scale
+        for i, ratio in enumerate(ratios):
+            block_norm = (
+                "none" if disable_norm_outer_blocks >= n_blocks - (i + 1) else norm
+            )
+            stream_layers.append(ELU(**ELU_PARAMS))
+            stream_layers.append(
+                StreamableConvTranspose1d(
+                    mult * n_filters,
+                    mult * n_filters // 2,
+                    kernel_size=ratio * 2,
+                    stride=ratio,
+                    norm=block_norm,
+                    norm_kwargs=norm_params,
+                    causal=causal,
+                    trim_right_ratio=trim_right_ratio,
+                    device=device,
+                    dtype=dtype,
+                )
+            )
+            stream_layers.append(
+                StreamableResnetBlock(
+                    mult * n_filters // 2,
+                    kernel_sizes=[residual_kernel_size, 1],
+                    dilations=[1, 1],
+                    norm=block_norm,
+                    norm_params=norm_params,
+                    activation_params=ELU_PARAMS,
+                    causal=causal,
+                    pad_mode=pad_mode,
+                    compress=compress,
+                    true_skip=true_skip,
+                    device=device,
+                    dtype=dtype,
+                )
+            )
+            mult //= 2
+
+        stream_layers.append(ELU(**ELU_PARAMS))
+        stream_layers.append(
+            StreamableConv1d(
+                n_filters,
+                channels,
+                last_kernel_size,
+                norm="none" if disable_norm_outer_blocks >= 1 else norm,
+                norm_kwargs=norm_params,
+                causal=causal,
+                pad_mode=pad_mode,
+                device=device,
+                dtype=dtype,
+            )
+        )
+        self.n_streams = len(stream_layers)
+        chunk_size = self.n_streams // 4
+        stream_idx = 0
+
+        self.pn_layers = pn_layers
+        self.layers = ModuleList()
+        assert pn_kernel_size % 2 == 1
+        for i in range(pn_layers):
+            cur_layers = (
+                [
+                    Conv1d(
+                        mel_dim if i == 0 else pn_n_channels,
+                        pn_n_channels if i < pn_layers - 1 else mel_dim,
+                        kernel_size=pn_kernel_size,
+                        padding="same",
+                        device=device,
+                        dtype=dtype,
+                    ),
+                    BatchNorm1d(
+                        pn_n_channels if i < pn_layers - 1 else mel_dim,
+                        device=device,
+                        dtype=dtype,
+                    ),
+                ]
+                + ([Tanh()] if i < pn_layers - 1 else [])
+                + [Dropout(pn_dropout)]
+            )
+            self.layers.append(Sequential(*cur_layers))
+        self.reset_parameters()
+        self.layers.extend(stream_layers[:chunk_size])
+        stream_idx += chunk_size
+        self.layers.append(
+            weight_norm(
+                Conv1d(
+                    mel_dim if mel_dim is not None else 80,
+                    upsample_initial_channel,
+                    7,
+                    1,
+                    padding="same",
+                    device=device,
+                    dtype=dtype,
+                )
+            )
+        )
+        self.layers.extend(stream_layers[stream_idx : stream_idx + chunk_size])  # noqa
+        stream_idx += chunk_size
+
+        self.num_kernels = len(resblock_kernel_sizes)
+        self.num_upsamples = len(upsample_rates)
+        ups = ModuleList()
+        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+            out_pad = u % 2 if add_ups_out_pad else 0
+            ups.append(
+                weight_norm(
+                    ConvTranspose1d(
+                        upsample_initial_channel // (2**i),
+                        upsample_initial_channel // (2 ** (i + 1)),
+                        k,
+                        u,
+                        padding=(k - u) // 2 + out_pad,
+                        output_padding=out_pad,
+                        device=device,
+                        dtype=dtype,
+                    )
+                )
+            )
+        ups.apply(init_weights)
+        self.layers.extend(ups)
+        self.layers.extend(stream_layers[stream_idx : stream_idx + chunk_size])  # noqa
+        stream_idx += chunk_size
+
+        for i in range(self.num_upsamples):
+            ch = upsample_initial_channel // (2 ** (i + 1))
+            for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
+                self.layers.append(
+                    ResBlock(
+                        ch,
+                        k,
+                        d,
+                    ).to(device, dtype=dtype)
+                )
+        self.layers.extend(stream_layers[stream_idx:])
+
+        conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
+        conv_post.apply(init_weights)
+        self.layers.append(conv_post)
+        for u, k in zip(upsample_rates, upsample_kernel_sizes):
+            assert k == 2 * u, (k, u)
+
+        mean = torch.zeros((mel_dim,), dtype=torch.float)
+        scale = torch.zeros((mel_dim,), dtype=torch.float)
+        self.register_buffer("mean", mean)
+        self.register_buffer("scale", scale)
+
+        self.gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype)
+        self.gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype)
+
+    def reset_parameters(self) -> None:
+        for i in range(self.pn_layers):
+            init.xavier_uniform_(
+                self.layers[i][0].weight,
+                init.calculate_gain("tanh" if i < self.pn_layers - 1 else "linear"),
+            )
+
+    def gcmvn_denormalize(self, x: torch.Tensor) -> torch.Tensor:
+        if self.gcmvn_mean is None or self.gcmvn_std is None:
+            raise ValueError("gcmvn_mean is not set")
+
+        assert (
+            x.ndim == 3
+            and x.shape[2] == self.gcmvn_mean.shape[0]
+            and x.shape[2] == self.gcmvn_std.shape[0]
+        )
+        gcmvn_mean = self.gcmvn_mean.to(x)
+        gcmvn_std = self.gcmvn_std.to(x)
+        x = x * gcmvn_std.view(1, 1, -1).expand_as(x)  # type: ignore[attr-defined]
+        return x + gcmvn_mean.view(1, 1, -1).expand_as(x)  # type: ignore[attr-defined,no-any-return]
+
+    def forward(
+        self,
+        seqs: torch.Tensor,
+        tgt_lang: str,
+        prosody_input_seqs: torch.Tensor,
+        padding_mask: Optional[PaddingMask] = None,
+        prosody_padding_mask: Optional[PaddingMask] = None,
+        durations: Optional[torch.Tensor] = None,
+        duration_factor: float = 1.0,
+        min_duration: int = 0,
+        normalize_before: bool = True,
+    ) -> torch.Tensor:
+        # Here we are adding batch dimension for the pretssel
+        if seqs.ndim < 3:
+            seqs = seqs.unsqueeze(0)
+        if prosody_input_seqs.ndim < 3:
+            prosody_input_seqs = prosody_input_seqs.unsqueeze(0)
+        seqs, cond_embs = self.encoder_frontend(
+            seqs,
+            padding_mask,
+            prosody_input_seqs,
+            prosody_padding_mask,
+            tgt_lang,
+        )
+        seqs, padding_mask = self.encoder(seqs, padding_mask, cond_embs)
+        seqs, padding_mask = self.decoder_frontend(
+            seqs, padding_mask, durations, duration_factor, min_duration, cond_embs
+        )
+        seqs, padding_mask = self.decoder(seqs, padding_mask, cond_embs)
+        seqs = self.final_proj(seqs)
+
+        pn = seqs.transpose(1, 2)  # B x T x C -> B x C x T
+        for i in range(self.pn_layers):
+            pn = self.layers[i](pn)
+        pn = pn.transpose(1, 2)
+
+        x = seqs + pn
+        x = self.gcmvn_denormalize(x).squeeze(0)
+        if normalize_before:
+            x = (x - self.mean) / self.scale
+
+        x = x.transpose(1, 0).unsqueeze(0)
+        chunk_size = self.n_streams // 4
+        x = self.layers[self.pn_layers + chunk_size](x)
+        for i in range(self.num_upsamples):
+            x = F.leaky_relu(x, LRELU_SLOPE)
+            x = self.layers[i + self.pn_layers + 1 + 2 * chunk_size](x)
+            xs = None
+            for j in range(self.num_kernels):
+                if xs is None:
+                    xs = self.layers[
+                        i * self.num_kernels
+                        + j
+                        + self.pn_layers
+                        + 3 * chunk_size
+                        + self.num_upsamples
+                        + 1
+                    ](x)
+                else:
+                    xs += self.layers[
+                        i * self.num_kernels
+                        + j
+                        + self.pn_layers
+                        + 3 * chunk_size
+                        + self.num_upsamples
+                        + 1
+                    ](x)
+            x = xs / self.num_kernels  # type: ignore
+        x = F.leaky_relu(x)
+        x = self.layers[
+            self.pn_layers
+            + self.n_streams
+            + self.num_upsamples * (1 + self.num_kernels)
+            + 1
+        ](x)
+        skip_output = x
+        h = skip_output
+
+        for i1 in range(self.pn_layers, self.pn_layers + chunk_size):
+            h = self.layers[i1](h)
+        i1 += 2
+        for i2 in range(i1, i1 + chunk_size):
+            h = self.layers[i2](h)
+        i2 = i2 + self.num_upsamples + 1
+
+        for i3 in range(i2, i2 + chunk_size):
+            h = self.layers[i3](h)
+        i3 = i3 + (self.num_upsamples * self.num_kernels) + 1
+        for i4 in range(i3, i3 + chunk_size):
+            h = self.layers[i4](h)
+        h = h[:, :, : x.size(-1)]
+
+        h += torch.tanh(skip_output).squeeze(0)
+        return h
+
+    def remove_weight_norm(self) -> None:
+        i = self.pn_layers + 1
+        for j in range(self.num_upsamples):
+            remove_weight_norm(self.layers[i + j])
+        for k in range(self.num_upsamples * self.num_kernels):
+            self.layers[i + j + k + 1].remove_weight_norm()
+        remove_weight_norm(self.layers[self.pn_layers])
+        remove_weight_norm(
+            self.layers[
+                self.pn_layers + 1 + self.num_upsamples * (1 + self.num_kernels)
+            ]
+        )

+ 1 - 1
src/seamless_communication/models/unity/builder.py

@@ -26,7 +26,7 @@ from fairseq2.nn.transformer import (
 from fairseq2.typing import DataType, Device, override
 from fairseq2.typing import DataType, Device, override
 from torch.nn import GELU, ReLU
 from torch.nn import GELU, ReLU
 
 
-from seamless_communication.models.pretssel import (
+from seamless_communication.models.generator.ecapa_tdnn_builder import (
     EcapaTDNNBuilder,
     EcapaTDNNBuilder,
     EcapaTDNNConfig,
     EcapaTDNNConfig,
     ecapa_tdnn_archs,
     ecapa_tdnn_archs,

+ 6 - 2
src/seamless_communication/models/unity/loader.py

@@ -8,7 +8,7 @@ from typing import Any, Dict, List, Mapping, Tuple, Union
 
 
 import torch
 import torch
 from fairseq2.assets import AssetStore, asset_store, download_manager
 from fairseq2.assets import AssetStore, asset_store, download_manager
-from fairseq2.assets.card import AssetCard
+from fairseq2.assets.card import AssetCard, AssetCardFieldNotFoundError
 from fairseq2.models.nllb import NllbConfig
 from fairseq2.models.nllb import NllbConfig
 from fairseq2.models.nllb.loader import NllbTokenizerLoader
 from fairseq2.models.nllb.loader import NllbTokenizerLoader
 from fairseq2.models.utils import ConfigLoader, ModelLoader
 from fairseq2.models.utils import ConfigLoader, ModelLoader
@@ -459,7 +459,11 @@ class GcmvnStatsLoader:
         else:
         else:
             card = self.asset_store.retrieve_card(model_name_or_card)
             card = self.asset_store.retrieve_card(model_name_or_card)
 
 
-        gcmvn_stats: Dict[str, List[float]] = card.field("gcmvn_stats").as_(dict)
+        try:
+            gcmvn_stats: Dict[str, List[float]] = card.field("gcmvn_stats").as_(dict)
+        except AssetCardFieldNotFoundError:
+            model_override = card.field("model_config").as_(dict)
+            gcmvn_stats = model_override["gcmvn_stats"]
 
 
         return gcmvn_stats["mean"], gcmvn_stats["std"]
         return gcmvn_stats["mean"], gcmvn_stats["std"]
 
 

+ 1 - 1
src/seamless_communication/models/unity/model.py

@@ -19,7 +19,7 @@ from overrides import final as finaloverride
 from torch import Tensor
 from torch import Tensor
 from torch.nn import Module
 from torch.nn import Module
 
 
-from seamless_communication.models.pretssel.ecapa_tdnn import ECAPA_TDNN
+from seamless_communication.models.generator.ecapa_tdnn import ECAPA_TDNN
 from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
 from seamless_communication.models.unity.fft_decoder import FeedForwardTransformer
 from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
 from seamless_communication.models.unity.nar_decoder_frontend import NARDecoderFrontend
 
 

+ 152 - 0
tests/integration/models/test_watermarked_vocoder.py

@@ -0,0 +1,152 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import sys
+from typing import cast, List, Final, Optional
+from anyio import Path
+import torch
+from fairseq2.typing import Device
+from fairseq2.data import Collater, SequenceData
+from fairseq2.data.audio import AudioDecoderOutput
+from torch.nn import Module
+
+from seamless_communication.inference.pretssel_generator import PretsselGenerator
+from seamless_communication.models.unity.loader import load_gcmvn_stats
+from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
+from tests.common import (
+    assert_close,
+    convert_to_collated_fbank,
+)
+
+
+N_MEL_BINS = 80
+
+# fmt: off
+REF_FRA_UNITS: Final = [8976, 6589, 6589, 5736, 7542, 6515, 1240, 8335, 2381, 1076, 1076, 3380, 4085, 8207, 7957, 4446, 2641, 2544, 5552, 5529, 6319, 2779, 2890, 2890, 3229, 3303, 9751, 1979, 664, 1859, 1302, 528, 1303, 9543, 5770, 3532, 1286, 1286, 1727, 9287, 5248, 5586, 594, 3385, 2613, 1717, 7529, 7634, 931, 1602, 4512, 850, 2748, 5056, 1086, 2320, 2320, 9320, 3223, 5592, 1122, 419, 24, 4126, 5200, 2712, 9549, 8676, 8676, 3443, 7598, 7598, 2200, 2745, 1215, 118, 3840, 2703, 1616, 8788, 1240, 3349, 4890, 2756, 166, 9574, 9773, 5887, 2516, 9332, 6092, 3377, 4334, 3127, 3127, 3127, 944, 3089, 5947, 6572, 6572, 7561, 4358, 4358, 4358, 8124, 5549, 9275, 82, 8830, 8830, 5949, 22, 6729, 6878, 3817, 1871, 6092, 1441, 3127, 3928, 8254, 7984, 1116, 2796, 1806, 3710, 797, 9269, 576, 576, 2020, 137, 6624, 3815, 8690, 3634, 6036, 3530, 8719, 3458, 138, 8745, 5233, 2235, 8580, 8580, 6831, 2709, 7136, 9693, 3437, 3437, 3238, 4368, 2321, 2321, 391, 391, 4976, 8622, 6722, 3864, 9113, 9113, 7222, 7222, 7937, 999, 1286, 1286, 7789, 9396, 9603, 6690, 5233, 2235, 618, 8830, 6954, 3668, 4302, 596, 1934, 2886, 2704, 9097, 4161, 458, 4147, 9245, 9245, 3127, 3127, 944, 9676, 9676, 3468, 270, 270, 4608, 5549, 4182, 102, 8568, 1286, 1286, 5087, 817, 4153, 207, 207, 3763, 6415, 5188, 6010, 554, 753, 9953, 5104, 3828, 1879, 995, 9683, 6932, 3644, 2683, 9335, 183, 5525, 7023, 9568, 6222, 6315, 676, 3443, 6971, 2084, 999, 1286, 1286, 9620, 9620, 1048, 5577, 9328, 4963, 1364, 8328, 4573, 4573, 7917, 7917, 560, 2020, 4923, 137, 9542, 5832, 9775, 4780, 9400, 2745, 2745, 8984, 628, 8834, 6932, 3817, 8312, 5393, 458, 4147, 9191, 2225, 2759, 8980, 2351, 193, 1476, 9347, 3063, 2076, 3641, 1614, 9832, 3554, 8197, 5589, 5589, 7306, 184, 1708, 2954, 2954, 3485, 3485, 7665, 8909, 5405, 3590, 3590, 3446, 6442, 6442, 2802, 5549, 3791]
+# fmt: on
+
+
+def load_watermarking_model() -> Optional[Module]:
+    import importlib.util
+
+    # Run in CPU mode until pretssel inconsistent behavious is fixed
+    device = Device("cpu")
+    dtype = torch.float32
+    wm_py_file = Path(__file__).parents[3] / "scripts/watermarking/watermarking.py"
+    assert wm_py_file.is_file()
+    wm_spec = importlib.util.spec_from_file_location("watermark.f1", wm_py_file)
+    assert wm_spec, f"Module not found: {wm_py_file}"
+    wm_py_module = importlib.util.module_from_spec(wm_spec)
+    assert wm_py_module, f"Invalid Python module file: {wm_py_file}"
+    sys.modules["watermark.f1"] = wm_py_module
+    assert wm_spec.loader, f"Module cannot be loaded from {wm_py_file}"
+    wm_spec.loader.exec_module(wm_py_module)
+
+    return cast(Module, wm_py_module.model_from_checkpoint(device=device, dtype=dtype))
+
+
+def test_pretssel_vocoder_watermarking(
+    example_rate16k_audio: AudioDecoderOutput,
+) -> None:
+    """
+    Test that the watermarked pretssel vocoder generates the same output
+    as the non-watermarked (pretssel_generator)
+    """
+    audio = example_rate16k_audio
+
+    # Run in CPU mode until pretssel inconsistent behavious is fixed
+    device = Device("cpu")
+    dtype = torch.float32
+    audio["waveform"] = audio["waveform"].to(device, dtype=dtype)
+    feat = convert_to_collated_fbank(audio, dtype=dtype)["seqs"][0]
+    feat = feat.to(device, dtype=dtype)
+    # Run the watermarked vocoding
+    # TODO: Build a generator API for the watermarked vocoder
+    vocoder = load_pretssel_vocoder_model(
+        "vocoder_pretssel", device=device, dtype=dtype
+    )
+
+    units = torch.tensor(REF_FRA_UNITS, device=device, dtype=torch.int64)
+
+    # adjust the control symbols for the embedding
+    units += 4
+
+    # eos_idx = 2 in the VocabularyInfo setting for base pretssel_vocoder
+    unit_eos_token = torch.tensor([2], device=device)
+    units = torch.cat([units, unit_eos_token], dim=0)
+    units, duration = torch.unique_consecutive(units, return_counts=True)
+
+    # adjust for the last eos token
+    duration[-1] = 0
+    duration *= 2
+
+    # bos_idx=0 in base VocabularyInfo
+    duration_collate = Collater(pad_value=0)
+    duration_seqs = duration_collate(duration)
+
+    with torch.no_grad():
+        vocoder.eval()
+        wav_wm = vocoder(
+            seqs=units,
+            tgt_lang="fra",
+            prosody_input_seqs=feat,
+            durations=duration_seqs["seqs"],
+            normalize_before=True,
+        )
+
+    # torchaudio.save("wm.wav", wav_wm.squeeze(0).float().cpu(), sample_rate=16000)
+
+    # Run the non-watermarked vocoder using pretssel generator
+    gcmvn_mean, gcmvn_std = load_gcmvn_stats("pretssel_v1")
+    gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype)  # type: ignore[assignment]
+    gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype)  # type: ignore[assignment]
+
+    generator = PretsselGenerator(
+        "seamless_expressivity",
+        "vocoder_mel_24khz",
+        "pretssel_v1",
+        gcmvn_mean=gcmvn_mean,  # type: ignore[arg-type]
+        gcmvn_std=gcmvn_std,  # type: ignore[arg-type]
+        device=device,
+        dtype=dtype,
+    )
+
+    # PretsselGenerator expects a batch of units
+    unit_list: List[List[int]] = [REF_FRA_UNITS]
+    prosody_input_seqs = SequenceData(
+        is_ragged=False,
+        seqs=feat.unsqueeze(0),  # add batch dim
+        seq_lens=torch.tensor([feat.size(0)]),
+    )
+    speech_output = generator.predict(
+        unit_list,
+        tgt_lang="fra",
+        prosody_encoder_input=prosody_input_seqs,
+    )
+    wav = speech_output.audio_wavs[0].unsqueeze(0)
+
+    # torchaudio.save("mel.wav", wav.float().cpu(), sample_rate=16000)
+
+    # Run the watermark model separately after the PretsselGenerator
+    watermarker = load_watermarking_model()
+    wm = watermarker.get_watermark(wav)  # type: ignore
+    wav_wm_hat = wav + wm
+
+    # Test that the watermark is detecte-able
+    detection = watermarker.detect_watermark(wav_wm)  # type: ignore
+    assert torch.all(detection[:, 1, :] > 0.5)
+
+    # Remove the batch and compare parity on the overlapping frames
+    wav_wm = wav_wm.squeeze(0)
+    wav_wm_hat = wav_wm_hat.squeeze(0)
+
+    nframes = min(wav_wm_hat.size(1), wav_wm.size(1))
+    assert_close(
+        wav_wm[:, :nframes],
+        wav_wm_hat[:, :nframes],
+        atol=0.0,
+        rtol=5.0,
+    )