소스 검색

Enable testing watermarked vocoder with local cards (#249)

* update test and expand test suite params

* PJ's comments

* Update

* remove test code in model card

---------

Co-authored-by: Tuan Tran <tuantran@devfair0436.h2.fair>
Co-authored-by: Can Balioglu <cbalioglu@users.noreply.github.com>
Tuan Tran 1 년 전
부모
커밋
4afcf8e68a

+ 21 - 4
scripts/watermarking/watermarking.py

@@ -10,7 +10,7 @@
 import math
 import math
 from argparse import ArgumentParser, ArgumentTypeError
 from argparse import ArgumentParser, ArgumentTypeError
 from pathlib import Path
 from pathlib import Path
-from typing import Any, Dict, Optional, Union, cast
+from typing import Any, Dict, Union
 
 
 import audiocraft
 import audiocraft
 import omegaconf
 import omegaconf
@@ -115,6 +115,7 @@ class Watermarker(nn.Module):
 
 
 def model_from_checkpoint(
 def model_from_checkpoint(
     config_file: Union[Path, str] = "seamlesswatermark.yaml",
     config_file: Union[Path, str] = "seamlesswatermark.yaml",
+    checkpoint: str = "",
     device: Union[torch.device, str] = "cpu",
     device: Union[torch.device, str] = "cpu",
     dtype: DataType = torch.float32,
     dtype: DataType = torch.float32,
 ) -> Watermarker:
 ) -> Watermarker:
@@ -151,7 +152,11 @@ def model_from_checkpoint(
     """
     """
     config_path = Path(__file__).parent / config_file
     config_path = Path(__file__).parent / config_file
     cfg = omegaconf.OmegaConf.load(config_path)
     cfg = omegaconf.OmegaConf.load(config_path)
-    state: Dict[str, Any] = torch.load(cfg["checkpoint"], map_location=device)
+    if checkpoint and Path(checkpoint).is_file():
+        ckpt = checkpoint
+    else:
+        ckpt = cfg["checkpoint"]
+    state: Dict[str, Any] = torch.load(ckpt, map_location=device)
     if "model" in state and "xp.cfg" in state:
     if "model" in state and "xp.cfg" in state:
         cfg = omegaconf.OmegaConf.create(state["xp.cfg"])
         cfg = omegaconf.OmegaConf.create(state["xp.cfg"])
         omegaconf.OmegaConf.resolve(cfg)
         omegaconf.OmegaConf.resolve(cfg)
@@ -188,7 +193,13 @@ def get_detector(cfg: omegaconf.DictConfig):
     kwargs.pop("decoder")
     kwargs.pop("decoder")
     kwargs.pop("encoder")
     kwargs.pop("encoder")
     encoder_kwargs = {**kwargs, **encoder_override_kwargs}
     encoder_kwargs = {**kwargs, **encoder_override_kwargs}
-    output_hidden_dim = 8
+
+    # Some new checkpoints of watermarking was trained on a newer code, where
+    # `output_hidden_dim` is renamed to `output_dim`
+    if "output_dim" in encoder_kwargs:
+        output_hidden_dim = encoder_kwargs.pop("output_dim")
+    else:
+        output_hidden_dim = 8
     encoder = SEANetEncoderKeepDimension(output_hidden_dim, **encoder_kwargs)
     encoder = SEANetEncoderKeepDimension(output_hidden_dim, **encoder_kwargs)
 
 
     last_layer = torch.nn.Conv1d(output_hidden_dim, 2, 1)
     last_layer = torch.nn.Conv1d(output_hidden_dim, 2, 1)
@@ -220,6 +231,12 @@ if __name__ == "__main__":
         type=str,
         type=str,
         help="path to a config or checkpoint file (default: %(default)s)",
         help="path to a config or checkpoint file (default: %(default)s)",
     )
     )
+    parser.add_argument(
+        "--checkpoint",
+        default="",
+        type=str,
+        help="inline argument to override the value `checkpoint` specified in the file `model-file`",
+    )
     sub_parser = parser.add_subparsers(title="actions", dest="sub_cmd")
     sub_parser = parser.add_subparsers(title="actions", dest="sub_cmd")
     detect_parser = sub_parser.add_parser("detect")
     detect_parser = sub_parser.add_parser("detect")
     wm_parser = sub_parser.add_parser("wm")
     wm_parser = sub_parser.add_parser("wm")
@@ -228,7 +245,7 @@ if __name__ == "__main__":
     args = parser.parse_args()
     args = parser.parse_args()
 
 
     if args.sub_cmd == "detect":
     if args.sub_cmd == "detect":
