diff --git a/pgx/animal_shogi.py b/pgx/animal_shogi.py index 0cbe788a7..19ab477d1 100644 --- a/pgx/animal_shogi.py +++ b/pgx/animal_shogi.py @@ -341,32 +341,36 @@ def _flip(state): ) -def _can_move(piece, from_, to): - def can_move(piece, from_, to): - """Can move from to ?""" - x0, y0 = from_ // 4, from_ % 4 - x1, y1 = to // 4, to % 4 - dx = x1 - x0 - dy = y1 - y0 - is_neighbour = ((dx != 0) | (dy != 0)) & (jnp.abs(dx) <= 1) & (jnp.abs(dy) <= 1) - return jax.lax.switch( - piece, - [ - lambda: (dx == 0) & (dy == -1), # PAWN - lambda: is_neighbour & ((dx == dy) | (dx == -dy)), # BISHOP - lambda: is_neighbour & ((dx == 0) | (dy == 0)), # ROOK - lambda: is_neighbour, # KING - lambda: is_neighbour & ((dx == 0) | (dy != +1)), # GOLD - ], - ) +def _can_move_geometry(piece, from_, to): + """Can move from to ? (movement geometry only)""" + x0, y0 = from_ // 4, from_ % 4 + x1, y1 = to // 4, to % 4 + dx = x1 - x0 + dy = y1 - y0 + is_neighbour = ((dx != 0) | (dy != 0)) & (jnp.abs(dx) <= 1) & (jnp.abs(dy) <= 1) + return jax.lax.switch( + piece, + [ + lambda: (dx == 0) & (dy == -1), # PAWN + lambda: is_neighbour & ((dx == dy) | (dx == -dy)), # BISHOP + lambda: is_neighbour & ((dx == 0) | (dy == 0)), # ROOK + lambda: is_neighbour, # KING + lambda: is_neighbour & ((dx == 0) | (dy != +1)), # GOLD + ], + ) - # fmt: off - # CAN_MOVE[piece, from_, to] = Can move from to ? - CAN_MOVE = jax.vmap(jax.vmap(jax.vmap( - can_move, (None, None, 0)), (None, 0, None)), (0, None, None) - )(jnp.arange(5), jnp.arange(12), jnp.arange(12)) - # fmt: on +# CAN_MOVE[piece, from_, to] = Can move from to ? It depends only on the +# movement geometry, so it is precomputed once here rather than rebuilt on every _can_move call +# (which runs ~1700x per _legal_action_mask). +# fmt: off +CAN_MOVE = jax.vmap(jax.vmap(jax.vmap( + _can_move_geometry, (None, None, 0)), (None, 0, None)), (0, None, None) +)(jnp.arange(5), jnp.arange(12), jnp.arange(12)) +# fmt: on + + +def _can_move(piece, from_, to): return CAN_MOVE[piece, from_, to]