浏览代码

Update watermarked vocoder code and checkpoint (#156)

* add watermark checkpoint argument

* update checkpoint loading logic in watermarking

* update chkpt

* update default wm ckpt

* revert the linting in test_expressivity

* update checkpoint

* 1) update pretssel vocoder and pretssel inference with batch prosody encoder; 2) down-tune watermark

* make watermarked PretsselGenerator a separate API

* address Yilin's and Pierre's comments

---------

Co-authored-by: Tuan Tran <tuantran@devfair0436.h2.fair>
Tuan Tran 1 年之前
父节点
当前提交
300d14efc5

+ 14 - 7
scripts/watermarking/compile_chkpt.py

@@ -128,7 +128,9 @@ def wm_key_map() -> Mapping[Any, Any]:
     }
 
 
-def combine_chkpts(pretssel_file: str, vocoder_file: str, out_path: str) -> None:
+def combine_chkpts(
+    pretssel_file: str, vocoder_file: str, wm_file: str, out_path: str
+) -> None:
     """Combine the pretssel and melhifigan into one model"""
     pretssel_chkpt = load_checkpoint(pretssel_file)
     pretssel_chkpt = convert_fairseq_checkpoint(pretssel_chkpt, pretssel_key_map())
@@ -136,10 +138,10 @@ def combine_chkpts(pretssel_file: str, vocoder_file: str, out_path: str) -> None
     vocoder_chkpt = load_checkpoint(vocoder_file)
     vocoder_chkpt = convert_fairseq_checkpoint(vocoder_chkpt, vocoder_key_map())
 
-    wm_ckpt = load_checkpoint(
-        "/large_experiments/seamless/nllb/watermarking/checkpoints/ckpt_e9d0008c.th",
-    )
-    # wm_ckpt is not a fairseq2 checkpoint so we have to handle it differently
+    wm_ckpt = load_checkpoint(wm_file)
+    # some wm checkpoints are not a fairseq2 checkpoint, so we have to inspect it differently
+    if "model" in wm_ckpt:
+        wm_ckpt = wm_ckpt["model"]
     wm_ckpt = convert_model_state_dict(wm_ckpt, wm_key_map())
 
     # Merge the state dicts
@@ -170,7 +172,6 @@ def combine_chkpts(pretssel_file: str, vocoder_file: str, out_path: str) -> None
         if key in state_dict:
             del state_dict[key]
 
-    out_path = "/large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-final.pt"
     model_mapping_metafile = Path(out_path).with_suffix(".arch")
     with open(model_mapping_metafile, "w", encoding="utf-8") as o:
         o.write(vocoder_key_map.__doc__)  # type: ignore
@@ -196,6 +197,12 @@ if __name__ == "__main__":
         type=str,
         help="Path to the mel-vocoder checkpoint",
     )
+    parser.add_argument(
+        "--wm",
+        default="/checkpoint/hadyelsahar/experiments/audiocraft/outputs/xps/BA6f05be46/checkpoint.th",
+        type=str,
+        help=""
+    )
     parser.add_argument(
         "--output",
         default="/large_experiments/seamless/workstream/expressivity/oss/checkpoints/pretssel_melhifigan_wm-final.pt",
@@ -205,4 +212,4 @@ if __name__ == "__main__":
     )
     # fmt: on
     args = parser.parse_args()
-    combine_chkpts(args.pretssel, args.vocoder, args.output)
+    combine_chkpts(args.pretssel, args.vocoder, args.wm, args.output)

+ 1 - 4
scripts/watermarking/seamlesswatermark.yaml

@@ -7,10 +7,7 @@
 
 name: seamlesswatermark
 model_type: seanet
-checkpoint: "/large_experiments/seamless/nllb/watermarking/checkpoints/ckpt_e9d0008c.th"
-watermarker_model:
-  channels: 1
-  sample_rate: 16000
+checkpoint: "/checkpoint/hadyelsahar/experiments/audiocraft/outputs/xps/6f05be46/checkpoint.th"
 seanet:
   activation: ELU
   activation_params:

+ 22 - 21
scripts/watermarking/watermarking.py

@@ -10,7 +10,7 @@
 import math
 from argparse import ArgumentParser, ArgumentTypeError
 from pathlib import Path
-from typing import Any, Dict, Union, cast
+from typing import Any, Dict, Optional, Union, cast
 
 import audiocraft
 import omegaconf
@@ -61,12 +61,8 @@ class Watermarker(nn.Module):
         encoder (nn.Module): Watermark Encoder.
         decoder (nn.Module): Watermark Decoder.
         detector (nn.Module): Watermark Detector.
