Problem Description
In examples/alphazero/train.py, we compute value_mask as follows:
|
value_mask = jnp.cumsum(data.terminated[::-1, :], axis=0)[::-1, :] >= 1 |
The purpose is to avoid updating the critic network on incomplete trajectories, as is evident by masking of value loss:
|
value_loss = jnp.mean(value_loss * samples.mask) # mask if the episode is truncated |
Now, critic and actor networks share a torso of residual blocks as defined in network.py, and while we mask value losses, we don't mask policy losses for samples from incomplete trajectories:
|
policy_loss = optax.softmax_cross_entropy(logits, samples.policy_tgt) |
Therefore, we are in fact inadvertently influencing both the policy and the value network outputs by samples from incomplete trajectories. This seems to be against the intended effect of defining value_mask.
Possible Solutions
- To mask out the effect of truncated trajectories from computation of policy loss as well.
- To bootstrap value target for truncated trajectories.
I am not sure which of these or another solution is used by the original AlphaZero papers.
Problem Description
In examples/alphazero/train.py, we compute
value_maskas follows:pgx/examples/alphazero/train.py
Line 179 in 87278d2
The purpose is to avoid updating the critic network on incomplete trajectories, as is evident by masking of value loss:
pgx/examples/alphazero/train.py
Line 211 in 87278d2
Now, critic and actor networks share a torso of residual blocks as defined in network.py, and while we mask value losses, we don't mask policy losses for samples from incomplete trajectories:
pgx/examples/alphazero/train.py
Line 207 in 87278d2
Therefore, we are in fact inadvertently influencing both the policy and the value network outputs by samples from incomplete trajectories. This seems to be against the intended effect of defining
value_mask.Possible Solutions
I am not sure which of these or another solution is used by the original AlphaZero papers.