Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions pgx/backgammon.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,16 +401,24 @@ def _decompose_action(action: Array):
return src, die, tgt


def _is_action_legal(board: Array, action: Array) -> bool:
def _is_action_legal(board: Array, action: Array, rear=None, all_on_home=None) -> bool:
"""
Check if the action is legal.
action = src * 6 + die
src = [no op., from bar, 0, .., 23]

``rear`` (rear-checker distance) and ``all_on_home`` (bear-off readiness) depend only on
the board. Callers evaluating many actions on the same board (the legal-action mask) pass
them in precomputed so they are not rebuilt per action; standalone callers omit them.
"""
if rear is None:
rear = _rear_distance(board)
if all_on_home is None:
all_on_home = _is_all_on_home_board(board)
src, die, tgt = _decompose_action(action)
_is_to_point = (0 <= tgt) & (tgt <= 23) & (src >= 0)
return _is_to_point & _is_to_point_legal(board, src, tgt) | (~_is_to_point) & _is_to_off_legal(
board, src, tgt, die
board, src, tgt, die, rear, all_on_home
) # type: ignore


Expand All @@ -421,18 +429,21 @@ def _distance_to_goal(src: int) -> int:
return 24 - src # type: ignore


def _is_to_off_legal(board: Array, src: int, tgt: int, die: int):
def _is_to_off_legal(board: Array, src: int, tgt: int, die: int, rear: Array, all_on_home: Array):
"""
Check if the action is legal when the target is off.
The conditions are:
1. src has checkers.
2. All checkers are on home board.
3. The distance from the src to the goal is the same as the die or the src is the farthest checker and the die is bigger than the distance.

``rear`` (= _rear_distance(board)) and ``all_on_home`` (= _is_all_on_home_board(board))
are board-only; they are passed in so the legal-action mask computes them once per board
rather than once per candidate action.
"""
r = _rear_distance(board)
d = _distance_to_goal(src)
return (
(src >= 0) & _exists(board, src) & _is_all_on_home_board(board) & ((d == die) | ((r <= die) & (r == d)))
(src >= 0) & _exists(board, src) & all_on_home & ((d == die) | ((rear <= die) & (rear == d)))
) # type: ignore


Expand Down Expand Up @@ -514,16 +525,15 @@ def _legal_action_mask_for_valid_single_dice(board: Array, die) -> Array:
"""
Legal action mask for a single die when the die is valid.
"""
src_indices = jnp.arange(26, dtype=jnp.int32) # calc legal action for all src indices

def _is_legal(idx: Array):
action = idx * 6 + die
legal_action_mask = jnp.zeros(26 * 6, dtype=jnp.bool_)
legal_action_mask = legal_action_mask.at[action].set(_is_action_legal(board, action))
return legal_action_mask

legal_action_mask = jax.vmap(_is_legal)(src_indices).any(axis=0) # (26 * 6)
return legal_action_mask
# micro action = 6 * src + die. Compute the 26 src legalities as a (26,) vector and
# scatter them into the mask once, instead of materializing a (26, 26 * 6) intermediate
# (a full zero vector per src, each with a single bit set) and OR-reducing it.
# The board-only predicates are computed once here, not per src inside the vmap.
rear = _rear_distance(board)
all_on_home = _is_all_on_home_board(board)
actions = jnp.arange(26, dtype=jnp.int32) * 6 + die
legal = jax.vmap(lambda action: _is_action_legal(board, action, rear, all_on_home))(actions)
return jnp.zeros(26 * 6, dtype=jnp.bool_).at[actions].set(legal)


def _get_abs_board(state: State) -> Array:
Expand Down