-        sample_rate (int): Audio sample rate.
-        channels (int): Number of audio channels.
     """
 
-    sample_rate: int = 0
-    channels: int = 0
     encoder: SEANetEncoder
     decoder: SEANetEncoder
     detector: SEANetEncoderKeepDimension
@@ -76,15 +72,11 @@ class Watermarker(nn.Module):
         encoder: SEANetEncoder,
         decoder: SEANetEncoder,
         detector: SEANetEncoderKeepDimension,
-        sample_rate: int,
-        channels: int,
     ):
         super().__init__()
         self.encoder = encoder
         self.decoder = decoder
         self.detector = detector
-        self.sample_rate = sample_rate
-        self.channels = channels
 
     def get_watermark(self, x: torch.Tensor) -> torch.Tensor:
         """
@@ -122,8 +114,7 @@ class Watermarker(nn.Module):
 
 
 def model_from_checkpoint(
-    checkpoint_path: Union[Path, str] = Path(__file__).parent
-    / "seamlesswatermark.yaml",
+    config_file: Union[Path, str] = "seamlesswatermark.yaml",
     device: Union[torch.device, str] = "cpu",
     dtype: DataType = torch.float32,
 ) -> Watermarker:
@@ -137,8 +128,9 @@ def model_from_checkpoint(
     >>> wav, _ = torchaudio.load("random.wav")
     >>> wav = wav.unsqueeze(0)  # add bsz dimension
 
-    # code starts here
-    >>> model = model_from_checkpoint(cfg, device = wav.device)
+    >>> model = model_from_config(cfg, device = wav.device)
+    # Other way is to load directly from the checkpoint
+    >>> model = model_from_checkpoint(checkpoint_path, device = wav.device)
 
     >>> watermark = model.get_watermark(wav)
 
@@ -157,8 +149,13 @@ def model_from_checkpoint(
     Returns:
         Watermarker: An instance of the Watermarker model loaded from the checkpoint.
     """
-    cfg = omegaconf.OmegaConf.load(checkpoint_path)
+    config_path = Path(__file__).parent / config_file
+    cfg = omegaconf.OmegaConf.load(config_path)
     state: Dict[str, Any] = torch.load(cfg["checkpoint"])
+    if "model" in state and "xp.cfg" in state:
+        cfg = omegaconf.OmegaConf.create(state["xp.cfg"])
+        omegaconf.OmegaConf.resolve(cfg)
+        state = state["model"]
     watermarking_model = get_watermarking_model(cfg)
     watermarking_model.load_state_dict(state)
     watermarking_model = watermarking_model.to(device, dtype=dtype)
@@ -167,10 +164,9 @@ def model_from_checkpoint(
 
 
 def get_watermarking_model(cfg: omegaconf.DictConfig) -> Watermarker:
-    kwargs = dict_from_config(getattr(cfg, "watermarker_model"))
     encoder, decoder = get_encodec_autoencoder(cfg)
     detector = get_detector(cfg)
-    return Watermarker(encoder, decoder, detector, **kwargs)
+    return Watermarker(encoder, decoder, detector)
 
 
 def get_encodec_autoencoder(cfg: omegaconf.DictConfig):
@@ -209,10 +205,8 @@ def parse_device_arg(value: str) -> Device:
 
 
 if __name__ == "__main__":
-    """
-    Example usage:
-    python watermarking.py --device cuda:0 detect [file.wav]
-    """
+    # Example usage: python watermarking.py --device cuda:0 detect [file.wav]
+
     parser = ArgumentParser(description="Handle the watermarking for audios")
     parser.add_argument(
         "--device",
@@ -220,6 +214,12 @@ if __name__ == "__main__":
         type=parse_device_arg,
         help="device on which to run tests (default: %(default)s)",
     )
+    parser.add_argument(
+        "--model-file",
+        default="seamlesswatermark.yaml",
+        type=str,
+        help="path to a config or checkpoint file (default: %(default)s)",
+    )
     sub_parser = parser.add_subparsers(title="actions", dest="sub_cmd")
     detect_parser = sub_parser.add_parser("detect")
     wm_parser = sub_parser.add_parser("wm")
@@ -228,9 +228,10 @@ if __name__ == "__main__":
     args = parser.parse_args()
 
     if args.sub_cmd == "detect":
-        model = model_from_checkpoint(device=args.device)
+        model = model_from_checkpoint(args.model_file, device=args.device)
         wav, _ = torchaudio.load(args.file)
         wav = wav.unsqueeze(0)
         wav = wav.to(args.device)
         detection = model.detect_watermark(wav)
         print(detection[:, 1, :])
+        print(torch.count_nonzero(torch.gt(detection[:, 1, :], 0.5)))

+ 1 - 1
src/seamless_communication/cli/expressivity/evaluate/evaluate.py

@@ -196,7 +196,7 @@ def run_eval(
     output_path.mkdir(parents=True, exist_ok=True)
 
     if ctx.output_modality == Modality.SPEECH:
-        waveforms_dir = output_path / f"waveform"
+        waveforms_dir = output_path / "waveform"
         waveforms_dir.mkdir(parents=True, exist_ok=True)
 
     hyps = []

+ 14 - 99
src/seamless_communication/cli/expressivity/evaluate/pretssel_inference.py

@@ -4,33 +4,34 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
+from typing import Optional
 import argparse
 import contextlib
 import logging
 from argparse import Namespace
-from dataclasses import dataclass
 from pathlib import Path
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from tqdm import tqdm
 
 import torch
 import torchaudio
-from fairseq2.assets import asset_store
-from fairseq2.assets.card import AssetCard
-from fairseq2.data import Collater, DataPipeline, FileMapper, SequenceData
+from fairseq2.data import (
+    Collater,
+    DataPipeline,
+    FileMapper,
+)
 from fairseq2.data.audio import (
     AudioDecoder,
     WaveformToFbankConverter,
     WaveformToFbankOutput,
 )
-from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
-from fairseq2.generation import SequenceGeneratorOptions
-from fairseq2.nn.padding import get_seqs_and_padding_mask
+from fairseq2.data.text import StrSplitter, read_text
 from fairseq2.typing import DataType, Device
 from sacrebleu.metrics import BLEU  # type: ignore[attr-defined]
 from torch import Tensor
-from torch.nn import Module
-from tqdm import tqdm
 
+from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import (
+    PretsselGenerator,
+)
 from seamless_communication.cli.m4t.evaluate.evaluate import (
     adjust_output_for_corrupted_inputs,
     count_lines,
@@ -40,11 +41,8 @@ from seamless_communication.cli.m4t.predict import (
     set_generation_opts,
 )
 from seamless_communication.inference import BatchedSpeechOutput, Translator
-from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
 from seamless_communication.models.unity import (
-    UnitTokenizer,
     load_gcmvn_stats,
-    load_unity_text_tokenizer,
     load_unity_unit_tokenizer,
 )
 
@@ -56,87 +54,8 @@ logging.basicConfig(
 logger = logging.getLogger(__name__)
 
 
-class PretsselGenerator(Module):
-    def __init__(
-        self,
-        pretssel_name_or_card: str,
-        unit_tokenizer: UnitTokenizer,
-        device: Device,
-        dtype: DataType = torch.float16,
-    ):
-        super().__init__()
-        # Load the model.
-        if device == torch.device("cpu"):
-            dtype = torch.float32
-
-        self.device = device
-        self.dtype = dtype
-
-        self.pretssel_model = load_pretssel_vocoder_model(
-            pretssel_name_or_card,
-            device=device,
-            dtype=dtype,
-        )
-        self.pretssel_model.eval()
-
-        vocoder_model_card = asset_store.retrieve_card(pretssel_name_or_card)
-        self.output_sample_rate = vocoder_model_card.field("sample_rate").as_(int)
-
-        self.unit_tokenizer = unit_tokenizer
-        self.unit_collate = Collater(pad_value=unit_tokenizer.vocab_info.pad_idx)
-        self.duration_collate = Collater(pad_value=0)
-
-    @torch.inference_mode()
-    def predict(
-        self,
-        units: List[List[int]],
-        tgt_lang: str,
-        prosody_encoder_input: SequenceData,
-    ) -> BatchedSpeechOutput:
-        audio_wavs = []
-        unit_eos_token = torch.tensor(
-            [self.unit_tokenizer.vocab_info.eos_idx],
-            device=self.device,
-        )
-
-        prosody_input_seqs = prosody_encoder_input["seqs"]
-        prosody_input_lens = prosody_encoder_input["seq_lens"]
-
-        for i, u in enumerate(units):
-            unit = torch.tensor(u).to(unit_eos_token)
-
-            # adjust the control symbols for the embedding
-            unit += 4
-            unit = torch.cat([unit, unit_eos_token], dim=0)
-
-            unit, duration = torch.unique_consecutive(unit, return_counts=True)
-
-            # adjust for the last eos token
-            duration[-1] = 0
-
-            duration *= 2
-
-            prosody_input_seq = prosody_input_seqs[i][: prosody_input_lens[i]]
-
-            audio_wav = self.pretssel_model(
-                unit,
-                tgt_lang,
-                prosody_input_seq,
-                durations=duration.unsqueeze(0),
-            )
-
-            audio_wavs.append(audio_wav)
-
-        return BatchedSpeechOutput(
-            units=units,
-            audio_wavs=audio_wavs,
-            sample_rate=self.output_sample_rate,
-        )
-
-
 def build_data_pipeline(
     args: Namespace,
-    text_tokenizer: TextTokenizer,
     device: Device,
     dtype: DataType,
     gcmvn_mean: Tensor,
@@ -232,22 +151,18 @@ def main() -> None:
         device = torch.device("cpu")
         dtype = torch.float32
 
-    text_tokenizer = load_unity_text_tokenizer(args.model_name)
     unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
 
     _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
     gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
     gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
 
-    pipeline = build_data_pipeline(
-        args, text_tokenizer, device, dtype, gcmvn_mean, gcmvn_std
-    )
+    pipeline = build_data_pipeline(args, device, dtype, gcmvn_mean, gcmvn_std)
 
     translator = Translator(
         args.model_name,
         vocoder_name_or_card=None,
         device=device,
-        text_tokenizer=text_tokenizer,
         dtype=dtype,
     )
 
@@ -261,7 +176,7 @@ def main() -> None:
 
     pretssel_generator = PretsselGenerator(
         args.vocoder_name,
-        unit_tokenizer=unit_tokenizer,
+        vocab_info=unit_tokenizer.vocab_info,
         device=device,
         dtype=dtype,
     )
@@ -272,7 +187,7 @@ def main() -> None:
     output_path = args.output_path / args.data_file.stem
     output_path.mkdir(parents=True, exist_ok=True)
 
-    waveforms_dir = output_path / f"waveform"
+    waveforms_dir = output_path / "waveform"
     waveforms_dir.mkdir(parents=True, exist_ok=True)
 
     hyps = []

+ 100 - 0
src/seamless_communication/cli/expressivity/evaluate/pretssel_inference_helper.py

@@ -0,0 +1,100 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import List
+
+import torch
+from torch.nn import Module
+
+from fairseq2.typing import DataType, Device
+
+from fairseq2.assets import asset_store
+from fairseq2.data import (
+    Collater,
+    SequenceData,
+    VocabularyInfo,
+)
+from fairseq2.nn.padding import get_seqs_and_padding_mask
+
+from seamless_communication.inference import BatchedSpeechOutput
+from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
+
+
+class PretsselGenerator(Module):
+    def __init__(
+        self,
+        pretssel_name_or_card: str,
+        vocab_info: VocabularyInfo,
+        device: Device,
+        dtype: DataType = torch.float16,
+    ):
+        super().__init__()
+        # Load the model.
+        if device == torch.device("cpu"):
+            dtype = torch.float32
+
+        self.device = device
+        self.dtype = dtype
+
+        self.pretssel_model = load_pretssel_vocoder_model(
+            pretssel_name_or_card,
+            device=device,
+            dtype=dtype,
+        )
+        self.pretssel_model.eval()
+
+        vocoder_model_card = asset_store.retrieve_card(pretssel_name_or_card)
+        self.output_sample_rate = vocoder_model_card.field("sample_rate").as_(int)
+
+        self.vocab_info = vocab_info
+        self.unit_collate = Collater(pad_value=vocab_info.pad_idx)
+        self.duration_collate = Collater(pad_value=0)
+        self.unit_eos_token = torch.tensor([vocab_info.eos_idx], device=device)
+
+    @torch.inference_mode()
+    def predict(
+        self,
+        units: List[List[int]],
+        tgt_lang: str,
+        prosody_encoder_input: SequenceData,
+    ) -> BatchedSpeechOutput:
+
+        units_batch, durations = [], []
+        for u in units:
+            unit = torch.tensor(u).to(self.unit_eos_token)
+
+            # adjust the control symbols for the embedding
+            unit += 4
+            unit = torch.cat([unit, self.unit_eos_token], dim=0)
+
+            unit, duration = torch.unique_consecutive(unit, return_counts=True)
+
+            # adjust for the last eos token
+            duration[-1] = 0
+
+            units_batch.append(unit)
+            durations.append(duration * 2)
+
+        speech_units = self.unit_collate(units_batch)
+        durations = self.duration_collate(durations)["seqs"]
+
+        units_tensor, unit_padding_mask = get_seqs_and_padding_mask(speech_units)
+        prosody_input_seqs, prosody_padding_mask = get_seqs_and_padding_mask(
+            prosody_encoder_input
+        )
+
+        audio_wavs = self.pretssel_model(
+            units_tensor,
+            tgt_lang,
+            prosody_input_seqs,
+            padding_mask=unit_padding_mask,
+            prosody_padding_mask=prosody_padding_mask,
+            durations=durations,
+        )
+        return BatchedSpeechOutput(
+            units=units,
+            audio_wavs=audio_wavs,
+            sample_rate=self.output_sample_rate,
+        )

+ 4 - 3
src/seamless_communication/inference/pretssel_generator.py

@@ -3,14 +3,14 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import List, Optional, Union
 
 import torch
 import torch.nn as nn
 from fairseq2.assets import asset_store
 from fairseq2.assets.card import AssetCard
 from fairseq2.data import Collater, SequenceData
-from fairseq2.nn.padding import PaddingMask, get_seqs_and_padding_mask
+from fairseq2.nn.padding import get_seqs_and_padding_mask
 from fairseq2.typing import DataType, Device
 from torch import Tensor
 
@@ -53,7 +53,8 @@ class PretsselGenerator(nn.Module):
             dtype=dtype,
         )
         self.pretssel_model.eval()
-
+        if isinstance(vocoder_name_or_card, AssetCard):
+            vocoder_name_or_card = vocoder_name_or_card.name
         vocoder_model_card = asset_store.retrieve_card(vocoder_name_or_card)
         self.output_sample_rate = vocoder_model_card.field("sample_rate").as_(int)
 

+ 63 - 59
src/seamless_communication/models/generator/vocoder.py

@@ -484,9 +484,9 @@ class PretsselVocoder(Module):
         duration_factor: float = 1.0,
         min_duration: int = 0,
         normalize_before: bool = True,
-    ) -> torch.Tensor:
+    ) -> List[torch.Tensor]:
         # Here we are adding batch dimension for the pretssel
-        if seqs.ndim < 3:
+        if seqs.ndim < 2:
             seqs = seqs.unsqueeze(0)
         if prosody_input_seqs.ndim < 3:
             prosody_input_seqs = prosody_input_seqs.unsqueeze(0)
@@ -510,63 +510,67 @@ class PretsselVocoder(Module):
         pn = pn.transpose(1, 2)
 
         x = seqs + pn
-        x = self.gcmvn_denormalize(x).squeeze(0)
-        if normalize_before:
-            x = (x - self.mean) / self.scale
-
-        x = x.transpose(1, 0).unsqueeze(0)
-        chunk_size = self.n_streams // 4
-        x = self.layers[self.pn_layers + chunk_size](x)
-        for i in range(self.num_upsamples):
-            x = F.leaky_relu(x, LRELU_SLOPE)
-            x = self.layers[i + self.pn_layers + 1 + 2 * chunk_size](x)
-            xs = None
-            for j in range(self.num_kernels):
-                if xs is None:
-                    xs = self.layers[
-                        i * self.num_kernels
-                        + j
-                        + self.pn_layers
-                        + 3 * chunk_size
-                        + self.num_upsamples
-                        + 1
-                    ](x)
-                else:
-                    xs += self.layers[
-                        i * self.num_kernels
-                        + j
-                        + self.pn_layers
-                        + 3 * chunk_size
-                        + self.num_upsamples
-                        + 1
-                    ](x)
-            x = xs / self.num_kernels  # type: ignore
-        x = F.leaky_relu(x)
-        x = self.layers[
-            self.pn_layers
-            + self.n_streams
-            + self.num_upsamples * (1 + self.num_kernels)
-            + 1
-        ](x)
-        skip_output = x
-        h = skip_output
-
-        for i1 in range(self.pn_layers, self.pn_layers + chunk_size):
-            h = self.layers[i1](h)
-        i1 += 2
-        for i2 in range(i1, i1 + chunk_size):
-            h = self.layers[i2](h)
-        i2 = i2 + self.num_upsamples + 1
-
-        for i3 in range(i2, i2 + chunk_size):
-            h = self.layers[i3](h)
-        i3 = i3 + (self.num_upsamples * self.num_kernels) + 1
-        for i4 in range(i3, i3 + chunk_size):
-            h = self.layers[i4](h)
-        h = h[:, :, : x.size(-1)]
-
-        h += torch.tanh(skip_output).squeeze(0)
-        return h
+        x = self.gcmvn_denormalize(x)
+
+        wavs = []
+        for idx, _x in enumerate(x):
+            _x = _x[: durations[idx].sum()]  # type: ignore[index]
+            if normalize_before:
+                _x = (_x - self.mean) / self.scale
+
+            _x = _x.transpose(1, 0).unsqueeze(0)
+            chunk_size = self.n_streams // 4
+            _x = self.layers[self.pn_layers + chunk_size](_x)
+            for i in range(self.num_upsamples):
+                _x = F.leaky_relu(_x, LRELU_SLOPE)
+                _x = self.layers[i + self.pn_layers + 1 + 2 * chunk_size](_x)
+                xs = None
+                for j in range(self.num_kernels):
+                    if xs is None:
+                        xs = self.layers[
+                            i * self.num_kernels
+                            + j
+                            + self.pn_layers
+                            + 3 * chunk_size
+                            + self.num_upsamples
+                            + 1
+                        ](_x)
+                    else:
+                        xs += self.layers[
+                            i * self.num_kernels
+                            + j
+                            + self.pn_layers
+                            + 3 * chunk_size
+                            + self.num_upsamples
+                            + 1
+                        ](_x)
+                _x = xs / self.num_kernels  # type: ignore
+            _x = F.leaky_relu(_x)
+            _x = self.layers[
+                self.pn_layers
+                + self.n_streams
+                + self.num_upsamples * (1 + self.num_kernels)
+                + 1
+            ](_x)
+            skip_output = _x
+            h = skip_output
+
+            for i1 in range(self.pn_layers, self.pn_layers + chunk_size):
+                h = self.layers[i1](h)
+            i1 += 2
+            for i2 in range(i1, i1 + chunk_size):
+                h = self.layers[i2](h)
+            i2 = i2 + self.num_upsamples + 1
+
+            for i3 in range(i2, i2 + chunk_size):
+                h = self.layers[i3](h)
+            i3 = i3 + (self.num_upsamples * self.num_kernels) + 1
+            for i4 in range(i3, i3 + chunk_size):
+                h = self.layers[i4](h)
+            h = h[:, :, : _x.size(-1)]
+
+            wavs.append(0.8 * h + torch.tanh(skip_output).squeeze(0))
+        return wavs
 
     def remove_weight_norm(self) -> None:
         i = self.pn_layers + 1

+ 1 - 3
src/seamless_communication/models/pretssel/pretssel_model.py

@@ -4,11 +4,9 @@
 # This source code is licensed under the license found in the
 # LICENSE file in the root directory of this source tree.
 
-from dataclasses import dataclass
-from typing import Dict, List, Optional, Tuple
+from typing import Dict, Optional, Tuple
 
 import torch
-from fairseq2.models.utils.arch_registry import ArchitectureRegistry
 from fairseq2.nn.embedding import Embedding, StandardEmbedding
 from fairseq2.nn.padding import PaddingMask
 from fairseq2.nn.position_encoder import PositionEncoder

+ 124 - 50
tests/integration/models/test_watermarked_vocoder.py

@@ -5,21 +5,30 @@
 # LICENSE file in the root directory of this source tree.
 
 import sys
+from argparse import Namespace
 from pathlib import Path
 from typing import Final, List, Optional, cast
+import pytest
 
 import torch
-from fairseq2.data import Collater, SequenceData
+from fairseq2.data import SequenceData, VocabularyInfo
 from fairseq2.data.audio import AudioDecoderOutput
 from fairseq2.typing import Device
 from torch.nn import Module
 
+from seamless_communication.inference import Translator
 from seamless_communication.inference.pretssel_generator import PretsselGenerator
-from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
-from seamless_communication.models.unity.loader import load_gcmvn_stats
+from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import (
+    PretsselGenerator as WatermarkedPretsselGenerator,
+)
+from seamless_communication.cli.expressivity.evaluate.pretssel_inference import (
+    build_data_pipeline,
+)
+from seamless_communication.models.unity import load_gcmvn_stats
 from tests.common import assert_close, convert_to_collated_fbank
 
 N_MEL_BINS = 80
+WM_WEIGHT = 0.8
 
 # fmt: off
 REF_FRA_UNITS: Final = [8976, 6589, 6589, 5736, 7542, 6515, 1240, 8335, 2381, 1076, 1076, 3380, 4085, 8207, 7957, 4446, 2641, 2544, 5552, 5529, 6319, 2779, 2890, 2890, 3229, 3303, 9751, 1979, 664, 1859, 1302, 528, 1303, 9543, 5770, 3532, 1286, 1286, 1727, 9287, 5248, 5586, 594, 3385, 2613, 1717, 7529, 7634, 931, 1602, 4512, 850, 2748, 5056, 1086, 2320, 2320, 9320, 3223, 5592, 1122, 419, 24, 4126, 5200, 2712, 9549, 8676, 8676, 3443, 7598, 7598, 2200, 2745, 1215, 118, 3840, 2703, 1616, 8788, 1240, 3349, 4890, 2756, 166, 9574, 9773, 5887, 2516, 9332, 6092, 3377, 4334, 3127, 3127, 3127, 944, 3089, 5947, 6572, 6572, 7561, 4358, 4358, 4358, 8124, 5549, 9275, 82, 8830, 8830, 5949, 22, 6729, 6878, 3817, 1871, 6092, 1441, 3127, 3928, 8254, 7984, 1116, 2796, 1806, 3710, 797, 9269, 576, 576, 2020, 137, 6624, 3815, 8690, 3634, 6036, 3530, 8719, 3458, 138, 8745, 5233, 2235, 8580, 8580, 6831, 2709, 7136, 9693, 3437, 3437, 3238, 4368, 2321, 2321, 391, 391, 4976, 8622, 6722, 3864, 9113, 9113, 7222, 7222, 7937, 999, 1286, 1286, 7789, 9396, 9603, 6690, 5233, 2235, 618, 8830, 6954, 3668, 4302, 596, 1934, 2886, 2704, 9097, 4161, 458, 4147, 9245, 9245, 3127, 3127, 944, 9676, 9676, 3468, 270, 270, 4608, 5549, 4182, 102, 8568, 1286, 1286, 5087, 817, 4153, 207, 207, 3763, 6415, 5188, 6010, 554, 753, 9953, 5104, 3828, 1879, 995, 9683, 6932, 3644, 2683, 9335, 183, 5525, 7023, 9568, 6222, 6315, 676, 3443, 6971, 2084, 999, 1286, 1286, 9620, 9620, 1048, 5577, 9328, 4963, 1364, 8328, 4573, 4573, 7917, 7917, 560, 2020, 4923, 137, 9542, 5832, 9775, 4780, 9400, 2745, 2745, 8984, 628, 8834, 6932, 3817, 8312, 5393, 458, 4147, 9191, 2225, 2759, 8980, 2351, 193, 1476, 9347, 3063, 2076, 3641, 1614, 9832, 3554, 8197, 5589, 5589, 7306, 184, 1708, 2954, 2954, 3485, 3485, 7665, 8909, 5405, 3590, 3590, 3446, 6442, 6442, 2802, 5549, 3791]
@@ -45,65 +54,40 @@ def load_watermarking_model() -> Optional[Module]:
     return cast(Module, wm_py_module.model_from_checkpoint(device=device, dtype=dtype))
 
 
+@pytest.mark.parametrize("sr", [16_000, 24_000])
 def test_pretssel_vocoder_watermarking(
-    example_rate16k_audio: AudioDecoderOutput,
+    example_rate16k_audio: AudioDecoderOutput, sr: int
 ) -> None:
     """
     Test that the watermarked pretssel vocoder generates the same output
     as the non-watermarked (pretssel_generator)
     """
-    audio = example_rate16k_audio
-
     # Run in CPU mode until pretssel inconsistent behavious is fixed
     device = Device("cpu")
     dtype = torch.float32
+
+    audio = example_rate16k_audio
     audio["waveform"] = audio["waveform"].to(device, dtype=dtype)
     feat = convert_to_collated_fbank(audio, dtype=dtype)["seqs"][0]
-    feat = feat.to(device, dtype=dtype)
-    # Run the watermarked vocoding
-    # TODO: Build a generator API for the watermarked vocoder
-    vocoder = load_pretssel_vocoder_model(
-        "vocoder_pretssel", device=device, dtype=dtype
-    )
-
-    units = torch.tensor(REF_FRA_UNITS, device=device, dtype=torch.int64)
-
-    # adjust the control symbols for the embedding
-    units += 4
+    tgt_lang = "fra"
 
-    # eos_idx = 2 in the VocabularyInfo setting for base pretssel_vocoder
-    unit_eos_token = torch.tensor([2], device=device)
-    units = torch.cat([units, unit_eos_token], dim=0)
-    units, duration = torch.unique_consecutive(units, return_counts=True)
-
-    # adjust for the last eos token
-    duration[-1] = 0
-    duration *= 2
-
-    # bos_idx=0 in base VocabularyInfo
-    duration_collate = Collater(pad_value=0)
-    duration_seqs = duration_collate(duration)
-
-    with torch.no_grad():
-        vocoder.eval()
-        wav_wm = vocoder(
-            seqs=units,
-            tgt_lang="fra",
-            prosody_input_seqs=feat,
-            durations=duration_seqs["seqs"],
-            normalize_before=True,
-        )
-
-    # torchaudio.save("wm.wav", wav_wm.squeeze(0).float().cpu(), sample_rate=16000)
+    feat = feat.to(device, dtype=dtype)
 
-    # Run the non-watermarked vocoder using pretssel generator
     gcmvn_mean, gcmvn_std = load_gcmvn_stats("pretssel_v1")
     gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype)  # type: ignore[assignment]
     gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype)  # type: ignore[assignment]
 
+    if sr == 16_000:
+        vocoder_model_name = "vocoder_mel"
+        pretssel_vocoder_model_name = "vocoder_pretssel_16khz"
+    else:
+        vocoder_model_name = "vocoder_mel_24khz"
+        pretssel_vocoder_model_name = "vocoder_pretssel"
+
+    # non-watermarked vocoder using pretssel generator in inference
     generator = PretsselGenerator(
         "seamless_expressivity",
-        "vocoder_mel_24khz",
+        vocoder_model_name,
         "pretssel_v1",
         gcmvn_mean=gcmvn_mean,  # type: ignore[arg-type]
         gcmvn_std=gcmvn_std,  # type: ignore[arg-type]
@@ -111,28 +95,43 @@ def test_pretssel_vocoder_watermarking(
         dtype=dtype,
     )
 
-    # PretsselGenerator expects a batch of units
+    # watermarked vocoder using pretssel generator in the evaluation
+    vocab_info = VocabularyInfo(size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1)
+    wm_generator = WatermarkedPretsselGenerator(
+        pretssel_vocoder_model_name,
+        vocab_info=vocab_info,
+        device=device,
+        dtype=dtype,
+    )
+
     unit_list: List[List[int]] = [REF_FRA_UNITS]
     prosody_input_seqs = SequenceData(
         is_ragged=False,
         seqs=feat.unsqueeze(0),  # add batch dim
         seq_lens=torch.tensor([feat.size(0)]),
     )
+
+    # Run the non-watermark vocoder, followed by a watermarker
     speech_output = generator.predict(
         unit_list,
-        tgt_lang="fra",
+        tgt_lang=tgt_lang,
         prosody_encoder_input=prosody_input_seqs,
     )
     wav = speech_output.audio_wavs[0].unsqueeze(0)
 
-    # torchaudio.save("mel.wav", wav.float().cpu(), sample_rate=16000)
-
-    # Run the watermark model separately after the PretsselGenerator
     watermarker = load_watermarking_model()
     wm = watermarker.get_watermark(wav)  # type: ignore
-    wav_wm_hat = wav + wm
+    wav_wm_hat = wav + WM_WEIGHT * wm
 
-    # Test that the watermark is detecte-able
+    # Run the watermarked vocoder
+    wm_speech_output = wm_generator.predict(
+        unit_list,
+        tgt_lang=tgt_lang,
+        prosody_encoder_input=prosody_input_seqs,
+    )
+    wav_wm = wm_speech_output.audio_wavs[0]
+
+    # Test that the watermark is detectable
     detection = watermarker.detect_watermark(wav_wm)  # type: ignore
     assert torch.all(detection[:, 1, :] > 0.5)
 
@@ -147,3 +146,78 @@ def test_pretssel_vocoder_watermarking(
         atol=0.0,
         rtol=5.0,
     )
+
+
+def test_e2e_watermark_audio() -> None:
+    data_file = "/large_experiments/seamless/data/expressivity/fairseq_manifest/benchmark_20231025/test_examples_20231122.tsv"
+    model_name = "seamless_expressivity"
+
+    # Run in CPU mode until pretssel inconsistent behavious is fixed
+    device = Device("cpu")
+    dtype = torch.float32
+
+    gcmvn_mean, gcmvn_std = load_gcmvn_stats("pretssel_v1")
+    gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype)  # type: ignore[assignment]
+    gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype)  # type: ignore[assignment]
+
+    args = Namespace(data_file=data_file, audio_root_dir="", batch_size=4)
+    pipeline = build_data_pipeline(
+        args, device=device, dtype=dtype, gcmvn_mean=gcmvn_mean, gcmvn_std=gcmvn_std  # type: ignore[arg-type]
+    )
+    translator = Translator(model_name, None, device=device, dtype=dtype)
+
+    # no watermark
+    generator = PretsselGenerator(
+        "seamless_expressivity",
+        "vocoder_mel_24khz",
+        "pretssel_v1",
+        gcmvn_mean=gcmvn_mean,  # type: ignore[arg-type]
+        gcmvn_std=gcmvn_std,  # type: ignore[arg-type]
+        device=device,
+        dtype=dtype,
+    )
+    watermarker = load_watermarking_model()
+
+    # watermark
+    vocab_info = VocabularyInfo(size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1)
+
+    wm_generator = WatermarkedPretsselGenerator(
+        "vocoder_pretssel",
+        vocab_info=vocab_info,
+        device=device,
+        dtype=dtype,
+    )
+
+    sample_id = 0
+    for batch in pipeline:
+        feat = batch["audio"]["data"]["fbank"]
+        prosody_encoder_input = batch["audio"]["data"]["gcmvn_fbank"]
+
+        text_output, unit_out = translator.predict(
+            feat,
+            task_str="s2st",
+            tgt_lang="spa",
+            prosody_encoder_input=prosody_encoder_input,
+        )
+        assert unit_out, "empty translation output"
+
+        speech_out = generator.predict(
+            units=unit_out.units,
+            tgt_lang="spa",
+            prosody_encoder_input=prosody_encoder_input,
+        )
+
+        wm_speech_out = wm_generator.predict(
+            units=unit_out.units,
+            tgt_lang="spa",
+            prosody_encoder_input=prosody_encoder_input,
+        )
+
+        for i in range(len(text_output)):
+            wav_wm = wm_speech_out.audio_wavs[i].squeeze(0)
+            wav = speech_out.audio_wavs[i].unsqueeze(0)
+            wm = watermarker.get_watermark(wav)  # type: ignore
+            wav_wm_hat = wav + 0.8 * wm
+            wav_wm_hat = wav_wm_hat.squeeze(0)
+            assert_close(wav_wm, wav_wm_hat, atol=0.01, rtol=5.0)
+            sample_id += 1