diff --git a/04_dqn.py b/04_dqn.py index 5faef4d..7e91650 100644 --- a/04_dqn.py +++ b/04_dqn.py @@ -36,8 +36,6 @@ class DQN: def __init__(self, dim_state=None, num_action=None, discount=0.9): self.discount = discount self.Q = QNet(dim_state, num_action) - self.target_Q = QNet(dim_state, num_action) - self.target_Q.load_state_dict(self.Q.state_dict()) def get_action(self, state): qvals = self.Q(state) @@ -46,21 +44,13 @@ def get_action(self, state): def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch): # 计算s_batch,a_batch对应的值。 qvals = self.Q(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze() - # 使用target Q网络计算next_s_batch对应的值。 - next_qvals, _ = self.target_Q(next_s_batch).detach().max(dim=1) + # 使用原始的 Q网络计算next_s_batch对应的值。 + next_qvals, _ = self.Q(next_s_batch).detach().max(dim=1) # 使用MSE计算loss。 loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals) return loss -def soft_update(target, source, tau=0.01): - """ - update target by target = tau * source + (1 - tau) * target. - """ - for target_param, param in zip(target.parameters(), source.parameters()): - target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) - - @dataclass class ReplayBuffer: maxsize: int @@ -126,7 +116,7 @@ def train(args, env, agent): if np.random.rand() < epsilon or i < args.warmup_steps: action = env.action_space.sample() else: - action = agent.get_action(torch.from_numpy(state)) + action = agent.get_action(torch.from_numpy(state).to(args.device)) action = action.item() next_state, reward, terminated, truncated, _ = env.step(action) done = terminated or truncated @@ -155,11 +145,11 @@ def train(args, env, agent): if i > args.warmup_steps: bs, ba, br, bd, bns = replay_buffer.sample(n=args.batch_size) - bs = torch.tensor(bs, dtype=torch.float32) - ba = torch.tensor(ba, dtype=torch.long) - br = torch.tensor(br, dtype=torch.float32) - bd = torch.tensor(bd, dtype=torch.float32) - bns = torch.tensor(bns, dtype=torch.float32) + bs = torch.tensor(bs, dtype=torch.float32).to(args.device) + ba = torch.tensor(ba, dtype=torch.long).to(args.device) + br = torch.tensor(br, dtype=torch.float32).to(args.device) + bd = torch.tensor(bd, dtype=torch.float32).to(args.device) + bns = torch.tensor(bns, dtype=torch.float32).to(args.device) loss = agent.compute_loss(bs, ba, br, bd, bns) loss.backward() @@ -168,8 +158,6 @@ def train(args, env, agent): log["loss"].append(loss.item()) - soft_update(agent.target_Q, agent.Q) - # 3. 画图。 plt.plot(log["loss"]) plt.yscale("log") @@ -191,7 +179,7 @@ def eval(args, env, agent): state, _ = env.reset() for i in range(5000): episode_length += 1 - action = agent.get_action(torch.from_numpy(state)).item() + action = agent.get_action(torch.from_numpy(state).to(args.device)).item() next_state, reward, terminated, truncated, _ = env.step(action) done = terminated or truncated env.render() @@ -229,7 +217,6 @@ def main(): set_seed(args) agent = DQN(dim_state=args.dim_state, num_action=args.num_action, discount=args.discount) agent.Q.to(args.device) - agent.target_Q.to(args.device) if args.do_train: train(args, env, agent) diff --git a/06_doubledqn.py b/06_doubledqn.py index d27d49c..0831e70 100644 --- a/06_doubledqn.py +++ b/06_doubledqn.py @@ -32,19 +32,21 @@ def forward(self, obs): class DoubleDQN: def __init__(self, dim_obs=None, num_act=None, discount=0.9): self.discount = discount - self.model = QNet(dim_obs, num_act) - self.target_model = QNet(dim_obs, num_act) - self.target_model.load_state_dict(self.model.state_dict()) + self.Q = QNet(dim_obs, num_act) + self.target_Q = QNet(dim_obs, num_act) + self.target_Q.load_state_dict(self.Q.state_dict()) def get_action(self, obs): - qvals = self.model(obs) + qvals = self.Q(obs) return qvals.argmax() def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch): # Compute current Q value based on current states and actions. - qvals = self.model(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze() + qvals = self.Q(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze() + # 选择行动 + next_a_batch = self.Q(next_s_batch).argmax(dim=1) # next state的value不参与导数计算,避免不收敛。 - next_qvals, _ = self.target_model(next_s_batch).detach().max(dim=1) + next_qvals = self.target_Q(next_s_batch).gather(1, next_a_batch.unsqueeze(1)).squeeze().detach() loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals) return loss @@ -96,7 +98,7 @@ def set_seed(args): def train(args, env, agent): replay_buffer = ReplayBuffer(100_000) - optimizer = torch.optim.Adam(agent.model.parameters(), lr=args.lr) + optimizer = torch.optim.Adam(agent.Q.parameters(), lr=args.lr) optimizer.zero_grad() epsilon = 1 @@ -107,16 +109,16 @@ def train(args, env, agent): log_ep_rewards = [] log_losses = [0] - agent.model.train() - agent.target_model.train() - agent.model.zero_grad() - agent.target_model.zero_grad() + agent.Q.train() + agent.target_Q.train() + agent.Q.zero_grad() + agent.target_Q.zero_grad() state, _ = env.reset() for i in range(args.max_steps): if np.random.rand() < epsilon or i < args.warmup_steps: action = env.action_space.sample() else: - action = agent.get_action(torch.from_numpy(state)) + action = agent.get_action(torch.from_numpy(state).to(args.device)) action = action.item() next_state, reward, terminated, truncated, _ = env.step(action) done = terminated or truncated @@ -140,8 +142,8 @@ def train(args, env, agent): print(f"i={i}, reward={episode_reward:.0f}, length={episode_length}, max_reward={max_episode_reward}, loss={log_losses[-1]:.1e}, epsilon={epsilon:.3f}") if episode_length < 180 and episode_reward > max_episode_reward: - save_path = os.path.join(args.output_dir, "model.bin") - torch.save(agent.model.state_dict(), save_path) + save_path = os.path.join(args.output_dir, "Q.bin") + torch.save(agent.Q.state_dict(), save_path) max_episode_reward = episode_reward episode_reward = 0 @@ -150,11 +152,11 @@ def train(args, env, agent): if i > args.warmup_steps: bs, ba, br, bd, bns = replay_buffer.sample(n=args.batch_size) - bs = torch.tensor(bs, dtype=torch.float32) - ba = torch.tensor(ba, dtype=torch.long) - br = torch.tensor(br, dtype=torch.float32) - bd = torch.tensor(bd, dtype=torch.float32) - bns = torch.tensor(bns, dtype=torch.float32) + bs = torch.tensor(bs, dtype=torch.float32).to(args.device) + ba = torch.tensor(ba, dtype=torch.long).to(args.device) + br = torch.tensor(br, dtype=torch.float32).to(args.device) + bd = torch.tensor(bd, dtype=torch.float32).to(args.device) + bns = torch.tensor(bns, dtype=torch.float32).to(args.device) loss = agent.compute_loss(bs, ba, br, bd, bns) loss.backward() @@ -164,7 +166,7 @@ def train(args, env, agent): log_losses.append(loss.item()) # 更新目标网络。 - for target_param, param in zip(agent.target_model.parameters(), agent.model.parameters()): + for target_param, param in zip(agent.target_Q.parameters(), agent.Q.parameters()): target_param.data.copy_(args.lr_target * param.data + (1 - args.lr_target) * target_param.data) plt.plot(log_losses) @@ -179,16 +181,16 @@ def train(args, env, agent): def eval(args, env, agent): model_path = os.path.join(args.output_dir, "model.bin") - agent.model.load_state_dict(torch.load(model_path)) + agent.Q.load_state_dict(torch.load(model_path)) episode_length = 0 episode_reward = 0 - agent.model.eval() + agent.Q.eval() state, _ = env.reset() for i in range(5000): episode_length += 1 - action = agent.get_action(torch.from_numpy(state)).item() + action = agent.get_action(torch.from_numpy(state).to(args.device)).item() next_state, reward, terminated, truncated, _ = env.step(action) done = terminated or truncated episode_reward += reward @@ -225,7 +227,8 @@ def main(): env = gym.make(args.env) set_seed(args) agent = DoubleDQN(dim_obs=args.dim_obs, num_act=args.num_act, discount=args.discount) - agent.model.to(args.device) + agent.Q.to(args.device) + agent.target_Q.to(args.device) if args.do_train: train(args, env, agent) diff --git a/06_dueling_network.py b/06_dueling_network.py new file mode 100644 index 0000000..e623d85 --- /dev/null +++ b/06_dueling_network.py @@ -0,0 +1,264 @@ +"""6.3节对决网络算法实现。 +""" +import argparse +from collections import defaultdict +from itertools import chain +import os +import random +from dataclasses import dataclass, field +import gym +import matplotlib.pyplot as plt +import numpy as npf +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DNet(nn.Module): + """Optimal advantage function. + Input: feature + Output: num_act of values + """ + + def __init__(self, dim_state, num_action): + super().__init__() + self.fc1 = nn.Linear(dim_state, 64) + self.fc2 = nn.Linear(64, 32) + self.fc3 = nn.Linear(32, num_action) + + def forward(self, state): + x = F.relu(self.fc1(state)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class VNet(nn.Module): + """Optimal state-value function. + Input: feature + Output: num_act of values + """ + + def __init__(self, dim_state): + super().__init__() + self.fc1 = nn.Linear(dim_state, 64) + self.fc2 = nn.Linear(64, 32) + self.fc3 = nn.Linear(32,1) + + def forward(self, state): + x = F.relu(self.fc1(state)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + +class DeulingNetwork: + def __init__(self, dim_state=None, num_action=None, discount=0.9): + self.discount = discount + self.V = VNet(dim_state) + self.D = DNet(dim_state, num_action) + + def get_Q(self, state): + dvals = self.D(state) + vval = self.V(state) + return dvals + vval - dvals.mean() + + def get_action(self, state): + qvals = self.get_Q(state) + return qvals.argmax() + + def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch): + # 计算s_batch,a_batch对应的值。 + dvals = self.D(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze() + # 计算Q value + vval = self.V(s_batch).squeeze() + qvals = dvals + vval - torch.mean(self.D(s_batch), 1) + # 使用dueling network网络计算next_s_batch对应的值。 + next_qvals, _ = self.get_Q(next_s_batch).detach().max(dim=1) + # 使用MSE计算loss。 + loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals) + return loss + + +@dataclass +class ReplayBuffer: + maxsize: int + size: int = 0 + state: list = field(default_factory=list) + action: list = field(default_factory=list) + next_state: list = field(default_factory=list) + reward: list = field(default_factory=list) + done: list = field(default_factory=list) + + def push(self, state, action, reward, done, next_state): + if self.size < self.maxsize: + self.state.append(state) + self.action.append(action) + self.reward.append(reward) + self.done.append(done) + self.next_state.append(next_state) + else: + position = self.size % self.maxsize + self.state[position] = state + self.action[position] = action + self.reward[position] = reward + self.done[position] = done + self.next_state[position] = next_state + self.size += 1 + + def sample(self, n): + total_number = self.size if self.size < self.maxsize else self.maxsize + indices = np.random.randint(total_number, size=n) + state = [self.state[i] for i in indices] + action = [self.action[i] for i in indices] + reward = [self.reward[i] for i in indices] + done = [self.done[i] for i in indices] + next_state = [self.next_state[i] for i in indices] + return state, action, reward, done, next_state + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if not args.no_cuda: + torch.cuda.manual_seed(args.seed) + + +def train(args, env, agent): + replay_buffer = ReplayBuffer(10_000) + optimizer = torch.optim.Adam(chain(agent.V.parameters(), agent.D.parameters()), lr=args.lr) + optimizer.zero_grad() + + epsilon = 1 + epsilon_max = 1 + epsilon_min = 0.1 + episode_reward = 0 + episode_length = 0 + max_episode_reward = -float("inf") + log = defaultdict(list) + log["loss"].append(0) + + agent.V.train() + agent.D.train() + state, _ = env.reset(seed=args.seed) + for i in range(args.max_steps): + if np.random.rand() < epsilon or i < args.warmup_steps: + action = env.action_space.sample() + else: + action = agent.get_action(torch.from_numpy(state).to(args.device)) + action = action.item() + next_state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + episode_reward += reward + episode_length += 1 + + replay_buffer.push(state, action, reward, done, next_state) + state = next_state + + if done is True: + log["episode_reward"].append(episode_reward) + log["episode_length"].append(episode_length) + + print(f"i={i}, reward={episode_reward:.0f}, length={episode_length}, max_reward={max_episode_reward}, loss={log['loss'][-1]:.1e}, epsilon={epsilon:.3f}") + + # 如果得分更高,保存模型。 + if episode_reward > max_episode_reward: + save_path = os.path.join(args.output_dir, "V_model.bin") + torch.save(agent.V.state_dict(), save_path) + save_path = os.path.join(args.output_dir, "D_model.bin") + torch.save(agent.D.state_dict(), save_path) + max_episode_reward = episode_reward + + episode_reward = 0 + episode_length = 0 + epsilon = max(epsilon - (epsilon_max - epsilon_min) * args.epsilon_decay, 1e-1) + state, _ = env.reset() + + if i > args.warmup_steps: + bs, ba, br, bd, bns = replay_buffer.sample(n=args.batch_size) + bs = torch.tensor(bs, dtype=torch.float32).to(args.device) + ba = torch.tensor(ba, dtype=torch.long).to(args.device) + br = torch.tensor(br, dtype=torch.float32).to(args.device) + bd = torch.tensor(bd, dtype=torch.float32).to(args.device) + bns = torch.tensor(bns, dtype=torch.float32).to(args.device) + + loss = agent.compute_loss(bs, ba, br, bd, bns) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + log["loss"].append(loss.item()) + + # 3. 画图。 + plt.plot(log["loss"]) + plt.yscale("log") + plt.savefig(f"{args.output_dir}/loss.png", bbox_inches="tight") + plt.close() + + plt.plot(np.cumsum(log["episode_length"]), log["episode_reward"]) + plt.savefig(f"{args.output_dir}/episode_reward.png", bbox_inches="tight") + plt.close() + + +def eval(args, env, agent): + agent = DeulingNetwork(args.dim_state, args.num_action) + v_model_path = os.path.join(args.output_dir, "V_model.bin") + agent.V.load_state_dict(torch.load(v_model_path)) + + d_model_path = os.path.join(args.output_dir, "D_model.bin") + agent.D.load_state_dict(torch.load(d_model_path)) + + episode_length = 0 + episode_reward = 0 + state, _ = env.reset() + for i in range(5000): + episode_length += 1 + action = agent.get_action(torch.from_numpy(state).to(args.device)).item() + next_state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + env.render() + episode_reward += reward + + state = next_state + if done is True: + print(f"episode reward={episode_reward}, episode length{episode_length}") + state, _ = env.reset() + episode_length = 0 + episode_reward = 0 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--env", default="CartPole-v1", type=str, help="Environment name.") + parser.add_argument("--dim_state", default=4, type=int, help="Dimension of state.") + parser.add_argument("--num_action", default=2, type=int, help="Number of action.") + parser.add_argument("--discount", default=0.99, type=float, help="Discount coefficient.") + parser.add_argument("--max_steps", default=100_000, type=int, help="Maximum steps for interaction.") + parser.add_argument("--lr", default=1e-3, type=float, help="Learning rate.") + parser.add_argument("--batch_size", default=32, type=int, help="Batch size.") + parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") + parser.add_argument("--seed", default=42, type=int, help="Random seed.") + parser.add_argument("--warmup_steps", default=10_000, type=int, help="Warmup steps without training.") + parser.add_argument("--output_dir", default="output", type=str, help="Output directory.") + parser.add_argument("--epsilon_decay", default=1 / 1000, type=float, help="Epsilon-greedy algorithm decay coefficient.") + parser.add_argument("--do_train", action="store_true", help="Train policy.") + parser.add_argument("--do_eval", action="store_true", help="Evaluate policy.") + args = parser.parse_args() + + args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + + env = gym.make(args.env) + set_seed(args) + agent = DeulingNetwork(dim_state=args.dim_state, num_action=args.num_action, discount=args.discount) + agent.V.to(args.device) + agent.D.to(args.device) + + if args.do_train: + train(args, env, agent) + + if args.do_eval: + eval(args, env, agent) + + +if __name__ == "__main__": + main() diff --git a/06_target_network.py b/06_target_network.py new file mode 100644 index 0000000..12008f4 --- /dev/null +++ b/06_target_network.py @@ -0,0 +1,242 @@ +"""6.2.4节目标网络算法实现。 +""" +import argparse +from collections import defaultdict +import os +import random +from dataclasses import dataclass, field +import gym +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class QNet(nn.Module): + """QNet. + Input: feature + Output: num_act of values + """ + + def __init__(self, dim_state, num_action): + super().__init__() + self.fc1 = nn.Linear(dim_state, 64) + self.fc2 = nn.Linear(64, 32) + self.fc3 = nn.Linear(32, num_action) + + def forward(self, state): + x = F.relu(self.fc1(state)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class TargetDQN: + def __init__(self, dim_state=None, num_action=None, discount=0.9): + self.discount = discount + self.Q = QNet(dim_state, num_action) + self.target_Q = QNet(dim_state, num_action) + self.target_Q.load_state_dict(self.Q.state_dict()) + + def get_action(self, state): + qvals = self.Q(state) + return qvals.argmax() + + def compute_loss(self, s_batch, a_batch, r_batch, d_batch, next_s_batch): + # 计算s_batch,a_batch对应的值。 + qvals = self.Q(s_batch).gather(1, a_batch.unsqueeze(1)).squeeze() + # 使用target Q网络计算next_s_batch对应的值。 + next_qvals, _ = self.target_Q(next_s_batch).detach().max(dim=1) + # 使用MSE计算loss。 + loss = F.mse_loss(r_batch + self.discount * next_qvals * (1 - d_batch), qvals) + return loss + + +def soft_update(target, source, tau=0.01): + """ + update target by target = tau * source + (1 - tau) * target. + """ + for target_param, param in zip(target.parameters(), source.parameters()): + target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau) + + +@dataclass +class ReplayBuffer: + maxsize: int + size: int = 0 + state: list = field(default_factory=list) + action: list = field(default_factory=list) + next_state: list = field(default_factory=list) + reward: list = field(default_factory=list) + done: list = field(default_factory=list) + + def push(self, state, action, reward, done, next_state): + if self.size < self.maxsize: + self.state.append(state) + self.action.append(action) + self.reward.append(reward) + self.done.append(done) + self.next_state.append(next_state) + else: + position = self.size % self.maxsize + self.state[position] = state + self.action[position] = action + self.reward[position] = reward + self.done[position] = done + self.next_state[position] = next_state + self.size += 1 + + def sample(self, n): + total_number = self.size if self.size < self.maxsize else self.maxsize + indices = np.random.randint(total_number, size=n) + state = [self.state[i] for i in indices] + action = [self.action[i] for i in indices] + reward = [self.reward[i] for i in indices] + done = [self.done[i] for i in indices] + next_state = [self.next_state[i] for i in indices] + return state, action, reward, done, next_state + + +def set_seed(args): + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + if not args.no_cuda: + torch.cuda.manual_seed(args.seed) + + +def train(args, env, agent): + replay_buffer = ReplayBuffer(10_000) + optimizer = torch.optim.Adam(agent.Q.parameters(), lr=args.lr) + optimizer.zero_grad() + + epsilon = 1 + epsilon_max = 1 + epsilon_min = 0.1 + episode_reward = 0 + episode_length = 0 + max_episode_reward = -float("inf") + log = defaultdict(list) + log["loss"].append(0) + + agent.Q.train() + state, _ = env.reset(seed=args.seed) + for i in range(args.max_steps): + if np.random.rand() < epsilon or i < args.warmup_steps: + action = env.action_space.sample() + else: + action = agent.get_action(torch.from_numpy(state).to(args.device)) + action = action.item() + next_state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + episode_reward += reward + episode_length += 1 + + replay_buffer.push(state, action, reward, done, next_state) + state = next_state + + if done is True: + log["episode_reward"].append(episode_reward) + log["episode_length"].append(episode_length) + + print(f"i={i}, reward={episode_reward:.0f}, length={episode_length}, max_reward={max_episode_reward}, loss={log['loss'][-1]:.1e}, epsilon={epsilon:.3f}") + + # 如果得分更高,保存模型。 + if episode_reward > max_episode_reward: + save_path = os.path.join(args.output_dir, "model.bin") + torch.save(agent.Q.state_dict(), save_path) + max_episode_reward = episode_reward + + episode_reward = 0 + episode_length = 0 + epsilon = max(epsilon - (epsilon_max - epsilon_min) * args.epsilon_decay, 1e-1) + state, _ = env.reset() + + if i > args.warmup_steps: + bs, ba, br, bd, bns = replay_buffer.sample(n=args.batch_size) + bs = torch.tensor(bs, dtype=torch.float32).to(args.device) + ba = torch.tensor(ba, dtype=torch.long).to(args.device) + br = torch.tensor(br, dtype=torch.float32).to(args.device) + bd = torch.tensor(bd, dtype=torch.float32).to(args.device) + bns = torch.tensor(bns, dtype=torch.float32).to(args.device) + + loss = agent.compute_loss(bs, ba, br, bd, bns) + loss.backward() + optimizer.step() + optimizer.zero_grad() + + log["loss"].append(loss.item()) + + soft_update(agent.target_Q, agent.Q) + + # 3. 画图。 + plt.plot(log["loss"]) + plt.yscale("log") + plt.savefig(f"{args.output_dir}/loss.png", bbox_inches="tight") + plt.close() + + plt.plot(np.cumsum(log["episode_length"]), log["episode_reward"]) + plt.savefig(f"{args.output_dir}/episode_reward.png", bbox_inches="tight") + plt.close() + + +def eval(args, env, agent): + agent = TargetDQN(args.dim_state, args.num_action) + model_path = os.path.join(args.output_dir, "model.bin") + agent.Q.load_state_dict(torch.load(model_path)) + + episode_length = 0 + episode_reward = 0 + state, _ = env.reset() + for i in range(5000): + episode_length += 1 + action = agent.get_action(torch.from_numpy(state).to(args.device)).item() + next_state, reward, terminated, truncated, _ = env.step(action) + done = terminated or truncated + env.render() + episode_reward += reward + + state = next_state + if done is True: + print(f"episode reward={episode_reward}, episode length{episode_length}") + state, _ = env.reset() + episode_length = 0 + episode_reward = 0 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--env", default="CartPole-v1", type=str, help="Environment name.") + parser.add_argument("--dim_state", default=4, type=int, help="Dimension of state.") + parser.add_argument("--num_action", default=2, type=int, help="Number of action.") + parser.add_argument("--discount", default=0.99, type=float, help="Discount coefficient.") + parser.add_argument("--max_steps", default=100_000, type=int, help="Maximum steps for interaction.") + parser.add_argument("--lr", default=1e-3, type=float, help="Learning rate.") + parser.add_argument("--batch_size", default=32, type=int, help="Batch size.") + parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") + parser.add_argument("--seed", default=42, type=int, help="Random seed.") + parser.add_argument("--warmup_steps", default=10_000, type=int, help="Warmup steps without training.") + parser.add_argument("--output_dir", default="output", type=str, help="Output directory.") + parser.add_argument("--epsilon_decay", default=1 / 1000, type=float, help="Epsilon-greedy algorithm decay coefficient.") + parser.add_argument("--do_train", action="store_true", help="Train policy.") + parser.add_argument("--do_eval", action="store_true", help="Evaluate policy.") + args = parser.parse_args() + + args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") + + env = gym.make(args.env) + set_seed(args) + agent = TargetDQN(dim_state=args.dim_state, num_action=args.num_action, discount=args.discount) + agent.Q.to(args.device) + agent.target_Q.to(args.device) + + if args.do_train: + train(args, env, agent) + + if args.do_eval: + eval(args, env, agent) + + +if __name__ == "__main__": + main()