Эх сурвалжийг харах

load model on correct device (#196)

Pierre Andrews 1 жил өмнө
parent
commit
bf7697ec82

+ 1 - 1
scripts/watermarking/watermarking.py

@@ -151,7 +151,7 @@ def model_from_checkpoint(
     """
     """
     config_path = Path(__file__).parent / config_file
     config_path = Path(__file__).parent / config_file
     cfg = omegaconf.OmegaConf.load(config_path)
     cfg = omegaconf.OmegaConf.load(config_path)
-    state: Dict[str, Any] = torch.load(cfg["checkpoint"])
+    state: Dict[str, Any] = torch.load(cfg["checkpoint"], map_location=device)
     if "model" in state and "xp.cfg" in state:
     if "model" in state and "xp.cfg" in state:
         cfg = omegaconf.OmegaConf.create(state["xp.cfg"])
         cfg = omegaconf.OmegaConf.create(state["xp.cfg"])
         omegaconf.OmegaConf.resolve(cfg)
         omegaconf.OmegaConf.resolve(cfg)