diff --git a/pgx/backgammon.py b/pgx/backgammon.py index 6724f7218..dfc75e83c 100644 --- a/pgx/backgammon.py +++ b/pgx/backgammon.py @@ -401,16 +401,24 @@ def _decompose_action(action: Array): return src, die, tgt -def _is_action_legal(board: Array, action: Array) -> bool: +def _is_action_legal(board: Array, action: Array, rear=None, all_on_home=None) -> bool: """ Check if the action is legal. action = src * 6 + die src = [no op., from bar, 0, .., 23] + + ``rear`` (rear-checker distance) and ``all_on_home`` (bear-off readiness) depend only on + the board. Callers evaluating many actions on the same board (the legal-action mask) pass + them in precomputed so they are not rebuilt per action; standalone callers omit them. """ + if rear is None: + rear = _rear_distance(board) + if all_on_home is None: + all_on_home = _is_all_on_home_board(board) src, die, tgt = _decompose_action(action) _is_to_point = (0 <= tgt) & (tgt <= 23) & (src >= 0) return _is_to_point & _is_to_point_legal(board, src, tgt) | (~_is_to_point) & _is_to_off_legal( - board, src, tgt, die + board, src, tgt, die, rear, all_on_home ) # type: ignore @@ -421,18 +429,21 @@ def _distance_to_goal(src: int) -> int: return 24 - src # type: ignore -def _is_to_off_legal(board: Array, src: int, tgt: int, die: int): +def _is_to_off_legal(board: Array, src: int, tgt: int, die: int, rear: Array, all_on_home: Array): """ Check if the action is legal when the target is off. The conditions are: 1. src has checkers. 2. All checkers are on home board. 3. The distance from the src to the goal is the same as the die or the src is the farthest checker and the die is bigger than the distance. + + ``rear`` (= _rear_distance(board)) and ``all_on_home`` (= _is_all_on_home_board(board)) + are board-only; they are passed in so the legal-action mask computes them once per board + rather than once per candidate action. """ - r = _rear_distance(board) d = _distance_to_goal(src) return ( - (src >= 0) & _exists(board, src) & _is_all_on_home_board(board) & ((d == die) | ((r <= die) & (r == d))) + (src >= 0) & _exists(board, src) & all_on_home & ((d == die) | ((rear <= die) & (rear == d))) ) # type: ignore @@ -514,16 +525,15 @@ def _legal_action_mask_for_valid_single_dice(board: Array, die) -> Array: """ Legal action mask for a single die when the die is valid. """ - src_indices = jnp.arange(26, dtype=jnp.int32) # calc legal action for all src indices - - def _is_legal(idx: Array): - action = idx * 6 + die - legal_action_mask = jnp.zeros(26 * 6, dtype=jnp.bool_) - legal_action_mask = legal_action_mask.at[action].set(_is_action_legal(board, action)) - return legal_action_mask - - legal_action_mask = jax.vmap(_is_legal)(src_indices).any(axis=0) # (26 * 6) - return legal_action_mask + # micro action = 6 * src + die. Compute the 26 src legalities as a (26,) vector and + # scatter them into the mask once, instead of materializing a (26, 26 * 6) intermediate + # (a full zero vector per src, each with a single bit set) and OR-reducing it. + # The board-only predicates are computed once here, not per src inside the vmap. + rear = _rear_distance(board) + all_on_home = _is_all_on_home_board(board) + actions = jnp.arange(26, dtype=jnp.int32) * 6 + die + legal = jax.vmap(lambda action: _is_action_legal(board, action, rear, all_on_home))(actions) + return jnp.zeros(26 * 6, dtype=jnp.bool_).at[actions].set(legal) def _get_abs_board(state: State) -> Array: