test_watermarked_vocoder.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. import sys
  7. from argparse import Namespace
  8. from pathlib import Path
  9. from typing import Final, List, Optional, cast
  10. import pytest
  11. import torch
  12. from fairseq2.data import SequenceData, VocabularyInfo
  13. from fairseq2.data.audio import AudioDecoderOutput
  14. from fairseq2.typing import Device
  15. from torch.nn import Module
  16. from seamless_communication.inference import Translator
  17. from seamless_communication.inference.pretssel_generator import PretsselGenerator
  18. from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import (
  19. PretsselGenerator as WatermarkedPretsselGenerator,
  20. )
  21. from seamless_communication.cli.expressivity.evaluate.pretssel_inference import (
  22. build_data_pipeline,
  23. )
  24. from seamless_communication.models.unity import load_gcmvn_stats
  25. from tests.common import assert_close, convert_to_collated_fbank
  26. N_MEL_BINS = 80
  27. WM_WEIGHT = 0.8
  28. # fmt: off
  29. REF_FRA_UNITS: Final = [8976, 6589, 6589, 5736, 7542, 6515, 1240, 8335, 2381, 1076, 1076, 3380, 4085, 8207, 7957, 4446, 2641, 2544, 5552, 5529, 6319, 2779, 2890, 2890, 3229, 3303, 9751, 1979, 664, 1859, 1302, 528, 1303, 9543, 5770, 3532, 1286, 1286, 1727, 9287, 5248, 5586, 594, 3385, 2613, 1717, 7529, 7634, 931, 1602, 4512, 850, 2748, 5056, 1086, 2320, 2320, 9320, 3223, 5592, 1122, 419, 24, 4126, 5200, 2712, 9549, 8676, 8676, 3443, 7598, 7598, 2200, 2745, 1215, 118, 3840, 2703, 1616, 8788, 1240, 3349, 4890, 2756, 166, 9574, 9773, 5887, 2516, 9332, 6092, 3377, 4334, 3127, 3127, 3127, 944, 3089, 5947, 6572, 6572, 7561, 4358, 4358, 4358, 8124, 5549, 9275, 82, 8830, 8830, 5949, 22, 6729, 6878, 3817, 1871, 6092, 1441, 3127, 3928, 8254, 7984, 1116, 2796, 1806, 3710, 797, 9269, 576, 576, 2020, 137, 6624, 3815, 8690, 3634, 6036, 3530, 8719, 3458, 138, 8745, 5233, 2235, 8580, 8580, 6831, 2709, 7136, 9693, 3437, 3437, 3238, 4368, 2321, 2321, 391, 391, 4976, 8622, 6722, 3864, 9113, 9113, 7222, 7222, 7937, 999, 1286, 1286, 7789, 9396, 9603, 6690, 5233, 2235, 618, 8830, 6954, 3668, 4302, 596, 1934, 2886, 2704, 9097, 4161, 458, 4147, 9245, 9245, 3127, 3127, 944, 9676, 9676, 3468, 270, 270, 4608, 5549, 4182, 102, 8568, 1286, 1286, 5087, 817, 4153, 207, 207, 3763, 6415, 5188, 6010, 554, 753, 9953, 5104, 3828, 1879, 995, 9683, 6932, 3644, 2683, 9335, 183, 5525, 7023, 9568, 6222, 6315, 676, 3443, 6971, 2084, 999, 1286, 1286, 9620, 9620, 1048, 5577, 9328, 4963, 1364, 8328, 4573, 4573, 7917, 7917, 560, 2020, 4923, 137, 9542, 5832, 9775, 4780, 9400, 2745, 2745, 8984, 628, 8834, 6932, 3817, 8312, 5393, 458, 4147, 9191, 2225, 2759, 8980, 2351, 193, 1476, 9347, 3063, 2076, 3641, 1614, 9832, 3554, 8197, 5589, 5589, 7306, 184, 1708, 2954, 2954, 3485, 3485, 7665, 8909, 5405, 3590, 3590, 3446, 6442, 6442, 2802, 5549, 3791]
  30. # fmt: on
  31. def load_watermarking_model() -> Optional[Module]:
  32. import importlib.util
  33. # Run in CPU mode until pretssel inconsistent behavious is fixed
  34. device = Device("cpu")
  35. dtype = torch.float32
  36. wm_py_file = Path(__file__).parents[3] / "scripts/watermarking/watermarking.py"
  37. assert wm_py_file.is_file()
  38. wm_spec = importlib.util.spec_from_file_location("watermark.f1", wm_py_file)
  39. assert wm_spec, f"Module not found: {wm_py_file}"
  40. wm_py_module = importlib.util.module_from_spec(wm_spec)
  41. assert wm_py_module, f"Invalid Python module file: {wm_py_file}"
  42. sys.modules["watermark.f1"] = wm_py_module
  43. assert wm_spec.loader, f"Module cannot be loaded from {wm_py_file}"
  44. wm_spec.loader.exec_module(wm_py_module)
  45. return cast(Module, wm_py_module.model_from_checkpoint(device=device, dtype=dtype))
  46. @pytest.mark.parametrize("sr", [16_000, 24_000])
  47. def test_pretssel_vocoder_watermarking(
  48. example_rate16k_audio: AudioDecoderOutput, sr: int
  49. ) -> None:
  50. """
  51. Test that the watermarked pretssel vocoder generates the same output
  52. as the non-watermarked (pretssel_generator)
  53. """
  54. # Run in CPU mode until pretssel inconsistent behavious is fixed
  55. device = Device("cpu")
  56. dtype = torch.float32
  57. audio = example_rate16k_audio
  58. audio["waveform"] = audio["waveform"].to(device, dtype=dtype)
  59. feat = convert_to_collated_fbank(audio, dtype=dtype)["seqs"][0]
  60. tgt_lang = "fra"
  61. feat = feat.to(device, dtype=dtype)
  62. gcmvn_mean, gcmvn_std = load_gcmvn_stats("pretssel_v1")
  63. gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype) # type: ignore[assignment]
  64. gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype) # type: ignore[assignment]
  65. if sr == 16_000:
  66. vocoder_model_name = "vocoder_mel"
  67. pretssel_vocoder_model_name = "vocoder_pretssel_16khz"
  68. else:
  69. vocoder_model_name = "vocoder_mel_24khz"
  70. pretssel_vocoder_model_name = "vocoder_pretssel"
  71. # non-watermarked vocoder using pretssel generator in inference
  72. generator = PretsselGenerator(
  73. "seamless_expressivity",
  74. vocoder_model_name,
  75. "pretssel_v1",
  76. gcmvn_mean=gcmvn_mean, # type: ignore[arg-type]
  77. gcmvn_std=gcmvn_std, # type: ignore[arg-type]
  78. device=device,
  79. dtype=dtype,
  80. )
  81. # watermarked vocoder using pretssel generator in the evaluation
  82. vocab_info = VocabularyInfo(size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1)
  83. wm_generator = WatermarkedPretsselGenerator(
  84. pretssel_vocoder_model_name,
  85. vocab_info=vocab_info,
  86. device=device,
  87. dtype=dtype,
  88. )
  89. unit_list: List[List[int]] = [REF_FRA_UNITS]
  90. prosody_input_seqs = SequenceData(
  91. is_ragged=False,
  92. seqs=feat.unsqueeze(0), # add batch dim
  93. seq_lens=torch.tensor([feat.size(0)]),
  94. )
  95. # Run the non-watermark vocoder, followed by a watermarker
  96. speech_output = generator.predict(
  97. unit_list,
  98. tgt_lang=tgt_lang,
  99. prosody_encoder_input=prosody_input_seqs,
  100. )
  101. wav = speech_output.audio_wavs[0].unsqueeze(0)
  102. watermarker = load_watermarking_model()
  103. wm = watermarker.get_watermark(wav) # type: ignore
  104. wav_wm_hat = wav + WM_WEIGHT * wm
  105. # Run the watermarked vocoder
  106. wm_speech_output = wm_generator.predict(
  107. unit_list,
  108. tgt_lang=tgt_lang,
  109. prosody_encoder_input=prosody_input_seqs,
  110. )
  111. wav_wm = wm_speech_output.audio_wavs[0]
  112. # Test that the watermark is detectable
  113. detection = watermarker.detect_watermark(wav_wm) # type: ignore
  114. assert torch.all(detection[:, 1, :] > 0.5)
  115. # Remove the batch and compare parity on the overlapping frames
  116. wav_wm = wav_wm.squeeze(0)
  117. wav_wm_hat = wav_wm_hat.squeeze(0)
  118. nframes = min(wav_wm_hat.size(1), wav_wm.size(1))
  119. assert_close(
  120. wav_wm[:, :nframes],
  121. wav_wm_hat[:, :nframes],
  122. atol=0.0,
  123. rtol=5.0,
  124. )
  125. def test_e2e_watermark_audio() -> None:
  126. data_file = "/large_experiments/seamless/data/expressivity/fairseq_manifest/benchmark_20231025/test_examples_20231122.tsv"
  127. model_name = "seamless_expressivity"
  128. # Run in CPU mode until pretssel inconsistent behavious is fixed
  129. device = Device("cpu")
  130. dtype = torch.float32
  131. gcmvn_mean, gcmvn_std = load_gcmvn_stats("pretssel_v1")
  132. gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype) # type: ignore[assignment]
  133. gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype) # type: ignore[assignment]
  134. args = Namespace(data_file=data_file, audio_root_dir="", batch_size=4)
  135. pipeline = build_data_pipeline(
  136. args, device=device, dtype=dtype, gcmvn_mean=gcmvn_mean, gcmvn_std=gcmvn_std # type: ignore[arg-type]
  137. )
  138. translator = Translator(model_name, None, device=device, dtype=dtype)
  139. # no watermark
  140. generator = PretsselGenerator(
  141. "seamless_expressivity",
  142. "vocoder_mel_24khz",
  143. "pretssel_v1",
  144. gcmvn_mean=gcmvn_mean, # type: ignore[arg-type]
  145. gcmvn_std=gcmvn_std, # type: ignore[arg-type]
  146. device=device,
  147. dtype=dtype,
  148. )
  149. watermarker = load_watermarking_model()
  150. # watermark
  151. vocab_info = VocabularyInfo(size=10004, unk_idx=3, bos_idx=0, eos_idx=2, pad_idx=1)
  152. wm_generator = WatermarkedPretsselGenerator(
  153. "vocoder_pretssel",
  154. vocab_info=vocab_info,
  155. device=device,
  156. dtype=dtype,
  157. )
  158. sample_id = 0
  159. for batch in pipeline:
  160. feat = batch["audio"]["data"]["fbank"]
  161. prosody_encoder_input = batch["audio"]["data"]["gcmvn_fbank"]
  162. text_output, unit_out = translator.predict(
  163. feat,
  164. task_str="s2st",
  165. tgt_lang="spa",
  166. prosody_encoder_input=prosody_encoder_input,
  167. )
  168. assert unit_out, "empty translation output"
  169. speech_out = generator.predict(
  170. units=unit_out.units,
  171. tgt_lang="spa",
  172. prosody_encoder_input=prosody_encoder_input,
  173. )
  174. wm_speech_out = wm_generator.predict(
  175. units=unit_out.units,
  176. tgt_lang="spa",
  177. prosody_encoder_input=prosody_encoder_input,
  178. )
  179. for i in range(len(text_output)):
  180. wav_wm = wm_speech_out.audio_wavs[i].squeeze(0)
  181. wav = speech_out.audio_wavs[i].unsqueeze(0)
  182. wm = watermarker.get_watermark(wav) # type: ignore
  183. wav_wm_hat = wav + 0.8 * wm
  184. wav_wm_hat = wav_wm_hat.squeeze(0)
  185. assert_close(wav_wm, wav_wm_hat, atol=0.01, rtol=5.0)
  186. sample_id += 1