From 3e25a5c2e6cc8a90fecc4e7c115236880fe54567 Mon Sep 17 00:00:00 2001 From: Taro Date: Mon, 15 Jun 2026 09:39:59 +0200 Subject: [PATCH] animal_shogi: fix path-dependent Zobrist hash that undercounted repetitions The Zobrist hash was not a pure function of position: - on capture, the captured piece was removed from the hash at the moving piece's `from` square instead of its actual `to` square (_step_move) - the drop hand-count update used {num_hand+1, num_hand} while capture used {num_hand-1, num_hand}, so drop was not the inverse of capture (_step_drop) - the hand term was indexed by the post-flip turn This made the hash path-dependent, so repetition detection silently undercounted (two identical positions could hash differently), affecting sennichite draws and the repetition observation planes. Fix: recompute the hash from scratch as a pure function of the canonical position each step. Adds a regression test (fails on the old code). --- pgx/animal_shogi.py | 63 ++++++++++++++++++++------------------ tests/test_animal_shogi.py | 32 +++++++++++++++++++ 2 files changed, 66 insertions(+), 29 deletions(-) diff --git a/pgx/animal_shogi.py b/pgx/animal_shogi.py index 0cbe788a7..6653cef5c 100644 --- a/pgx/animal_shogi.py +++ b/pgx/animal_shogi.py @@ -66,6 +66,31 @@ MAX_TERMINATION_STEPS = 256 +def _compute_hash(board, hand, turn): + # Zobrist hash as a PURE function of the canonical (turn-0 frame) position. The previous + # incremental updates were not invertible (captured piece removed from the wrong square; drop + # hand-count off-by-one; hand side indexed by the post-flip turn), so the hash was + # path-dependent and repetition detection silently undercounted. Recomputing from scratch makes + # equal positions hash equally by construction. + sq = jnp.arange(12) + csq = jnp.where(turn == 0, sq, 11 - sq) + cpiece = jnp.where(turn == 0, board, (board + 5) % 10) + occ = board != EMPTY + terms = jnp.where(occ[:, None], ZOBRIST_BOARD[csq, cpiece], jnp.uint32(0)) # (12, 2) + h = jnp.zeros(2, dtype=jnp.uint32) + for i in range(12): + h = h ^ terms[i] + cside = jnp.array([turn, 1 - turn]) # hand[0]=mine -> side `turn`, hand[1]=opp -> side `1-turn` + for s in range(2): + for t in range(3): + h = h ^ ZOBRIST_HAND[cside[s], t, hand[s, t]] + h = h ^ jnp.where(turn == 1, ZOBRIST_SIDE, jnp.zeros(2, dtype=jnp.uint32)) + return h + + +INIT_ZOBRIST_HASH = _compute_hash(INIT_BOARD, jnp.zeros((2, 3), dtype=jnp.int32), jnp.int32(0)) + + @dataclass class State(core.State): current_player: Array = jnp.int32(0) @@ -79,9 +104,9 @@ class State(core.State): _turn: Array = jnp.int32(0) _board: Array = INIT_BOARD # (12,) _hand: Array = jnp.zeros((2, 3), dtype=jnp.int32) - _zobrist_hash: Array = jnp.uint32([233882788, 593924309]) + _zobrist_hash: Array = INIT_ZOBRIST_HASH _hash_history: Array = ( - jnp.zeros((MAX_TERMINATION_STEPS + 1, 2), dtype=jnp.uint32).at[0].set(jnp.uint32([233882788, 593924309])) + jnp.zeros((MAX_TERMINATION_STEPS + 1, 2), dtype=jnp.uint32).at[0].set(INIT_ZOBRIST_HASH) ) _board_history: Array = (-jnp.ones((8, 12), dtype=jnp.int32)).at[0, :].set(INIT_BOARD) _hand_history: Array = jnp.zeros((8, 6), dtype=jnp.int32) @@ -168,6 +193,9 @@ def _step(state: State, action: Array): state = state.replace(_hand_history=hand_history) # type:ignore state = _flip(state) + # recompute the Zobrist hash from scratch (pure function of position) to avoid the + # path-dependence of the incremental updates in _step_move/_step_drop + state = state.replace(_zobrist_hash=_compute_hash(state._board, state._hand, state._turn)) # type: ignore state = state.replace( # type: ignore _hash_history=state._hash_history.at[state._step_count].set(state._zobrist_hash), ) @@ -235,9 +263,6 @@ def _step_move(state: State, action: Action) -> State: piece = state._board[action.from_] # remove piece from the original position board = state._board.at[action.from_].set(EMPTY) - zb_from_ = jax.lax.select(state._turn == 0, action.from_, 11 - action.from_) - zb_piece = jax.lax.select(state._turn == 0, piece, (piece + 5) % 10) - zobrist_hash = state._zobrist_hash ^ ZOBRIST_BOARD[zb_from_, zb_piece] # capture the opponent if exists captured = board[action.to] # suppose >= OPP_PAWN, -1 if EMPTY hand = jax.lax.cond( @@ -248,42 +273,22 @@ def _step_move(state: State, action: Action) -> State: # (2) filtering promoted piece by x % 4 lambda: state._hand.at[0, (captured % 5) % 4].add(1), ) - zobrist_hash = jax.lax.select( - captured == EMPTY, - zobrist_hash, - zobrist_hash - ^ ZOBRIST_BOARD[ - zb_from_, - jax.lax.select(state._turn == 0, captured, (captured + 5) % 10), - ], - ) - num_hand = hand[0, (captured % 5) % 4] - zobrist_hash ^= ZOBRIST_HAND[state._turn, (captured % 5) % 4, num_hand - 1] - zobrist_hash ^= ZOBRIST_HAND[state._turn, (captured % 5) % 4, num_hand] # promote piece (PAWN to GOLD) is_promotion = (action.from_ % 4 == 1) & (piece == PAWN) piece = jax.lax.select(is_promotion, GOLD, piece) # set piece to the target position board = board.at[action.to].set(piece) - zb_to_ = jax.lax.select(state._turn == 0, action.to, 11 - action.to) - zb_piece = jax.lax.select(state._turn == 0, piece, (piece + 5) % 10) - zobrist_hash ^= ZOBRIST_BOARD[zb_to_, zb_piece] - # apply piece moves - return state.replace(_board=board, _hand=hand, _zobrist_hash=zobrist_hash) # type: ignore + # _zobrist_hash is recomputed from scratch in _step (pure function of position) + return state.replace(_board=board, _hand=hand) # type: ignore def _step_drop(state: State, action: Action) -> State: # add piece to board board = state._board.at[action.to].set(action.drop_piece) - zb_to_ = jax.lax.select(state._turn == 0, action.to, 11 - action.to) - zb_piece = jax.lax.select(state._turn == 0, action.drop_piece, (action.drop_piece + 5) % 10) - zobrist_hash = state._zobrist_hash ^ ZOBRIST_BOARD[zb_to_, zb_piece] # remove piece from hand hand = state._hand.at[0, action.drop_piece].add(-1) - num_hand = state._hand[0, action.drop_piece] - zobrist_hash ^= ZOBRIST_HAND[state._turn, action.drop_piece, num_hand + 1] - zobrist_hash ^= ZOBRIST_HAND[state._turn, action.drop_piece, num_hand] - return state.replace(_board=board, _hand=hand, _zobrist_hash=zobrist_hash) # type: ignore + # _zobrist_hash is recomputed from scratch in _step (pure function of position) + return state.replace(_board=board, _hand=hand) # type: ignore def _legal_action_mask(state: State): diff --git a/tests/test_animal_shogi.py b/tests/test_animal_shogi.py index eb3ea612b..0b58e8b61 100644 --- a/tests/test_animal_shogi.py +++ b/tests/test_animal_shogi.py @@ -253,3 +253,35 @@ def test_buggy_samples(): state = step(state, 0 * 12 + 11) # Black: Right Up Bishop DROP_PAWN_TO_0 = 8 * 12 + 0 assert state.legal_action_mask[DROP_PAWN_TO_0] + + +def test_zobrist_hash_is_pure_position_function(): + # Regression: the Zobrist hash must depend only on (board, hand, turn). The previous + # incremental updates were path-dependent -- the captured piece was removed from the hash at the + # moving piece's `from` square instead of its `to` square, the drop hand-count update was + # off-by-one vs the capture path, and the hand term was indexed by the post-flip turn. As a + # result repetition detection silently undercounted (e.g. position at ply N == ply N+k but with + # different hashes). Here we roll out random self-play (which exercises captures and drops) and + # assert that any two states with identical (board, hand, turn) have identical Zobrist hashes. + key = jax.random.PRNGKey(0) + seen = {} + for _ in range(50): + key, k = jax.random.split(key) + state = init(k) + for _ in range(60): + if bool(state.terminated): + break + pos = ( + tuple(int(x) for x in state._board.tolist()), + tuple(int(x) for x in state._hand.flatten().tolist()), + int(state._turn), + ) + h = tuple(int(x) for x in state._zobrist_hash.tolist()) + if pos in seen: + assert seen[pos] == h, f"equal position with different Zobrist hash: {pos}" + else: + seen[pos] = h + key, ak = jax.random.split(key) + legal = jnp.nonzero(state.legal_action_mask)[0] + a = legal[jax.random.randint(ak, (), 0, legal.shape[0])] + state = step(state, a)