test_watermarked_vocoder.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import sys
  7. from typing import cast, List, Final, Optional
  8. from anyio import Path
  9. import torch
  10. from fairseq2.typing import Device
  11. from fairseq2.data import Collater, SequenceData
  12. from fairseq2.data.audio import AudioDecoderOutput
  13. from torch.nn import Module
  14. from seamless_communication.inference.pretssel_generator import PretsselGenerator
  15. from seamless_communication.models.unity.loader import load_gcmvn_stats
  16. from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
  17. from tests.common import (
  18. assert_close,
  19. convert_to_collated_fbank,
  20. )
  21. N_MEL_BINS = 80
  22. # fmt: off
  23. REF_FRA_UNITS: Final = [8976, 6589, 6589, 5736, 7542, 6515, 1240, 8335, 2381, 1076, 1076, 3380, 4085, 8207, 7957, 4446, 2641, 2544, 5552, 5529, 6319, 2779, 2890, 2890, 3229, 3303, 9751, 1979, 664, 1859, 1302, 528, 1303, 9543, 5770, 3532, 1286, 1286, 1727, 9287, 5248, 5586, 594, 3385, 2613, 1717, 7529, 7634, 931, 1602, 4512, 850, 2748, 5056, 1086, 2320, 2320, 9320, 3223, 5592, 1122, 419, 24, 4126, 5200, 2712, 9549, 8676, 8676, 3443, 7598, 7598, 2200, 2745, 1215, 118, 3840, 2703, 1616, 8788, 1240, 3349, 4890, 2756, 166, 9574, 9773, 5887, 2516, 9332, 6092, 3377, 4334, 3127, 3127, 3127, 944, 3089, 5947, 6572, 6572, 7561, 4358, 4358, 4358, 8124, 5549, 9275, 82, 8830, 8830, 5949, 22, 6729, 6878, 3817, 1871, 6092, 1441, 3127, 3928, 8254, 7984, 1116, 2796, 1806, 3710, 797, 9269, 576, 576, 2020, 137, 6624, 3815, 8690, 3634, 6036, 3530, 8719, 3458, 138, 8745, 5233, 2235, 8580, 8580, 6831, 2709, 7136, 9693, 3437, 3437, 3238, 4368, 2321, 2321, 391, 391, 4976, 8622, 6722, 3864, 9113, 9113, 7222, 7222, 7937, 999, 1286, 1286, 7789, 9396, 9603, 6690, 5233, 2235, 618, 8830, 6954, 3668, 4302, 596, 1934, 2886, 2704, 9097, 4161, 458, 4147, 9245, 9245, 3127, 3127, 944, 9676, 9676, 3468, 270, 270, 4608, 5549, 4182, 102, 8568, 1286, 1286, 5087, 817, 4153, 207, 207, 3763, 6415, 5188, 6010, 554, 753, 9953, 5104, 3828, 1879, 995, 9683, 6932, 3644, 2683, 9335, 183, 5525, 7023, 9568, 6222, 6315, 676, 3443, 6971, 2084, 999, 1286, 1286, 9620, 9620, 1048, 5577, 9328, 4963, 1364, 8328, 4573, 4573, 7917, 7917, 560, 2020, 4923, 137, 9542, 5832, 9775, 4780, 9400, 2745, 2745, 8984, 628, 8834, 6932, 3817, 8312, 5393, 458, 4147, 9191, 2225, 2759, 8980, 2351, 193, 1476, 9347, 3063, 2076, 3641, 1614, 9832, 3554, 8197, 5589, 5589, 7306, 184, 1708, 2954, 2954, 3485, 3485, 7665, 8909, 5405, 3590, 3590, 3446, 6442, 6442, 2802, 5549, 3791]
  24. # fmt: on
  25. def load_watermarking_model() -> Optional[Module]:
  26. import importlib.util
  27. # Run in CPU mode until pretssel inconsistent behavious is fixed
  28. device = Device("cpu")
  29. dtype = torch.float32
  30. wm_py_file = Path(__file__).parents[3] / "scripts/watermarking/watermarking.py"
  31. assert wm_py_file.is_file()
  32. wm_spec = importlib.util.spec_from_file_location("watermark.f1", wm_py_file)
  33. assert wm_spec, f"Module not found: {wm_py_file}"
  34. wm_py_module = importlib.util.module_from_spec(wm_spec)
  35. assert wm_py_module, f"Invalid Python module file: {wm_py_file}"
  36. sys.modules["watermark.f1"] = wm_py_module
  37. assert wm_spec.loader, f"Module cannot be loaded from {wm_py_file}"
  38. wm_spec.loader.exec_module(wm_py_module)
  39. return cast(Module, wm_py_module.model_from_checkpoint(device=device, dtype=dtype))
  40. def test_pretssel_vocoder_watermarking(
  41. example_rate16k_audio: AudioDecoderOutput,
  42. ) -> None:
  43. """
  44. Test that the watermarked pretssel vocoder generates the same output
  45. as the non-watermarked (pretssel_generator)
  46. """
  47. audio = example_rate16k_audio
  48. # Run in CPU mode until pretssel inconsistent behavious is fixed
  49. device = Device("cpu")
  50. dtype = torch.float32
  51. audio["waveform"] = audio["waveform"].to(device, dtype=dtype)
  52. feat = convert_to_collated_fbank(audio, dtype=dtype)["seqs"][0]
  53. feat = feat.to(device, dtype=dtype)
  54. # Run the watermarked vocoding
  55. # TODO: Build a generator API for the watermarked vocoder
  56. vocoder = load_pretssel_vocoder_model(
  57. "vocoder_pretssel", device=device, dtype=dtype
  58. )
  59. units = torch.tensor(REF_FRA_UNITS, device=device, dtype=torch.int64)
  60. # adjust the control symbols for the embedding
  61. units += 4
  62. # eos_idx = 2 in the VocabularyInfo setting for base pretssel_vocoder
  63. unit_eos_token = torch.tensor([2], device=device)
  64. units = torch.cat([units, unit_eos_token], dim=0)
  65. units, duration = torch.unique_consecutive(units, return_counts=True)
  66. # adjust for the last eos token
  67. duration[-1] = 0
  68. duration *= 2
  69. # bos_idx=0 in base VocabularyInfo
  70. duration_collate = Collater(pad_value=0)
  71. duration_seqs = duration_collate(duration)
  72. with torch.no_grad():
  73. vocoder.eval()
  74. wav_wm = vocoder(
  75. seqs=units,
  76. tgt_lang="fra",
  77. prosody_input_seqs=feat,
  78. durations=duration_seqs["seqs"],
  79. normalize_before=True,
  80. )
  81. # torchaudio.save("wm.wav", wav_wm.squeeze(0).float().cpu(), sample_rate=16000)
  82. # Run the non-watermarked vocoder using pretssel generator
  83. gcmvn_mean, gcmvn_std = load_gcmvn_stats("pretssel_v1")
  84. gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype) # type: ignore[assignment]
  85. gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype) # type: ignore[assignment]
  86. generator = PretsselGenerator(
  87. "seamless_expressivity",
  88. "vocoder_mel_24khz",
  89. "pretssel_v1",
  90. gcmvn_mean=gcmvn_mean, # type: ignore[arg-type]
  91. gcmvn_std=gcmvn_std, # type: ignore[arg-type]
  92. device=device,
  93. dtype=dtype,
  94. )
  95. # PretsselGenerator expects a batch of units
  96. unit_list: List[List[int]] = [REF_FRA_UNITS]
  97. prosody_input_seqs = SequenceData(
  98. is_ragged=False,
  99. seqs=feat.unsqueeze(0), # add batch dim
  100. seq_lens=torch.tensor([feat.size(0)]),
  101. )
  102. speech_output = generator.predict(
  103. unit_list,
  104. tgt_lang="fra",
  105. prosody_encoder_input=prosody_input_seqs,
  106. )
  107. wav = speech_output.audio_wavs[0].unsqueeze(0)
  108. # torchaudio.save("mel.wav", wav.float().cpu(), sample_rate=16000)
  109. # Run the watermark model separately after the PretsselGenerator
  110. watermarker = load_watermarking_model()
  111. wm = watermarker.get_watermark(wav) # type: ignore
  112. wav_wm_hat = wav + wm
  113. # Test that the watermark is detecte-able
  114. detection = watermarker.detect_watermark(wav_wm) # type: ignore
  115. assert torch.all(detection[:, 1, :] > 0.5)
  116. # Remove the batch and compare parity on the overlapping frames
  117. wav_wm = wav_wm.squeeze(0)
  118. wav_wm_hat = wav_wm_hat.squeeze(0)
  119. nframes = min(wav_wm_hat.size(1), wav_wm.size(1))
  120. assert_close(
  121. wav_wm[:, :nframes],
  122. wav_wm_hat[:, :nframes],
  123. atol=0.0,
  124. rtol=5.0,
  125. )