From 9e3dc40f2e26e1bca62340b91852c103b5b778e6 Mon Sep 17 00:00:00 2001 From: Jo-el van Bergen Date: Thu, 3 Aug 2023 16:46:24 -0700 Subject: [PATCH] Update ppo.save to include iteration in "current" save file content. Update ppo.save to include iteration in "current" save file content. Fixes crash caused by attempting to continue from iteration `-1` when loading a current save file. --- rocket_learn/ppo.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/rocket_learn/ppo.py b/rocket_learn/ppo.py index cad2d24..be634e6 100644 --- a/rocket_learn/ppo.py +++ b/rocket_learn/ppo.py @@ -171,7 +171,7 @@ def _iter(): iteration += 1 if save_dir: - self.save(os.path.join(save_dir, self.logger.project + "_" + "latest"), -1, save_jit) + self.save(os.path.join(save_dir, self.logger.project + "_" + "latest"), iteration, save_actor_jit=save_jit, is_latest=True) if iteration % iterations_per_save == 0: self.save(current_run_dir, iteration, save_jit) # noqa @@ -485,18 +485,19 @@ def load(self, load_location, continue_iterations=True): self.total_steps = checkpoint["total_steps"] print("Continuing training at iteration " + str(self.starting_iteration)) - def save(self, save_location, current_step, save_actor_jit=False): + def save(self, save_location, current_step, save_actor_jit=False, is_latest=False,): """ Save the model weights, optimizer values, and metadata :param save_location: where to save :param current_step: the current iteration when saved. Use to later continue training :param save_actor_jit: save the policy network as a torch jit file for rlbot use + :param is_latest: if this file is the "latest" checkpoint, used to decide if real checkpoint number should be used in filename """ - version_str = str(self.logger.project) + "_" + str(current_step) + version_str = str(self.logger.project) + "_" + (str(current_step) if not is_latest else "c") version_dir = save_location + "\\" + version_str - os.makedirs(version_dir, exist_ok=current_step == -1) + os.makedirs(version_dir, exist_ok=is_latest) torch.save({ 'epoch': current_step,