watermarking.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. # The original implementation for the watermarking
  7. # This is not open-sourced and only kept here for future reference
  8. # mypy: ignore-errors
  9. import math
  10. from argparse import ArgumentParser, ArgumentTypeError
  11. from pathlib import Path
  12. from typing import Any, Dict, Union
  13. import audiocraft
  14. import omegaconf
  15. import torch
  16. import torch.nn as nn
  17. import torchaudio
  18. from audiocraft.modules.seanet import SEANetEncoder
  19. from audiocraft.utils.utils import dict_from_config
  20. from fairseq2.typing import DataType, Device
  21. class SEANetEncoderKeepDimension(SEANetEncoder):
  22. """
  23. similar architecture to the SEANet encoder but with an extra step that
  24. projects the output dimension to the same input dimension by repeating
  25. the sequential
  26. Args:
  27. SEANetEncoder (_type_): _description_
  28. """
  29. def __init__(self, output_hidden_dim: int = 8, *args, **kwargs): # type: ignore
  30. super().__init__(*args, **kwargs)
  31. self.output_hidden_dim = output_hidden_dim
  32. # Adding a reverse convolution layer
  33. self.reverse_convolution = nn.ConvTranspose1d(
  34. in_channels=self.dimension,
  35. out_channels=self.output_hidden_dim,
  36. kernel_size=math.prod(self.ratios),
  37. stride=math.prod(self.ratios),
  38. padding=0,
  39. )
  40. def forward(self, x: torch.Tensor) -> torch.Tensor:
  41. orig_nframes = x.shape[-1]
  42. x = self.model(x)
  43. x = self.reverse_convolution(x)
  44. # make sure dim didn't change
  45. x = x[:, :, :orig_nframes]
  46. return x
  47. class Watermarker(nn.Module):
  48. """
  49. Initialize the Watermarker model.
  50. Args:
  51. encoder (nn.Module): Watermark Encoder.
  52. decoder (nn.Module): Watermark Decoder.
  53. detector (nn.Module): Watermark Detector.
  54. """
  55. encoder: SEANetEncoder
  56. decoder: SEANetEncoder
  57. detector: SEANetEncoderKeepDimension
  58. def __init__(
  59. self,
  60. encoder: SEANetEncoder,
  61. decoder: SEANetEncoder,
  62. detector: SEANetEncoderKeepDimension,
  63. ):
  64. super().__init__()
  65. self.encoder = encoder
  66. self.decoder = decoder
  67. self.detector = detector
  68. def get_watermark(self, x: torch.Tensor) -> torch.Tensor:
  69. """
  70. Get the watermark from a batch of audio input.
  71. Args:
  72. x (torch.Tensor): Input audio tensor with dimensions [batch size, channels = 1, frames].
  73. Returns:
  74. torch.Tensor: Output watermark with the same dimensionality as the input.
  75. """
  76. hidden = self.encoder(x)
  77. # assert dim in = dim out
  78. watermark = self.decoder(hidden)[:, :, : x.size(-1)]
  79. return watermark
  80. def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
  81. """
  82. Detect the watermark in a batch of audio input.
  83. Args:
  84. x (torch.Tensor): Input audio tensor with dimensions
  85. [batch size, channels = 1, frames].
  86. Returns:
  87. torch.Tensor: Predictions of the classifier for watermark
  88. with dimensions [bsz, classes = 2, frames].
  89. For each frame, the detector outputs probabilities of
  90. non-watermarked class (class id 0) and
  91. the probability of "watermarked" class (class id 1).
  92. To do inference, you can use output[:, 1, :]
  93. to get probabilities of input audio being watermarked.
  94. """
  95. return self.detector(x)
  96. def model_from_checkpoint(
  97. config_file: Union[Path, str] = "seamlesswatermark.yaml",
  98. checkpoint: str = "",
  99. device: Union[torch.device, str] = "cpu",
  100. dtype: DataType = torch.float32,
  101. ) -> Watermarker:
  102. """Instantiate a Watermarker model from a given checkpoint path.
  103. Example usage:
  104. >>> from watermarking.watermarking import *
  105. >>> cfg = "seamlesswatermark.yaml"
  106. >>> url = "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav"
  107. >>> urllib.request.urlretrieve(url, "random.wav")
  108. >>> wav, _ = torchaudio.load("random.wav")
  109. >>> wav = wav.unsqueeze(0) # add bsz dimension
  110. >>> model = model_from_config(cfg, device = wav.device)
  111. # Other way is to load directly from the checkpoint
  112. >>> model = model_from_checkpoint(checkpoint_path, device = wav.device)
  113. >>> watermark = model.get_watermark(wav)
  114. >>> watermarked_audio = wav + watermark
  115. >>> detection = model.detect_watermark(watermarked_audio)
  116. >>> print(detection[:,1,:]) # print prob of watermarked class # should be > 0.5
  117. >>> detection = model.detect_watermark(wav)
  118. >>> print(detection[:,1,:]) # print prob of watermarked class # should be < 0.5
  119. Args:
  120. checkpoint_path (Path or str): Path to the checkpoint file.
  121. device (torch.device or str, optional): Device on which
  122. the model is loaded (default is "cpu").
  123. Returns:
  124. Watermarker: An instance of the Watermarker model loaded from the checkpoint.
  125. """
  126. config_path = Path(__file__).parent / config_file
  127. cfg = omegaconf.OmegaConf.load(config_path)
  128. if checkpoint and Path(checkpoint).is_file():
  129. ckpt = checkpoint
  130. else:
  131. ckpt = cfg["checkpoint"]
  132. state: Dict[str, Any] = torch.load(ckpt, map_location=device)
  133. if "model" in state and "xp.cfg" in state:
  134. cfg = omegaconf.OmegaConf.create(state["xp.cfg"])
  135. omegaconf.OmegaConf.resolve(cfg)
  136. state = state["model"]
  137. watermarking_model = get_watermarking_model(cfg)
  138. watermarking_model.load_state_dict(state)
  139. watermarking_model = watermarking_model.to(device, dtype=dtype)
  140. watermarking_model.eval()
  141. return watermarking_model
  142. def get_watermarking_model(cfg: omegaconf.DictConfig) -> Watermarker:
  143. encoder, decoder = get_encodec_autoencoder(cfg)
  144. detector = get_detector(cfg)
  145. return Watermarker(encoder, decoder, detector)
  146. def get_encodec_autoencoder(cfg: omegaconf.DictConfig):
  147. kwargs = dict_from_config(getattr(cfg, "seanet"))
  148. if hasattr(cfg.seanet, "detector"):
  149. kwargs.pop("detector")
  150. encoder_override_kwargs = kwargs.pop("encoder")
  151. decoder_override_kwargs = kwargs.pop("decoder")
  152. encoder_kwargs = {**kwargs, **encoder_override_kwargs}
  153. decoder_kwargs = {**kwargs, **decoder_override_kwargs}
  154. encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
  155. decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
  156. return encoder, decoder
  157. def get_detector(cfg: omegaconf.DictConfig):
  158. kwargs = dict_from_config(getattr(cfg, "seanet"))
  159. encoder_override_kwargs = kwargs.pop("detector")
  160. kwargs.pop("decoder")
  161. kwargs.pop("encoder")
  162. encoder_kwargs = {**kwargs, **encoder_override_kwargs}
  163. # Some new checkpoints of watermarking was trained on a newer code, where
  164. # `output_hidden_dim` is renamed to `output_dim`
  165. if "output_dim" in encoder_kwargs:
  166. output_hidden_dim = encoder_kwargs.pop("output_dim")
  167. else:
  168. output_hidden_dim = 8
  169. encoder = SEANetEncoderKeepDimension(output_hidden_dim, **encoder_kwargs)
  170. last_layer = torch.nn.Conv1d(output_hidden_dim, 2, 1)
  171. softmax = torch.nn.Softmax(dim=1)
  172. detector = torch.nn.Sequential(encoder, last_layer, softmax)
  173. return detector
  174. def parse_device_arg(value: str) -> Device:
  175. try:
  176. return Device(value)
  177. except RuntimeError:
  178. raise ArgumentTypeError(f"'{value}' is not a valid device name.")
  179. if __name__ == "__main__":
  180. # Example usage: python watermarking.py --device cuda:0 detect [file.wav]
  181. parser = ArgumentParser(description="Handle the watermarking for audios")
  182. parser.add_argument(
  183. "--device",
  184. default="cpu",
  185. type=parse_device_arg,
  186. help="device on which to run tests (default: %(default)s)",
  187. )
  188. parser.add_argument(
  189. "--model-file",
  190. default="seamlesswatermark.yaml",
  191. type=str,
  192. help="path to a config or checkpoint file (default: %(default)s)",
  193. )
  194. parser.add_argument(
  195. "--checkpoint",
  196. default="",
  197. type=str,
  198. help="inline argument to override the value `checkpoint` specified in the file `model-file`",
  199. )
  200. sub_parser = parser.add_subparsers(title="actions", dest="sub_cmd")
  201. detect_parser = sub_parser.add_parser("detect")
  202. wm_parser = sub_parser.add_parser("wm")
  203. parser.add_argument("file", type=str, help="Path to the .wav file")
  204. args = parser.parse_args()
  205. if args.sub_cmd == "detect":
  206. model = model_from_checkpoint(args.model_file, checkpoint=args.checkpoint, device=args.device)
  207. wav, _ = torchaudio.load(args.file)
  208. wav = wav.unsqueeze(0)
  209. wav = wav.to(args.device)
  210. detection = model.detect_watermark(wav)
  211. print(detection[:, 1, :])
  212. print(torch.count_nonzero(torch.gt(detection[:, 1, :], 0.5)))