test_watermarked_vocoder.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import sys
  7. from argparse import Namespace
  8. from pathlib import Path
  9. from typing import Final, List, Optional, cast
  10. import os
  11. import pytest
  12. import torch
  13. from fairseq2.data import SequenceData, VocabularyInfo
  14. from fairseq2.data.audio import AudioDecoderOutput
  15. from fairseq2.typing import Device
  16. from torch.nn import Module
  17. from seamless_communication.inference import Translator
  18. from seamless_communication.inference.pretssel_generator import PretsselGenerator
  19. from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import (
  20. PretsselGenerator as WatermarkedPretsselGenerator,
  21. )
  22. from seamless_communication.cli.expressivity.evaluate.pretssel_inference import (
  23. build_data_pipeline,
  24. )
  25. from seamless_communication.models.unity import load_gcmvn_stats
  26. from tests.common import assert_close, convert_to_collated_fbank
  27. N_MEL_BINS = 80
  28. WM_WEIGHT = 0.8
  29. # fmt: off
  30. REF_FRA_UNITS: Final = [8976, 6589, 6589, 5736, 7542, 6515, 1240, 8335, 2381, 1076, 1076, 3380, 4085, 8207, 7957, 4446, 2641, 2544, 5552, 5529, 6319, 2779, 2890, 2890, 3229, 3303, 9751, 1979, 664, 1859, 1302, 528, 1303, 9543, 5770, 3532, 1286, 1286, 1727, 9287, 5248, 5586, 594, 3385, 2613, 1717, 7529, 7634, 931, 1602, 4512, 850, 2748, 5056, 1086, 2320, 2320, 9320, 3223, 5592, 1122, 419, 24, 4126, 5200, 2712, 9549, 8676, 8676, 3443, 7598, 7598, 2200, 2745, 1215, 118, 3840, 2703, 1616, 8788, 1240, 3349, 4890, 2756, 166, 9574, 9773, 5887, 2516, 9332, 6092, 3377, 4334, 3127, 3127, 3127, 944, 3089, 5947, 6572, 6572, 7561, 4358, 4358, 4358, 8124, 5549, 9275, 82, 8830, 8830, 5949, 22, 6729, 6878, 3817, 1871, 6092, 1441, 3127, 3928, 8254, 7984, 1116, 2796, 1806, 3710, 797, 9269, 576, 576, 2020, 137, 6624, 3815, 8690, 3634, 6036, 3530, 8719, 3458, 138, 8745, 5233, 2235, 8580, 8580, 6831, 2709, 7136, 9693, 3437, 3437, 3238, 4368, 2321, 2321, 391, 391, 4976, 8622, 6722, 3864, 9113, 9113, 7222, 7222, 7937, 999, 1286, 1286, 7789, 9396, 9603, 6690, 5233, 2235, 618, 8830, 6954, 3668, 4302, 596, 1934, 2886, 2704, 9097, 4161, 458, 4147, 9245, 9245, 3127, 3127, 944, 9676, 9676, 3468, 270, 270, 4608, 5549, 4182, 102, 8568, 1286, 1286, 5087, 817, 4153, 207, 207, 3763, 6415, 5188, 6010, 554, 753, 9953, 5104, 3828, 1879, 995, 9683, 6932, 3644, 2683, 9335, 183, 5525, 7023, 9568, 6222, 6315, 676, 3443, 6971, 2084, 999, 1286, 1286, 9620, 9620, 1048, 5577, 9328, 4963, 1364, 8328, 4573, 4573, 7917, 7917, 560, 2020, 4923, 137, 9542, 5832, 9775, 4780, 9400, 2745, 2745, 8984, 628, 8834, 6932, 3817, 8312, 5393, 458, 4147, 9191, 2225, 2759, 8980, 2351, 193, 1476, 9347, 3063, 2076, 3641, 1614, 9832, 3554, 8197, 5589, 5589, 7306, 184, 1708, 2954, 2954, 3485, 3485, 7665, 8909, 5405, 3590, 3590, 3446, 6442, 6442, 2802, 5549, 3791]
  31. # fmt: on
  32. def load_watermarking_model() -> Optional[Module]:
  33. import importlib.util
  34. # Run in CPU mode until pretssel inconsistent behavious is fixed
  35. device = Device("cpu")
  36. dtype = torch.float32
  37. wm_py_file = Path(__file__).parents[3] / "scripts/watermarking/watermarking.py"
  38. assert wm_py_file.is_file()
  39. wm_spec = importlib.util.spec_from_file_location("watermark.f1", wm_py_file)
  40. assert wm_spec, f"Module not found: {wm_py_file}"
  41. wm_py_module = importlib.util.module_from_spec(wm_spec)
  42. assert wm_py_module, f"Invalid Python module file: {wm_py_file}"
  43. sys.modules["watermark.f1"] = wm_py_module
  44. assert wm_spec.loader, f"Module cannot be loaded from {wm_py_file}"
  45. wm_spec.loader.exec_module(wm_py_module)
  46. ckpt = os.getenv("SEAMLESS_WM_CKPT", "")
  47. return cast(Module, wm_py_module.model_from_checkpoint(device=device, checkpoint=ckpt, dtype=dtype))
  48. @pytest.mark.parametrize("sr", [16_000, 24_000])
  49. def test_pretssel_vocoder_watermarking(
  50. example_rate16k_audio: AudioDecoderOutput, sr: int
  51. ) -> None:
  52. """
  53. Test that the watermarked pretssel vocoder generates the same output
  54. as the non-watermarked (pretssel_generator)
  55. """
  56. # Run in CPU mode until pretssel inconsistent behavious is fixed
  57. device = Device("cpu")
  58. dtype = torch.float32
  59. audio = example_rate16k_audio
  60. audio["waveform"] = audio["waveform"].to(device, dtype=dtype)
  61. feat = convert_to_collated_fbank(audio, dtype=dtype)["seqs"][0]
  62. tgt_lang = "fra"
  63. feat = feat.to(device, dtype=dtype)
  64. gcmvn_mean, gcmvn_std = load_gcmvn_stats("pretssel_v1")
  65. gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype) # type: ignore[assignment]
  66. gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype) # type: ignore[assignment]
  67. if sr == 16_000:
  68. vocoder_model_name = "vocoder_mel"
  69. pretssel_vocoder_model_name = "vocoder_pretssel_16khz"
  70. else:
  71. vocoder_model_name = "vocoder_mel_24khz"
  72. pretssel_vocoder_model_name = "vocoder_pretssel"
  73. # non-watermarked vocoder using pretssel generator in inference
  74. generator = PretsselGenerator(
  75. "seamless_expressivity",
  76. vocoder_model_name,
  77. "pretssel_v1",
  78. gcmvn_mean=gcmvn_mean, # type: ignore[arg-type]
  79. gcmvn_std=gcmvn_std, # type: ignore[arg-type]
  80. device=device,
  81. dtype=dtype,
  82. )
  83. # watermarked vocoder using pretssel generator in the evaluation
  84. vocab_info = VocabularyInfo(size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1)
  85. wm_generator = WatermarkedPretsselGenerator(
  86. pretssel_vocoder_model_name,
  87. vocab_info=vocab_info,
  88. device=device,
  89. dtype=dtype,
  90. )
  91. unit_list: List[List[int]] = [REF_FRA_UNITS]
  92. prosody_input_seqs = SequenceData(
  93. is_ragged=False,
  94. seqs=feat.unsqueeze(0), # add batch dim
  95. seq_lens=torch.tensor([feat.size(0)]),
  96. )
  97. # Run the non-watermark vocoder, followed by a watermarker
  98. speech_output = generator.predict(
  99. unit_list,
  100. tgt_lang=tgt_lang,
  101. prosody_encoder_input=prosody_input_seqs,
  102. )
  103. wav = speech_output.audio_wavs[0].unsqueeze(0)
  104. watermarker = load_watermarking_model()
  105. wm = watermarker.get_watermark(wav) # type: ignore
  106. wav_wm_hat = wav + WM_WEIGHT * wm
  107. # Run the watermarked vocoder
  108. wm_speech_output = wm_generator.predict(
  109. unit_list,
  110. tgt_lang=tgt_lang,
  111. prosody_encoder_input=prosody_input_seqs,
  112. )
  113. wav_wm = wm_speech_output.audio_wavs[0]
  114. # Test that the watermark is detectable
  115. detection = watermarker.detect_watermark(wav_wm) # type: ignore
  116. # 0.9 is the current lower bound of Watermarking w.r.t all attacks
  117. assert torch.count_nonzero(torch.gt(detection[:, 1, :], 0.5)) / detection.shape[-1] > 0.9
  118. # Remove the batch and compare parity on the overlapping frames
  119. wav_wm = wav_wm.squeeze(0)
  120. wav_wm_hat = wav_wm_hat.squeeze(0)
  121. nframes = min(wav_wm_hat.size(1), wav_wm.size(1))
  122. assert_close(
  123. wav_wm[:, :nframes],
  124. wav_wm_hat[:, :nframes],
  125. atol=0.0,
  126. rtol=5.0,
  127. )
  128. @pytest.mark.skip(reason="Skip this test since it's extremely slow.")
  129. def test_e2e_watermark_audio() -> None:
  130. data_file = "/large_experiments/seamless/data/expressivity/fairseq_manifest/benchmark_20231025/test_examples_20231122.tsv"
  131. model_name = "seamless_expressivity"
  132. # Run in CPU mode until pretssel inconsistent behavious is fixed
  133. device = Device("cpu")
  134. dtype = torch.float32
  135. gcmvn_mean, gcmvn_std = load_gcmvn_stats("pretssel_v1")
  136. gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype) # type: ignore[assignment]
  137. gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype) # type: ignore[assignment]
  138. args = Namespace(data_file=data_file, audio_root_dir="", batch_size=4)
  139. pipeline = build_data_pipeline(
  140. args, device=device, dtype=dtype, gcmvn_mean=gcmvn_mean, gcmvn_std=gcmvn_std # type: ignore[arg-type]
  141. )
  142. translator = Translator(model_name, None, device=device, dtype=dtype)
  143. # no watermark
  144. generator = PretsselGenerator(
  145. "seamless_expressivity",
  146. "vocoder_mel_24khz",
  147. "pretssel_v1",
  148. gcmvn_mean=gcmvn_mean, # type: ignore[arg-type]
  149. gcmvn_std=gcmvn_std, # type: ignore[arg-type]
  150. device=device,
  151. dtype=dtype,
  152. )
  153. watermarker = load_watermarking_model()
  154. # watermark
  155. vocab_info = VocabularyInfo(size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1)
  156. wm_generator = WatermarkedPretsselGenerator(
  157. "vocoder_pretssel",
  158. vocab_info=vocab_info,
  159. device=device,
  160. dtype=dtype,
  161. )
  162. sample_id = 0
  163. for batch in pipeline:
  164. feat = batch["audio"]["data"]["fbank"]
  165. prosody_encoder_input = batch["audio"]["data"]["gcmvn_fbank"]
  166. text_output, unit_out = translator.predict(
  167. feat,
  168. task_str="s2st",
  169. tgt_lang="spa",
  170. prosody_encoder_input=prosody_encoder_input,
  171. )
  172. assert unit_out, "empty translation output"
  173. speech_out = generator.predict(
  174. units=unit_out.units,
  175. tgt_lang="spa",
  176. prosody_encoder_input=prosody_encoder_input,
  177. )
  178. wm_speech_out = wm_generator.predict(
  179. units=unit_out.units,
  180. tgt_lang="spa",
  181. prosody_encoder_input=prosody_encoder_input,
  182. )
  183. for i in range(len(text_output)):
  184. wav_wm = wm_speech_out.audio_wavs[i].squeeze(0)
  185. wav = speech_out.audio_wavs[i].unsqueeze(0)
  186. wm = watermarker.get_watermark(wav) # type: ignore
  187. wav_wm_hat = wav + 0.8 * wm
  188. wav_wm_hat = wav_wm_hat.squeeze(0)
  189. assert_close(wav_wm, wav_wm_hat, atol=0.01, rtol=5.0)
  190. sample_id += 1