Explorar o código

load model on correct device (#196)

Pierre Andrews hai 1 ano
pai
achega
bf7697ec82
Modificáronse 1 ficheiros con 1 adicións e 1 borrados
  1. 1 1
      scripts/watermarking/watermarking.py

+ 1 - 1
scripts/watermarking/watermarking.py

@@ -151,7 +151,7 @@ def model_from_checkpoint(
     """
     config_path = Path(__file__).parent / config_file
     cfg = omegaconf.OmegaConf.load(config_path)
-    state: Dict[str, Any] = torch.load(cfg["checkpoint"])
+    state: Dict[str, Any] = torch.load(cfg["checkpoint"], map_location=device)
     if "model" in state and "xp.cfg" in state:
         cfg = omegaconf.OmegaConf.create(state["xp.cfg"])
         omegaconf.OmegaConf.resolve(cfg)