|
@@ -10,7 +10,7 @@
|
|
|
import math
|
|
|
from argparse import ArgumentParser, ArgumentTypeError
|
|
|
from pathlib import Path
|
|
|
-from typing import Any, Dict, Optional, Union, cast
|
|
|
+from typing import Any, Dict, Union
|
|
|
|
|
|
import audiocraft
|
|
|
import omegaconf
|
|
@@ -115,6 +115,7 @@ class Watermarker(nn.Module):
|
|
|
|
|
|
def model_from_checkpoint(
|
|
|
config_file: Union[Path, str] = "seamlesswatermark.yaml",
|
|
|
+ checkpoint: str = "",
|
|
|
device: Union[torch.device, str] = "cpu",
|
|
|
dtype: DataType = torch.float32,
|
|
|
) -> Watermarker:
|
|
@@ -151,7 +152,11 @@ def model_from_checkpoint(
|
|
|
"""
|
|
|
config_path = Path(__file__).parent / config_file
|
|
|
cfg = omegaconf.OmegaConf.load(config_path)
|
|
|
- state: Dict[str, Any] = torch.load(cfg["checkpoint"], map_location=device)
|
|
|
+ if checkpoint and Path(checkpoint).is_file():
|
|
|
+ ckpt = checkpoint
|
|
|
+ else:
|
|
|
+ ckpt = cfg["checkpoint"]
|
|
|
+ state: Dict[str, Any] = torch.load(ckpt, map_location=device)
|
|
|
if "model" in state and "xp.cfg" in state:
|
|
|
cfg = omegaconf.OmegaConf.create(state["xp.cfg"])
|
|
|
omegaconf.OmegaConf.resolve(cfg)
|
|
@@ -188,7 +193,13 @@ def get_detector(cfg: omegaconf.DictConfig):
|
|
|
kwargs.pop("decoder")
|
|
|
kwargs.pop("encoder")
|
|
|
encoder_kwargs = {**kwargs, **encoder_override_kwargs}
|
|
|
- output_hidden_dim = 8
|
|
|
+
|
|
|
+ # Some new checkpoints of watermarking was trained on a newer code, where
|
|
|
+ # `output_hidden_dim` is renamed to `output_dim`
|
|
|
+ if "output_dim" in encoder_kwargs:
|
|
|
+ output_hidden_dim = encoder_kwargs.pop("output_dim")
|
|
|
+ else:
|
|
|
+ output_hidden_dim = 8
|
|
|
encoder = SEANetEncoderKeepDimension(output_hidden_dim, **encoder_kwargs)
|
|
|
|
|
|
last_layer = torch.nn.Conv1d(output_hidden_dim, 2, 1)
|
|
@@ -220,6 +231,12 @@ if __name__ == "__main__":
|
|
|
type=str,
|
|
|
help="path to a config or checkpoint file (default: %(default)s)",
|
|
|
)
|
|
|
+ parser.add_argument(
|
|
|
+ "--checkpoint",
|
|
|
+ default="",
|
|
|
+ type=str,
|
|
|
+ help="inline argument to override the value `checkpoint` specified in the file `model-file`",
|
|
|
+ )
|
|
|
sub_parser = parser.add_subparsers(title="actions", dest="sub_cmd")
|
|
|
detect_parser = sub_parser.add_parser("detect")
|
|
|
wm_parser = sub_parser.add_parser("wm")
|
|
@@ -228,7 +245,7 @@ if __name__ == "__main__":
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
if args.sub_cmd == "detect":
|
|
|
- model = model_from_checkpoint(args.model_file, device=args.device)
|
|
|
+ model = model_from_checkpoint(args.model_file, checkpoint=args.checkpoint, device=args.device)
|
|
|
wav, _ = torchaudio.load(args.file)
|
|
|
wav = wav.unsqueeze(0)
|
|
|
wav = wav.to(args.device)
|