From c54195325a29f870ab7801f495cc7a405dd12dc1 Mon Sep 17 00:00:00 2001 From: Taro Date: Wed, 10 Jun 2026 20:46:57 +0200 Subject: [PATCH] [AnimalShogi] precompute the CAN_MOVE table once instead of per _can_move call _can_move rebuilt the full (5, 12, 12) CAN_MOVE table via a triple-nested vmap on every call, only to index a single entry. The table depends solely on movement geometry, so lift it to a module-level constant (as the Zobrist tables already are). _can_move runs ~1700x per _legal_action_mask (132 move checks + 132x12 in _is_checked), so the rebuild dominated. ~1.57x faster _legal_action_mask, ~1.2x faster compile. Output identical: legal_action_mask matches the previous implementation across 250 states; tests/test_animal_shogi.py passes. --- pgx/animal_shogi.py | 52 ++++++++++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/pgx/animal_shogi.py b/pgx/animal_shogi.py index 0cbe788a7..19ab477d1 100644 --- a/pgx/animal_shogi.py +++ b/pgx/animal_shogi.py @@ -341,32 +341,36 @@ def _flip(state): ) -def _can_move(piece, from_, to): - def can_move(piece, from_, to): - """Can move from to ?""" - x0, y0 = from_ // 4, from_ % 4 - x1, y1 = to // 4, to % 4 - dx = x1 - x0 - dy = y1 - y0 - is_neighbour = ((dx != 0) | (dy != 0)) & (jnp.abs(dx) <= 1) & (jnp.abs(dy) <= 1) - return jax.lax.switch( - piece, - [ - lambda: (dx == 0) & (dy == -1), # PAWN - lambda: is_neighbour & ((dx == dy) | (dx == -dy)), # BISHOP - lambda: is_neighbour & ((dx == 0) | (dy == 0)), # ROOK - lambda: is_neighbour, # KING - lambda: is_neighbour & ((dx == 0) | (dy != +1)), # GOLD - ], - ) +def _can_move_geometry(piece, from_, to): + """Can move from to ? (movement geometry only)""" + x0, y0 = from_ // 4, from_ % 4 + x1, y1 = to // 4, to % 4 + dx = x1 - x0 + dy = y1 - y0 + is_neighbour = ((dx != 0) | (dy != 0)) & (jnp.abs(dx) <= 1) & (jnp.abs(dy) <= 1) + return jax.lax.switch( + piece, + [ + lambda: (dx == 0) & (dy == -1), # PAWN + lambda: is_neighbour & ((dx == dy) | (dx == -dy)), # BISHOP + lambda: is_neighbour & ((dx == 0) | (dy == 0)), # ROOK + lambda: is_neighbour, # KING + lambda: is_neighbour & ((dx == 0) | (dy != +1)), # GOLD + ], + ) - # fmt: off - # CAN_MOVE[piece, from_, to] = Can move from to ? - CAN_MOVE = jax.vmap(jax.vmap(jax.vmap( - can_move, (None, None, 0)), (None, 0, None)), (0, None, None) - )(jnp.arange(5), jnp.arange(12), jnp.arange(12)) - # fmt: on +# CAN_MOVE[piece, from_, to] = Can move from to ? It depends only on the +# movement geometry, so it is precomputed once here rather than rebuilt on every _can_move call +# (which runs ~1700x per _legal_action_mask). +# fmt: off +CAN_MOVE = jax.vmap(jax.vmap(jax.vmap( + _can_move_geometry, (None, None, 0)), (None, 0, None)), (0, None, None) +)(jnp.arange(5), jnp.arange(12), jnp.arange(12)) +# fmt: on + + +def _can_move(piece, from_, to): return CAN_MOVE[piece, from_, to]