From a0eb44e57c13ed7abca8dea50ad75799cbc36dfb Mon Sep 17 00:00:00 2001 From: Taro Date: Wed, 10 Jun 2026 20:03:54 +0200 Subject: [PATCH 1/2] [Backgammon] legal_action_mask: scatter once instead of a (26, 156) intermediate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit _legal_action_mask_for_valid_single_dice built 26 full (26*6,) zero vectors — one per src — set a single bit in each, then OR-reduced them. Compute the 26 src legalities as a (26,) vector and scatter them in one shot. Behaviour identical; ~2x faster. Verified: outputs match the previous implementation across 400 boards x 6 dice and the full mask over 4 dice-sets (0 mismatches); ~1.99x faster full _legal_action_mask (B=256, CPU). --- pgx/backgammon.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/pgx/backgammon.py b/pgx/backgammon.py index 6724f7218..283efd416 100644 --- a/pgx/backgammon.py +++ b/pgx/backgammon.py @@ -514,16 +514,12 @@ def _legal_action_mask_for_valid_single_dice(board: Array, die) -> Array: """ Legal action mask for a single die when the die is valid. """ - src_indices = jnp.arange(26, dtype=jnp.int32) # calc legal action for all src indices - - def _is_legal(idx: Array): - action = idx * 6 + die - legal_action_mask = jnp.zeros(26 * 6, dtype=jnp.bool_) - legal_action_mask = legal_action_mask.at[action].set(_is_action_legal(board, action)) - return legal_action_mask - - legal_action_mask = jax.vmap(_is_legal)(src_indices).any(axis=0) # (26 * 6) - return legal_action_mask + # micro action = 6 * src + die. Compute the 26 src legalities as a (26,) vector and + # scatter them into the mask once, instead of materializing a (26, 26 * 6) intermediate + # (a full zero vector per src, each with a single bit set) and OR-reducing it. + actions = jnp.arange(26, dtype=jnp.int32) * 6 + die + legal = jax.vmap(lambda action: _is_action_legal(board, action))(actions) + return jnp.zeros(26 * 6, dtype=jnp.bool_).at[actions].set(legal) def _get_abs_board(state: State) -> Array: From 56add75a1e9b817c14cde086cb580607f31e0aa2 Mon Sep 17 00:00:00 2001 From: Taro Date: Wed, 10 Jun 2026 20:19:19 +0200 Subject: [PATCH 2/2] [Backgammon] legal_action_mask: hoist board-only predicates out of the per-action vmap _is_to_off_legal recomputed _rear_distance(board) and _is_all_on_home_board(board) for every candidate action, though both depend only on the board. Compute them once per board in the legal-mask leaf and thread them in. _is_action_legal keeps its 2-arg form (computes them itself when omitted), so external callers and tests are unaffected. ~1.12x faster full _legal_action_mask on top of the previous scatter change. Identical output across 300 boards x 5 dice-sets; tests/test_backgammon.py passes. --- pgx/backgammon.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/pgx/backgammon.py b/pgx/backgammon.py index 283efd416..dfc75e83c 100644 --- a/pgx/backgammon.py +++ b/pgx/backgammon.py @@ -401,16 +401,24 @@ def _decompose_action(action: Array): return src, die, tgt -def _is_action_legal(board: Array, action: Array) -> bool: +def _is_action_legal(board: Array, action: Array, rear=None, all_on_home=None) -> bool: """ Check if the action is legal. action = src * 6 + die src = [no op., from bar, 0, .., 23] + + ``rear`` (rear-checker distance) and ``all_on_home`` (bear-off readiness) depend only on + the board. Callers evaluating many actions on the same board (the legal-action mask) pass + them in precomputed so they are not rebuilt per action; standalone callers omit them. """ + if rear is None: + rear = _rear_distance(board) + if all_on_home is None: + all_on_home = _is_all_on_home_board(board) src, die, tgt = _decompose_action(action) _is_to_point = (0 <= tgt) & (tgt <= 23) & (src >= 0) return _is_to_point & _is_to_point_legal(board, src, tgt) | (~_is_to_point) & _is_to_off_legal( - board, src, tgt, die + board, src, tgt, die, rear, all_on_home ) # type: ignore @@ -421,18 +429,21 @@ def _distance_to_goal(src: int) -> int: return 24 - src # type: ignore -def _is_to_off_legal(board: Array, src: int, tgt: int, die: int): +def _is_to_off_legal(board: Array, src: int, tgt: int, die: int, rear: Array, all_on_home: Array): """ Check if the action is legal when the target is off. The conditions are: 1. src has checkers. 2. All checkers are on home board. 3. The distance from the src to the goal is the same as the die or the src is the farthest checker and the die is bigger than the distance. + + ``rear`` (= _rear_distance(board)) and ``all_on_home`` (= _is_all_on_home_board(board)) + are board-only; they are passed in so the legal-action mask computes them once per board + rather than once per candidate action. """ - r = _rear_distance(board) d = _distance_to_goal(src) return ( - (src >= 0) & _exists(board, src) & _is_all_on_home_board(board) & ((d == die) | ((r <= die) & (r == d))) + (src >= 0) & _exists(board, src) & all_on_home & ((d == die) | ((rear <= die) & (rear == d))) ) # type: ignore @@ -517,8 +528,11 @@ def _legal_action_mask_for_valid_single_dice(board: Array, die) -> Array: # micro action = 6 * src + die. Compute the 26 src legalities as a (26,) vector and # scatter them into the mask once, instead of materializing a (26, 26 * 6) intermediate # (a full zero vector per src, each with a single bit set) and OR-reducing it. + # The board-only predicates are computed once here, not per src inside the vmap. + rear = _rear_distance(board) + all_on_home = _is_all_on_home_board(board) actions = jnp.arange(26, dtype=jnp.int32) * 6 + die - legal = jax.vmap(lambda action: _is_action_legal(board, action))(actions) + legal = jax.vmap(lambda action: _is_action_legal(board, action, rear, all_on_home))(actions) return jnp.zeros(26 * 6, dtype=jnp.bool_).at[actions].set(legal)