diff --git a/requirements.txt b/requirements.txt index a9dbe249..b8118459 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,6 @@ loguru pyyaml tqdm numpy -pillow \ No newline at end of file +pillow +torch +torchvision diff --git a/utils/train.py b/utils/train.py index 67f6a2f8..e755771c 100644 --- a/utils/train.py +++ b/utils/train.py @@ -56,9 +56,9 @@ def __init__(self, project_name: str): newer_checkpoint = None for checkpoint in history_checkpoints: checkpoint_name = checkpoint.split(".")[0].split("_") - if int(checkpoint_name[3]) > history_step: + if int(checkpoint_name[-1]) > history_step: newer_checkpoint = checkpoint - history_step = int(checkpoint_name[3]) + history_step = int(checkpoint_name[-1]) param, self.state_dict, self.optimizer= Net.load_checkpoint( os.path.join(self.checkpoints_path, newer_checkpoint), self.device) self.epoch, self.step, self.lr = param['epoch'], param['step'], param['lr']