PGX on TPUs seems to be slower than CPUs.
With a TPU v3-8, PGX is only achieving 1638 steps / sec on the game of chess.
Minimal Reproducible Example
PGX CPU vs TPU Test (512 env) (with sharding)
PGX CPU vs TPU Test (64 env) (single device)
Running around 8192 envs seems to be the limit. With split sharding across 8 devices, it takes about 1 hour and 27 minutes. If more than 8192 envs are used, there will be memory issues during JIT AOT compilation.
PGX on TPUs seems to be slower than CPUs.
With a TPU v3-8, PGX is only achieving 1638 steps / sec on the game of chess.
Minimal Reproducible Example
PGX CPU vs TPU Test (512 env) (with sharding)
PGX CPU vs TPU Test (64 env) (single device)
Running around 8192 envs seems to be the limit. With split sharding across 8 devices, it takes about 1 hour and 27 minutes. If more than 8192 envs are used, there will be memory issues during JIT AOT compilation.