Skip to content

Commit a981dca

Browse files
committed
Fix bugs
1 parent 253c87d commit a981dca

2 files changed

Lines changed: 12 additions & 2 deletions

File tree

arch_eval/core/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import torch
1010

11+
from arch_eval.core.exceptions import ConfigurationError
1112

1213
class TaskType(str, Enum):
1314
REGRESSION = "regression"

arch_eval/core/trainer.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,17 @@ def _save_checkpoint(self, epoch: int, metrics: Dict[str, float]):
544544
self.logger.info(f"Checkpoint saved to {path}")
545545
self.plugin_manager.execute_hook("on_checkpoint", self, path, is_best)
546546

547-
def load_checkpoint(self, path: str, load_optimizer: bool = True, load_scheduler: bool = True):
548-
ckpt = torch.load(path, map_location=self.device)
547+
def load_checkpoint(self, path: str, load_optimizer: bool = True, load_scheduler: bool = True, weights_only: bool = False):
548+
"""Load a checkpoint.
549+
550+
Args:
551+
path: Path to the checkpoint file.
552+
load_optimizer: Whether to load optimizer states.
553+
load_scheduler: Whether to load scheduler states.
554+
weights_only: If True, restricts unpickling to safe types.
555+
For checkpoints saved by this library, set to False (default) because they contain custom classes.
556+
"""
557+
ckpt = torch.load(path, map_location=self.device, weights_only=weights_only)
549558
self.model.load_state_dict(ckpt["model_state_dict"])
550559
if load_optimizer:
551560
for opt, state in zip(self.optimizers, ckpt["optimizer_state_dicts"]):

0 commit comments

Comments
 (0)