Browse Source

load model on correct device (#196)

Pierre Andrews 1 year ago
parent
commit
bf7697ec82
1 changed files with 1 additions and 1 deletions
  1. 1 1
      scripts/watermarking/watermarking.py

+ 1 - 1
scripts/watermarking/watermarking.py

@@ -151,7 +151,7 @@ def model_from_checkpoint(
     """
     config_path = Path(__file__).parent / config_file
     cfg = omegaconf.OmegaConf.load(config_path)
-    state: Dict[str, Any] = torch.load(cfg["checkpoint"])
+    state: Dict[str, Any] = torch.load(cfg["checkpoint"], map_location=device)
     if "model" in state and "xp.cfg" in state:
         cfg = omegaconf.OmegaConf.create(state["xp.cfg"])
         omegaconf.OmegaConf.resolve(cfg)