watermarking.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the
  5. # LICENSE file in the root directory of this source tree.
  6. # The original implementation for the watermarking
  7. # This is not open-sourced and only kept here for future reference
  8. # mypy: ignore-errors
  9. import math
  10. from argparse import ArgumentParser, ArgumentTypeError
  11. from pathlib import Path
  12. from typing import Any, Dict, Union, cast
  13. import audiocraft
  14. import omegaconf
  15. import torch
  16. import torch.nn as nn
  17. import torchaudio
  18. from audiocraft.modules.seanet import SEANetEncoder
  19. from audiocraft.utils.utils import dict_from_config
  20. from fairseq2.typing import DataType, Device
  21. class SEANetEncoderKeepDimension(SEANetEncoder):
  22. """
  23. similar architecture to the SEANet encoder but with an extra step that
  24. projects the output dimension to the same input dimension by repeating
  25. the sequential
  26. Args:
  27. SEANetEncoder (_type_): _description_
  28. """
  29. def __init__(self, output_hidden_dim: int = 8, *args, **kwargs): # type: ignore
  30. super().__init__(*args, **kwargs)
  31. self.output_hidden_dim = output_hidden_dim
  32. # Adding a reverse convolution layer
  33. self.reverse_convolution = nn.ConvTranspose1d(
  34. in_channels=self.dimension,
  35. out_channels=self.output_hidden_dim,
  36. kernel_size=math.prod(self.ratios),
  37. stride=math.prod(self.ratios),
  38. padding=0,
  39. )
  40. def forward(self, x: torch.Tensor) -> torch.Tensor:
  41. orig_nframes = x.shape[-1]
  42. x = self.model(x)
  43. x = self.reverse_convolution(x)
  44. # make sure dim didn't change
  45. x = x[:, :, :orig_nframes]
  46. return x
  47. class Watermarker(nn.Module):
  48. """
  49. Initialize the Watermarker model.
  50. Args:
  51. encoder (nn.Module): Watermark Encoder.
  52. decoder (nn.Module): Watermark Decoder.
  53. detector (nn.Module): Watermark Detector.
  54. sample_rate (int): Audio sample rate.
  55. channels (int): Number of audio channels.
  56. """
  57. sample_rate: int = 0
  58. channels: int = 0
  59. encoder: SEANetEncoder
  60. decoder: SEANetEncoder
  61. detector: SEANetEncoderKeepDimension
  62. def __init__(
  63. self,
  64. encoder: SEANetEncoder,
  65. decoder: SEANetEncoder,
  66. detector: SEANetEncoderKeepDimension,
  67. sample_rate: int,
  68. channels: int,
  69. ):
  70. super().__init__()
  71. self.encoder = encoder
  72. self.decoder = decoder
  73. self.detector = detector
  74. self.sample_rate = sample_rate
  75. self.channels = channels
  76. def get_watermark(self, x: torch.Tensor) -> torch.Tensor:
  77. """
  78. Get the watermark from a batch of audio input.
  79. Args:
  80. x (torch.Tensor): Input audio tensor with dimensions [batch size, channels = 1, frames].
  81. Returns:
  82. torch.Tensor: Output watermark with the same dimensionality as the input.
  83. """
  84. hidden = self.encoder(x)
  85. # assert dim in = dim out
  86. watermark = self.decoder(hidden)[:, :, : x.size(-1)]
  87. return watermark
  88. def detect_watermark(self, x: torch.Tensor) -> torch.Tensor:
  89. """
  90. Detect the watermark in a batch of audio input.
  91. Args:
  92. x (torch.Tensor): Input audio tensor with dimensions
  93. [batch size, channels = 1, frames].
  94. Returns:
  95. torch.Tensor: Predictions of the classifier for watermark
  96. with dimensions [bsz, classes = 2, frames].
  97. For each frame, the detector outputs probabilities of
  98. non-watermarked class (class id 0) and
  99. the probability of "watermarked" class (class id 1).
  100. To do inference, you can use output[:, 1, :]
  101. to get probabilities of input audio being watermarked.
  102. """
  103. return self.detector(x)
  104. def model_from_checkpoint(
  105. checkpoint_path: Union[Path, str] = Path(__file__).parent
  106. / "seamlesswatermark.yaml",
  107. device: Union[torch.device, str] = "cpu",
  108. dtype: DataType = torch.float32,
  109. ) -> Watermarker:
  110. """Instantiate a Watermarker model from a given checkpoint path.
  111. Example usage:
  112. >>> from watermarking.watermarking import *
  113. >>> cfg = "seamlesswatermark.yaml"
  114. >>> url = "https://keithito.com/LJ-Speech-Dataset/LJ037-0171.wav"
  115. >>> urllib.request.urlretrieve(url, "random.wav")
  116. >>> wav, _ = torchaudio.load("random.wav")
  117. >>> wav = wav.unsqueeze(0) # add bsz dimension
  118. # code starts here
  119. >>> model = model_from_checkpoint(cfg, device = wav.device)
  120. >>> watermark = model.get_watermark(wav)
  121. >>> watermarked_audio = wav + watermark
  122. >>> detection = model.detect_watermark(watermarked_audio)
  123. >>> print(detection[:,1,:]) # print prob of watermarked class # should be > 0.5
  124. >>> detection = model.detect_watermark(wav)
  125. >>> print(detection[:,1,:]) # print prob of watermarked class # should be < 0.5
  126. Args:
  127. checkpoint_path (Path or str): Path to the checkpoint file.
  128. device (torch.device or str, optional): Device on which
  129. the model is loaded (default is "cpu").
  130. Returns:
  131. Watermarker: An instance of the Watermarker model loaded from the checkpoint.
  132. """
  133. cfg = omegaconf.OmegaConf.load(checkpoint_path)
  134. state: Dict[str, Any] = torch.load(cfg["checkpoint"])
  135. watermarking_model = get_watermarking_model(cfg)
  136. watermarking_model.load_state_dict(state)
  137. watermarking_model = watermarking_model.to(device, dtype=dtype)
  138. watermarking_model.eval()
  139. return watermarking_model
  140. def get_watermarking_model(cfg: omegaconf.DictConfig) -> Watermarker:
  141. kwargs = dict_from_config(getattr(cfg, "watermarker_model"))
  142. encoder, decoder = get_encodec_autoencoder(cfg)
  143. detector = get_detector(cfg)
  144. return Watermarker(encoder, decoder, detector, **kwargs)
  145. def get_encodec_autoencoder(cfg: omegaconf.DictConfig):
  146. kwargs = dict_from_config(getattr(cfg, "seanet"))
  147. if hasattr(cfg.seanet, "detector"):
  148. kwargs.pop("detector")
  149. encoder_override_kwargs = kwargs.pop("encoder")
  150. decoder_override_kwargs = kwargs.pop("decoder")
  151. encoder_kwargs = {**kwargs, **encoder_override_kwargs}
  152. decoder_kwargs = {**kwargs, **decoder_override_kwargs}
  153. encoder = audiocraft.modules.SEANetEncoder(**encoder_kwargs)
  154. decoder = audiocraft.modules.SEANetDecoder(**decoder_kwargs)
  155. return encoder, decoder
  156. def get_detector(cfg: omegaconf.DictConfig):
  157. kwargs = dict_from_config(getattr(cfg, "seanet"))
  158. encoder_override_kwargs = kwargs.pop("detector")
  159. kwargs.pop("decoder")
  160. kwargs.pop("encoder")
  161. encoder_kwargs = {**kwargs, **encoder_override_kwargs}
  162. output_hidden_dim = 8
  163. encoder = SEANetEncoderKeepDimension(output_hidden_dim, **encoder_kwargs)
  164. last_layer = torch.nn.Conv1d(output_hidden_dim, 2, 1)
  165. softmax = torch.nn.Softmax(dim=1)
  166. detector = torch.nn.Sequential(encoder, last_layer, softmax)
  167. return detector
  168. def parse_device_arg(value: str) -> Device:
  169. try:
  170. return Device(value)
  171. except RuntimeError:
  172. raise ArgumentTypeError(f"'{value}' is not a valid device name.")
  173. if __name__ == "__main__":
  174. """
  175. Example usage:
  176. python watermarking.py --device cuda:0 detect [file.wav]
  177. """
  178. parser = ArgumentParser(description="Handle the watermarking for audios")
  179. parser.add_argument(
  180. "--device",
  181. default="cpu",
  182. type=parse_device_arg,
  183. help="device on which to run tests (default: %(default)s)",
  184. )
  185. sub_parser = parser.add_subparsers(title="actions", dest="sub_cmd")
  186. detect_parser = sub_parser.add_parser("detect")
  187. wm_parser = sub_parser.add_parser("wm")
  188. parser.add_argument("file", type=str, help="Path to the .wav file")
  189. args = parser.parse_args()
  190. if args.sub_cmd == "detect":
  191. model = model_from_checkpoint(device=args.device)
  192. wav, _ = torchaudio.load(args.file)
  193. wav = wav.unsqueeze(0)
  194. wav = wav.to(args.device)
  195. detection = model.detect_watermark(wav)
  196. print(detection[:, 1, :])