diff --git a/agents/ppo_agent.py b/agents/ppo_agent.py index 4d1f318..ffe655e 100644 --- a/agents/ppo_agent.py +++ b/agents/ppo_agent.py @@ -2,6 +2,7 @@ PPO Agent with CNN architecture optimized for Breakout """ +import os import torch import torch.nn as nn import torch.nn.functional as F @@ -237,7 +238,7 @@ def save(self, path: str, global_step: int): if self.rnd_module is not None: state['rnd_module_state_dict'] = self.rnd_module.state_dict() state['rnd_optimizer_state_dict'] = self.rnd_optimizer.state_dict() - + os.makedirs(os.path.dirname(path), exist_ok=True) torch.save(state, path) def load(self, path: str):