123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- # All rights reserved.
- #
- # This source code is licensed under the license found in the
- # LICENSE file in the root directory of this source tree.
- import sys
- from typing import cast, List, Final, Optional
- from anyio import Path
- import torch
- from fairseq2.typing import Device
- from fairseq2.data import Collater, SequenceData
- from fairseq2.data.audio import AudioDecoderOutput
- from torch.nn import Module
- from seamless_communication.inference.pretssel_generator import PretsselGenerator
- from seamless_communication.models.unity.loader import load_gcmvn_stats
- from seamless_communication.models.generator.loader import load_pretssel_vocoder_model
- from tests.common import (
- assert_close,
- convert_to_collated_fbank,
- )
- N_MEL_BINS = 80
- # fmt: off
- REF_FRA_UNITS: Final = [8976, 6589, 6589, 5736, 7542, 6515, 1240, 8335, 2381, 1076, 1076, 3380, 4085, 8207, 7957, 4446, 2641, 2544, 5552, 5529, 6319, 2779, 2890, 2890, 3229, 3303, 9751, 1979, 664, 1859, 1302, 528, 1303, 9543, 5770, 3532, 1286, 1286, 1727, 9287, 5248, 5586, 594, 3385, 2613, 1717, 7529, 7634, 931, 1602, 4512, 850, 2748, 5056, 1086, 2320, 2320, 9320, 3223, 5592, 1122, 419, 24, 4126, 5200, 2712, 9549, 8676, 8676, 3443, 7598, 7598, 2200, 2745, 1215, 118, 3840, 2703, 1616, 8788, 1240, 3349, 4890, 2756, 166, 9574, 9773, 5887, 2516, 9332, 6092, 3377, 4334, 3127, 3127, 3127, 944, 3089, 5947, 6572, 6572, 7561, 4358, 4358, 4358, 8124, 5549, 9275, 82, 8830, 8830, 5949, 22, 6729, 6878, 3817, 1871, 6092, 1441, 3127, 3928, 8254, 7984, 1116, 2796, 1806, 3710, 797, 9269, 576, 576, 2020, 137, 6624, 3815, 8690, 3634, 6036, 3530, 8719, 3458, 138, 8745, 5233, 2235, 8580, 8580, 6831, 2709, 7136, 9693, 3437, 3437, 3238, 4368, 2321, 2321, 391, 391, 4976, 8622, 6722, 3864, 9113, 9113, 7222, 7222, 7937, 999, 1286, 1286, 7789, 9396, 9603, 6690, 5233, 2235, 618, 8830, 6954, 3668, 4302, 596, 1934, 2886, 2704, 9097, 4161, 458, 4147, 9245, 9245, 3127, 3127, 944, 9676, 9676, 3468, 270, 270, 4608, 5549, 4182, 102, 8568, 1286, 1286, 5087, 817, 4153, 207, 207, 3763, 6415, 5188, 6010, 554, 753, 9953, 5104, 3828, 1879, 995, 9683, 6932, 3644, 2683, 9335, 183, 5525, 7023, 9568, 6222, 6315, 676, 3443, 6971, 2084, 999, 1286, 1286, 9620, 9620, 1048, 5577, 9328, 4963, 1364, 8328, 4573, 4573, 7917, 7917, 560, 2020, 4923, 137, 9542, 5832, 9775, 4780, 9400, 2745, 2745, 8984, 628, 8834, 6932, 3817, 8312, 5393, 458, 4147, 9191, 2225, 2759, 8980, 2351, 193, 1476, 9347, 3063, 2076, 3641, 1614, 9832, 3554, 8197, 5589, 5589, 7306, 184, 1708, 2954, 2954, 3485, 3485, 7665, 8909, 5405, 3590, 3590, 3446, 6442, 6442, 2802, 5549, 3791]
- # fmt: on
- def load_watermarking_model() -> Optional[Module]:
- import importlib.util
- # Run in CPU mode until pretssel inconsistent behavious is fixed
- device = Device("cpu")
- dtype = torch.float32
- wm_py_file = Path(__file__).parents[3] / "scripts/watermarking/watermarking.py"
- assert wm_py_file.is_file()
- wm_spec = importlib.util.spec_from_file_location("watermark.f1", wm_py_file)
- assert wm_spec, f"Module not found: {wm_py_file}"
- wm_py_module = importlib.util.module_from_spec(wm_spec)
- assert wm_py_module, f"Invalid Python module file: {wm_py_file}"
- sys.modules["watermark.f1"] = wm_py_module
- assert wm_spec.loader, f"Module cannot be loaded from {wm_py_file}"
- wm_spec.loader.exec_module(wm_py_module)
- return cast(Module, wm_py_module.model_from_checkpoint(device=device, dtype=dtype))
- def test_pretssel_vocoder_watermarking(
- example_rate16k_audio: AudioDecoderOutput,
- ) -> None:
- """
- Test that the watermarked pretssel vocoder generates the same output
- as the non-watermarked (pretssel_generator)
- """
- audio = example_rate16k_audio
- # Run in CPU mode until pretssel inconsistent behavious is fixed
- device = Device("cpu")
- dtype = torch.float32
- audio["waveform"] = audio["waveform"].to(device, dtype=dtype)
- feat = convert_to_collated_fbank(audio, dtype=dtype)["seqs"][0]
- feat = feat.to(device, dtype=dtype)
- # Run the watermarked vocoding
- # TODO: Build a generator API for the watermarked vocoder
- vocoder = load_pretssel_vocoder_model(
- "vocoder_pretssel", device=device, dtype=dtype
- )
- units = torch.tensor(REF_FRA_UNITS, device=device, dtype=torch.int64)
- # adjust the control symbols for the embedding
- units += 4
- # eos_idx = 2 in the VocabularyInfo setting for base pretssel_vocoder
- unit_eos_token = torch.tensor([2], device=device)
- units = torch.cat([units, unit_eos_token], dim=0)
- units, duration = torch.unique_consecutive(units, return_counts=True)
- # adjust for the last eos token
- duration[-1] = 0
- duration *= 2
- # bos_idx=0 in base VocabularyInfo
- duration_collate = Collater(pad_value=0)
- duration_seqs = duration_collate(duration)
- with torch.no_grad():
- vocoder.eval()
- wav_wm = vocoder(
- seqs=units,
- tgt_lang="fra",
- prosody_input_seqs=feat,
- durations=duration_seqs["seqs"],
- normalize_before=True,
- )
- # torchaudio.save("wm.wav", wav_wm.squeeze(0).float().cpu(), sample_rate=16000)
- # Run the non-watermarked vocoder using pretssel generator
- gcmvn_mean, gcmvn_std = load_gcmvn_stats("pretssel_v1")
- gcmvn_mean = torch.tensor(gcmvn_mean, device=device, dtype=dtype) # type: ignore[assignment]
- gcmvn_std = torch.tensor(gcmvn_std, device=device, dtype=dtype) # type: ignore[assignment]
- generator = PretsselGenerator(
- "seamless_expressivity",
- "vocoder_mel_24khz",
- "pretssel_v1",
- gcmvn_mean=gcmvn_mean, # type: ignore[arg-type]
- gcmvn_std=gcmvn_std, # type: ignore[arg-type]
- device=device,
- dtype=dtype,
- )
- # PretsselGenerator expects a batch of units
- unit_list: List[List[int]] = [REF_FRA_UNITS]
- prosody_input_seqs = SequenceData(
- is_ragged=False,
- seqs=feat.unsqueeze(0), # add batch dim
- seq_lens=torch.tensor([feat.size(0)]),
- )
- speech_output = generator.predict(
- unit_list,
- tgt_lang="fra",
- prosody_encoder_input=prosody_input_seqs,
- )
- wav = speech_output.audio_wavs[0].unsqueeze(0)
- # torchaudio.save("mel.wav", wav.float().cpu(), sample_rate=16000)
- # Run the watermark model separately after the PretsselGenerator
- watermarker = load_watermarking_model()
- wm = watermarker.get_watermark(wav) # type: ignore
- wav_wm_hat = wav + wm
- # Test that the watermark is detecte-able
- detection = watermarker.detect_watermark(wav_wm) # type: ignore
- assert torch.all(detection[:, 1, :] > 0.5)
- # Remove the batch and compare parity on the overlapping frames
- wav_wm = wav_wm.squeeze(0)
- wav_wm_hat = wav_wm_hat.squeeze(0)
- nframes = min(wav_wm_hat.size(1), wav_wm.size(1))
- assert_close(
- wav_wm[:, :nframes],
- wav_wm_hat[:, :nframes],
- atol=0.0,
- rtol=5.0,
- )
|