-        model = model_from_checkpoint(args.model_file, device=args.device)
+        model = model_from_checkpoint(args.model_file, checkpoint=args.checkpoint, device=args.device)
         wav, _ = torchaudio.load(args.file)
         wav, _ = torchaudio.load(args.file)
         wav = wav.unsqueeze(0)
         wav = wav.unsqueeze(0)
         wav = wav.to(args.device)
         wav = wav.to(args.device)

+ 0 - 2
src/seamless_communication/models/generator/loader.py

@@ -5,8 +5,6 @@
 # MIT_LICENSE file in the root directory of this source tree.
 # MIT_LICENSE file in the root directory of this source tree.
 
 
 
 
-from typing import Any, Mapping
-
 from fairseq2.assets import asset_store, download_manager
 from fairseq2.assets import asset_store, download_manager
 from fairseq2.models.utils import ConfigLoader, ModelLoader
 from fairseq2.models.utils import ConfigLoader, ModelLoader
 
 

+ 1 - 1
tests/common.py

@@ -56,7 +56,7 @@ def assert_unit_close(
     if percent_unit_tol > 0.0:
     if percent_unit_tol > 0.0:
         num_unit_tol = int(percent_unit_tol * len(a))
         num_unit_tol = int(percent_unit_tol * len(a))
 
 
-    num_unit_diff = (a != b).sum()
+    num_unit_diff = (a != b).sum()  # type: ignore
     assert (
     assert (
         num_unit_diff <= num_unit_tol
         num_unit_diff <= num_unit_tol
     ), f"The difference is beyond tolerance, {num_unit_diff} units are different, tolerance is {num_unit_tol}"
     ), f"The difference is beyond tolerance, {num_unit_diff} units are different, tolerance is {num_unit_tol}"

+ 4 - 0
tests/conftest.py

@@ -37,6 +37,10 @@ def pytest_addoption(parser: pytest.Parser) -> None:
 def pytest_sessionstart(session: pytest.Session) -> None:
 def pytest_sessionstart(session: pytest.Session) -> None:
     tests.common.device = cast(Device, session.config.getoption("device"))
     tests.common.device = cast(Device, session.config.getoption("device"))
 
 
+    from fairseq2.assets import asset_store
+
+    asset_store.env_resolvers.append(lambda: "integ_test")
+
 
 
 @pytest.fixture(scope="module")
 @pytest.fixture(scope="module")
 def example_rate16k_audio() -> AudioDecoderOutput:
 def example_rate16k_audio() -> AudioDecoderOutput:

+ 7 - 2
tests/integration/models/test_watermarked_vocoder.py

@@ -8,6 +8,7 @@ import sys
 from argparse import Namespace
 from argparse import Namespace
 from pathlib import Path
 from pathlib import Path
 from typing import Final, List, Optional, cast
 from typing import Final, List, Optional, cast
+import os
 import pytest
 import pytest
 
 
 import torch
 import torch
@@ -51,7 +52,9 @@ def load_watermarking_model() -> Optional[Module]:
     assert wm_spec.loader, f"Module cannot be loaded from {wm_py_file}"
     assert wm_spec.loader, f"Module cannot be loaded from {wm_py_file}"
     wm_spec.loader.exec_module(wm_py_module)
     wm_spec.loader.exec_module(wm_py_module)
 
 
-    return cast(Module, wm_py_module.model_from_checkpoint(device=device, dtype=dtype))
+    ckpt = os.getenv("SEAMLESS_WM_CKPT", "")
+
+    return cast(Module, wm_py_module.model_from_checkpoint(device=device, checkpoint=ckpt, dtype=dtype))
 
 
 
 
 @pytest.mark.parametrize("sr", [16_000, 24_000])
 @pytest.mark.parametrize("sr", [16_000, 24_000])
@@ -133,7 +136,9 @@ def test_pretssel_vocoder_watermarking(
 
 
     # Test that the watermark is detectable
     # Test that the watermark is detectable
     detection = watermarker.detect_watermark(wav_wm)  # type: ignore
     detection = watermarker.detect_watermark(wav_wm)  # type: ignore
-    assert torch.all(detection[:, 1, :] > 0.5)
+
+    # 0.9 is the current lower bound of Watermarking w.r.t all attacks
+    assert torch.count_nonzero(torch.gt(detection[:, 1, :], 0.5)) / detection.shape[-1] > 0.9
 
 
     # Remove the batch and compare parity on the overlapping frames
     # Remove the batch and compare parity on the overlapping frames
     wav_wm = wav_wm.squeeze(0)
     wav_wm = wav_wm.squeeze(0)