|
@@ -151,7 +151,7 @@ def model_from_checkpoint(
|
|
"""
|
|
"""
|
|
config_path = Path(__file__).parent / config_file
|
|
config_path = Path(__file__).parent / config_file
|
|
cfg = omegaconf.OmegaConf.load(config_path)
|
|
cfg = omegaconf.OmegaConf.load(config_path)
|
|
- state: Dict[str, Any] = torch.load(cfg["checkpoint"])
|
|
|
|
|
|
+ state: Dict[str, Any] = torch.load(cfg["checkpoint"], map_location=device)
|
|
if "model" in state and "xp.cfg" in state:
|
|
if "model" in state and "xp.cfg" in state:
|
|
cfg = omegaconf.OmegaConf.create(state["xp.cfg"])
|
|
cfg = omegaconf.OmegaConf.create(state["xp.cfg"])
|
|
omegaconf.OmegaConf.resolve(cfg)
|
|
omegaconf.OmegaConf.resolve(cfg)
|