Przeglądaj źródła

Enable testing watermarked vocoder with local cards (#249)

* update test and expand test suite params

* PJ's comments

* Update

* remove test code in model card

---------

Co-authored-by: Tuan Tran <tuantran@devfair0436.h2.fair>
Co-authored-by: Can Balioglu <cbalioglu@users.noreply.github.com>
Tuan Tran 1 rok temu
rodzic
commit
4afcf8e68a

+ 21 - 4
scripts/watermarking/watermarking.py

@@ -10,7 +10,7 @@
 import math
 from argparse import ArgumentParser, ArgumentTypeError
 from pathlib import Path
-from typing import Any, Dict, Optional, Union, cast
+from typing import Any, Dict, Union
 
 import audiocraft
 import omegaconf
@@ -115,6 +115,7 @@ class Watermarker(nn.Module):
 
 def model_from_checkpoint(
     config_file: Union[Path, str] = "seamlesswatermark.yaml",
+    checkpoint: str = "",
     device: Union[torch.device, str] = "cpu",
     dtype: DataType = torch.float32,
 ) -> Watermarker:
@@ -151,7 +152,11 @@ def model_from_checkpoint(
     """
     config_path = Path(__file__).parent / config_file
     cfg = omegaconf.OmegaConf.load(config_path)
-    state: Dict[str, Any] = torch.load(cfg["checkpoint"], map_location=device)
+    if checkpoint and Path(checkpoint).is_file():
+        ckpt = checkpoint
+    else:
+        ckpt = cfg["checkpoint"]
+    state: Dict[str, Any] = torch.load(ckpt, map_location=device)
     if "model" in state and "xp.cfg" in state:
         cfg = omegaconf.OmegaConf.create(state["xp.cfg"])
         omegaconf.OmegaConf.resolve(cfg)
@@ -188,7 +193,13 @@ def get_detector(cfg: omegaconf.DictConfig):
     kwargs.pop("decoder")
     kwargs.pop("encoder")
     encoder_kwargs = {**kwargs, **encoder_override_kwargs}
-    output_hidden_dim = 8
+
+    # Some new checkpoints of watermarking was trained on a newer code, where
+    # `output_hidden_dim` is renamed to `output_dim`
+    if "output_dim" in encoder_kwargs:
+        output_hidden_dim = encoder_kwargs.pop("output_dim")
+    else:
+        output_hidden_dim = 8
     encoder = SEANetEncoderKeepDimension(output_hidden_dim, **encoder_kwargs)
 
     last_layer = torch.nn.Conv1d(output_hidden_dim, 2, 1)
@@ -220,6 +231,12 @@ if __name__ == "__main__":
         type=str,
         help="path to a config or checkpoint file (default: %(default)s)",
     )
+    parser.add_argument(
+        "--checkpoint",
+        default="",
+        type=str,
+        help="inline argument to override the value `checkpoint` specified in the file `model-file`",
+    )
     sub_parser = parser.add_subparsers(title="actions", dest="sub_cmd")
     detect_parser = sub_parser.add_parser("detect")
     wm_parser = sub_parser.add_parser("wm")
@@ -228,7 +245,7 @@ if __name__ == "__main__":
     args = parser.parse_args()
 
     if args.sub_cmd == "detect":
-        model = model_from_checkpoint(args.model_file, device=args.device)
+        model = model_from_checkpoint(args.model_file, checkpoint=args.checkpoint, device=args.device)
         wav, _ = torchaudio.load(args.file)
         wav = wav.unsqueeze(0)
         wav = wav.to(args.device)

+ 0 - 2
src/seamless_communication/models/generator/loader.py

@@ -5,8 +5,6 @@
 # MIT_LICENSE file in the root directory of this source tree.
 
 
-from typing import Any, Mapping
-
 from fairseq2.assets import asset_store, download_manager
 from fairseq2.models.utils import ConfigLoader, ModelLoader
 

+ 1 - 1
tests/common.py

@@ -56,7 +56,7 @@ def assert_unit_close(
     if percent_unit_tol > 0.0:
         num_unit_tol = int(percent_unit_tol * len(a))
 
-    num_unit_diff = (a != b).sum()
+    num_unit_diff = (a != b).sum()  # type: ignore
     assert (
         num_unit_diff <= num_unit_tol
     ), f"The difference is beyond tolerance, {num_unit_diff} units are different, tolerance is {num_unit_tol}"

+ 4 - 0
tests/conftest.py

@@ -37,6 +37,10 @@ def pytest_addoption(parser: pytest.Parser) -> None:
 def pytest_sessionstart(session: pytest.Session) -> None:
     tests.common.device = cast(Device, session.config.getoption("device"))
 
+    from fairseq2.assets import asset_store
+
+    asset_store.env_resolvers.append(lambda: "integ_test")
+
 
 @pytest.fixture(scope="module")
 def example_rate16k_audio() -> AudioDecoderOutput:

+ 7 - 2
tests/integration/models/test_watermarked_vocoder.py

@@ -8,6 +8,7 @@ import sys
 from argparse import Namespace
 from pathlib import Path
 from typing import Final, List, Optional, cast
+import os
 import pytest
 
 import torch
@@ -51,7 +52,9 @@ def load_watermarking_model() -> Optional[Module]:
     assert wm_spec.loader, f"Module cannot be loaded from {wm_py_file}"
     wm_spec.loader.exec_module(wm_py_module)
 
-    return cast(Module, wm_py_module.model_from_checkpoint(device=device, dtype=dtype))
+    ckpt = os.getenv("SEAMLESS_WM_CKPT", "")
+
+    return cast(Module, wm_py_module.model_from_checkpoint(device=device, checkpoint=ckpt, dtype=dtype))
 
 
 @pytest.mark.parametrize("sr", [16_000, 24_000])
@@ -133,7 +136,9 @@ def test_pretssel_vocoder_watermarking(
 
     # Test that the watermark is detectable
     detection = watermarker.detect_watermark(wav_wm)  # type: ignore
-    assert torch.all(detection[:, 1, :] > 0.5)
+
+    # 0.9 is the current lower bound of Watermarking w.r.t all attacks
+    assert torch.count_nonzero(torch.gt(detection[:, 1, :], 0.5)) / detection.shape[-1] > 0.9
 
     # Remove the batch and compare parity on the overlapping frames
     wav_wm = wav_wm.squeeze(0)