test_watermarked_vocoder.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import sys
  7. from pathlib import Path
  8. from typing import Final, List, Optional, cast
  9. import torch
  10. from fairseq2.data import Collater, SequenceData
  11. from fairseq2.data.audio import AudioDecoderOutput
  12. from fairseq2.typing import Device
  13. from torch.nn import Module
  14. from seamless_communication.inference.pretssel_generator import PretsselGenerator
  15. from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
  16. from seamless_communication.models.unity.loader import load_gcmvn_stats
  17. from tests.common import assert_close, convert_to_collated_fbank
  18. N_MEL_BINS = 80
  19. # fmt: off
  20. REF_FRA_UNITS: Final = [8976, 6589, 6589, 5736, 7542, 6515, 1240, 8335, 2381, 1076, 1076, 3380, 4085, 8207, 7957, 4446, 2641, 2544, 5552, 5529, 6319, 2779, 2890, 2890, 3229, 3303, 9751, 1979, 664, 1859, 1302, 528, 1303, 9543, 5770, 3532, 1286, 1286, 1727, 9287, 5248, 5586, 594, 3385, 2613, 1717, 7529, 7634, 931, 1602, 4512, 850, 2748, 5056, 1086, 2320, 2320, 9320, 3223, 5592, 1122, 419, 24, 4126, 5200, 2712, 9549, 8676, 8676, 3443, 7598, 7598, 2200, 2745, 1215, 118, 3840, 2703, 1616, 8788, 1240, 3349, 4890, 2756, 166, 9574, 9773, 5887, 2516, 9332, 6092, 3377, 4334, 3127, 3127, 3127, 944, 3089, 5947, 6572, 6572, 7561, 4358, 4358, 4358, 8124, 5549, 9275, 82, 8830, 8830, 5949, 22, 6729, 6878, 3817, 1871, 6092, 1441, 3127, 3928, 8254, 7984, 1116, 2796, 1806, 3710, 797, 9269, 576, 576, 2020, 137, 6624, 3815, 8690, 3634, 6036, 3530, 8719, 3458, 138, 8745, 5233, 2235, 8580, 8580, 6831, 2709, 7136, 9693, 3437, 3437, 3238, 4368, 2321, 2321, 391, 391, 4976, 8622, 6722, 3864, 9113, 9113, 7222, 7222, 7937, 999, 1286, 1286, 7789, 9396, 9603, 6690, 5233, 2235, 618, 8830, 6954, 3668, 4302, 596, 1934, 2886, 2704, 9097, 4161, 458, 4147, 9245, 9245, 3127, 3127, 944, 9676, 9676, 3468, 270, 270, 4608, 5549, 4182, 102, 8568, 1286, 1286, 5087, 817, 4153, 207, 207, 3763, 6415, 5188, 6010, 554, 753, 9953, 5104, 3828, 1879, 995, 9683, 6932, 3644, 2683, 9335, 183, 5525, 7023, 9568, 6222, 6315, 676, 3443, 6971, 2084, 999, 1286, 1286, 9620, 9620, 1048, 5577, 9328, 4963, 1364, 8328, 4573, 4573, 7917, 7917, 560, 2020, 4923, 137, 9542, 5832, 9775, 4780, 9400, 2745, 2745, 8984, 628, 8834, 6932, 3817, 8312, 5393, 458, 4147, 9191, 2225, 2759, 8980, 2351, 193, 1476, 9347, 3063, 2076, 3641, 1614, 9832, 3554, 8197, 5589, 5589, 7306, 184, 1708, 2954, 2954, 3485, 3485, 7665, 8909, 5405, 3590, 3590, 3446, 6442, 6442, 2802, 5549, 3791]
  21. # fmt: on
  22. def load_watermarking_model() -> Optional[Module]:
  23. import importlib.util
  24. # Run in CPU mode until pretssel inconsistent behavious is fixed
  25. device = Device("cpu")
  26. dtype = torch.float32
  27. wm_py_file = Path(__file__).parents[3] / "scripts/watermarking/watermarking.py"
  28. assert wm_py_file.is_file()
  29. wm_spec = importlib.util.spec_from_file_location("watermark.f1", wm_py_file)
  30. assert wm_spec, f"Module not found: {wm_py_file}"
  31. wm_py_module = importlib.util.module_from_spec(wm_spec)
  32. assert wm_py_module, f"Invalid Python module file: {wm_py_file}"
  33. sys.modules["watermark.f1"] = wm_py_module
  34. assert wm_spec.loader, f"Module cannot be loaded from {wm_py_file}"
  35. wm_spec.loader.exec_module(wm_py_module)
  36. return cast(Module, wm_py_module.model_from_checkpoint(device=device, dtype=dtype))
  37. def test_pretssel_vocoder_watermarking(
  38. example_rate16k_audio: AudioDecoderOutput,
  39. ) -> None:
  40. """
  41. Test that the watermarked pretssel vocoder generates the same output
  42. as the non-watermarked (pretssel_generator)
  43. """
  44. audio = example_rate16k_audio
  45. # Run in CPU mode until pretssel inconsistent behavious is fixed
  46. device = Device("cpu")
  47. dtype = torch.float32
  48. audio["waveform"] = audio["waveform"].to(device, dtype=dtype)
  49. feat = convert_to_collated_fbank(audio, dtype=dtype)["seqs"][0]
  50. feat = feat.to(device, dtype=dtype)
  51. # Run the watermarked vocoding
  52. # TODO: Build a generator API for the watermarked vocoder
  53. vocoder = load_pretssel_vocoder_model(
  54. "vocoder_pretssel", device=device, dtype=dtype
  55. )
  56. units = torch.tensor(REF_FRA_UNITS, device=device, dtype=torch.int64)
  57. # adjust the control symbols for the embedding
  58. units += 4
  59. # eos_idx = 2 in the VocabularyInfo setting for base pretssel_vocoder
  60. unit_eos_token = torch.tensor([2], device=device)
  61. units = torch.cat([units, unit_eos_token], dim=0)
  62. units, duration = torch.unique_consecutive(units, return_counts=True)
  63. # adjust for the last eos token
  64. duration[-1] = 0
  65. duration *= 2
  66. # bos_idx=0 in base VocabularyInfo
  67. duration_collate = Collater(pad_value=0)
  68. duration_seqs = duration_collate(duration)
  69. with torch.no_grad():
  70. vocoder.eval()
  71. wav_wm = vocoder(
  72. seqs=units,
  73. tgt_lang="fra",
  74. prosody_input_seqs=feat,
  75. durations=duration_seqs["seqs"],
  76. normalize_before=True,
  77. )
  78. # torchaudio.save("wm.wav", wav_wm.squeeze(0).float().cpu(), sample_rate=16000)
  79. # Run the non-watermarked vocoder using pretssel generator
  80. gcmvn_mean, gcmvn_std = load_gcmvn_stats("pretssel_v1")
  81. gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype) # type: ignore[assignment]
  82. gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype) # type: ignore[assignment]
  83. generator = PretsselGenerator(
  84. "seamless_expressivity",
  85. "vocoder_mel_24khz",
  86. "pretssel_v1",
  87. gcmvn_mean=gcmvn_mean, # type: ignore[arg-type]
  88. gcmvn_std=gcmvn_std, # type: ignore[arg-type]
  89. device=device,
  90. dtype=dtype,
  91. )
  92. # PretsselGenerator expects a batch of units
  93. unit_list: List[List[int]] = [REF_FRA_UNITS]
  94. prosody_input_seqs = SequenceData(
  95. is_ragged=False,
  96. seqs=feat.unsqueeze(0), # add batch dim
  97. seq_lens=torch.tensor([feat.size(0)]),
  98. )
  99. speech_output = generator.predict(
  100. unit_list,
  101. tgt_lang="fra",
  102. prosody_encoder_input=prosody_input_seqs,
  103. )
  104. wav = speech_output.audio_wavs[0].unsqueeze(0)
  105. # torchaudio.save("mel.wav", wav.float().cpu(), sample_rate=16000)
  106. # Run the watermark model separately after the PretsselGenerator
  107. watermarker = load_watermarking_model()
  108. wm = watermarker.get_watermark(wav) # type: ignore
  109. wav_wm_hat = wav + wm
  110. # Test that the watermark is detecte-able
  111. detection = watermarker.detect_watermark(wav_wm) # type: ignore
  112. assert torch.all(detection[:, 1, :] > 0.5)
  113. # Remove the batch and compare parity on the overlapping frames
  114. wav_wm = wav_wm.squeeze(0)
  115. wav_wm_hat = wav_wm_hat.squeeze(0)
  116. nframes = min(wav_wm_hat.size(1), wav_wm.size(1))
  117. assert_close(
  118. wav_wm[:, :nframes],
  119. wav_wm_hat[:, :nframes],
  120. atol=0.0,
  121. rtol=5.0,
  122. )