Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions pgx/animal_shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,32 +341,36 @@ def _flip(state):
)


def _can_move(piece, from_, to):
def can_move(piece, from_, to):
"""Can <piece> move from <from_> to <to>?"""
x0, y0 = from_ // 4, from_ % 4
x1, y1 = to // 4, to % 4
dx = x1 - x0
dy = y1 - y0
is_neighbour = ((dx != 0) | (dy != 0)) & (jnp.abs(dx) <= 1) & (jnp.abs(dy) <= 1)
return jax.lax.switch(
piece,
[
lambda: (dx == 0) & (dy == -1), # PAWN
lambda: is_neighbour & ((dx == dy) | (dx == -dy)), # BISHOP
lambda: is_neighbour & ((dx == 0) | (dy == 0)), # ROOK
lambda: is_neighbour, # KING
lambda: is_neighbour & ((dx == 0) | (dy != +1)), # GOLD
],
)
def _can_move_geometry(piece, from_, to):
"""Can <piece> move from <from_> to <to>? (movement geometry only)"""
x0, y0 = from_ // 4, from_ % 4
x1, y1 = to // 4, to % 4
dx = x1 - x0
dy = y1 - y0
is_neighbour = ((dx != 0) | (dy != 0)) & (jnp.abs(dx) <= 1) & (jnp.abs(dy) <= 1)
return jax.lax.switch(
piece,
[
lambda: (dx == 0) & (dy == -1), # PAWN
lambda: is_neighbour & ((dx == dy) | (dx == -dy)), # BISHOP
lambda: is_neighbour & ((dx == 0) | (dy == 0)), # ROOK
lambda: is_neighbour, # KING
lambda: is_neighbour & ((dx == 0) | (dy != +1)), # GOLD
],
)

# fmt: off
# CAN_MOVE[piece, from_, to] = Can <piece> move from <from_> to <to>?
CAN_MOVE = jax.vmap(jax.vmap(jax.vmap(
can_move, (None, None, 0)), (None, 0, None)), (0, None, None)
)(jnp.arange(5), jnp.arange(12), jnp.arange(12))
# fmt: on

# CAN_MOVE[piece, from_, to] = Can <piece> move from <from_> to <to>? It depends only on the
# movement geometry, so it is precomputed once here rather than rebuilt on every _can_move call
# (which runs ~1700x per _legal_action_mask).
# fmt: off
CAN_MOVE = jax.vmap(jax.vmap(jax.vmap(
_can_move_geometry, (None, None, 0)), (None, 0, None)), (0, None, None)
)(jnp.arange(5), jnp.arange(12), jnp.arange(12))
# fmt: on


def _can_move(piece, from_, to):
return CAN_MOVE[piece, from_, to]


Expand Down