Hi,
I want to use SAC algorithm in MultiDiscrete action space. In Discrete action space, the actor loss is calculated as follows:
|
action, (action_probabilities, log_action_probabilities), _ = self.produce_action_and_action_info(state_batch) |
|
qf1_pi = self.critic_local(state_batch) |
|
qf2_pi = self.critic_local_2(state_batch) |
|
min_qf_pi = torch.min(qf1_pi, qf2_pi) |
|
inside_term = self.alpha * log_action_probabilities - min_qf_pi |
|
policy_loss = (action_probabilities * inside_term).sum(dim=1).mean() |
In MultiDiscrete action space, the shapes of log_action_probabilities action_probabilities qf1_pi and qf2_pi are all [batch_size, num_action_dim, num_actions_per_dim]. Can you give me some hints on how to calculate policy_loss in MultiDiscrete action space? Should I apply sum(-1) twice to make sure the shape of policy_loss is [batch_size]?
Hi,
I want to use SAC algorithm in MultiDiscrete action space. In Discrete action space, the actor loss is calculated as follows:
Deep-Reinforcement-Learning-Algorithms-with-PyTorch/agents/actor_critic_agents/SAC_Discrete.py
Lines 83 to 88 in 4835bac
In MultiDiscrete action space, the shapes of
log_action_probabilitiesaction_probabilitiesqf1_piandqf2_piare all[batch_size, num_action_dim, num_actions_per_dim]. Can you give me some hints on how to calculatepolicy_lossin MultiDiscrete action space? Should I applysum(-1)twice to make sure the shape ofpolicy_lossis[batch_size]?