From dff23b9672cd5e605cfe635ef50f5b6d483d9da5 Mon Sep 17 00:00:00 2001 From: Taro Date: Wed, 10 Jun 2026 15:15:07 +0200 Subject: [PATCH] Chess: derive INIT_ZOBRIST_HASH from the Zobrist tables at import time The Zobrist tables are generated at import time from jax.random, whose outputs are not stable across JAX versions (e.g. the threefry partitionable change). The hardcoded INIT_ZOBRIST_HASH no longer matches _zobrist_hash(GameState()) under current JAX, so the initial position's hash_history entry never matches later recomputed hashes: - repetitions involving the start position are counted one short (a g1f3/g8f6/f3g1/f6g8 knight shuffle terminates at ply 9 instead of 8) - the repetition observation planes are wrong for the startpos entry Compute the constant from the actual tables at import time so it can never go stale again. No behavior change on JAX versions where the old constant happened to match. --- pgx/_src/games/chess.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index 433fdb5e0..493d2adf9 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -134,7 +134,20 @@ ZOBRIST_SIDE = jax.random.randint(keys[1], shape=(2,), minval=0, maxval=2**31 - 1, dtype=jnp.uint32) ZOBRIST_CASTLING = jax.random.randint(keys[2], shape=(4, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32) ZOBRIST_EN_PASSANT = jax.random.randint(keys[3], shape=(65, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32) -INIT_ZOBRIST_HASH = jnp.uint32([1455170221, 1478960862]) + + +def _init_zobrist_hash() -> Array: + # Derived from the tables above instead of hardcoded: the tables come from jax.random, + # whose outputs are not stable across JAX versions — a stale constant silently breaks + # repetition counting for the initial position. + hash_ = ZOBRIST_SIDE # color == 0 + hash_ ^= lax.reduce(ZOBRIST_BOARD[jnp.arange(64), INIT_BOARD + 6], 0, lax.bitwise_xor, (0,)) + hash_ ^= lax.reduce(ZOBRIST_CASTLING, 0, lax.bitwise_xor, (0,)) # all castling rights + hash_ ^= ZOBRIST_EN_PASSANT[-1] # en_passant == -1 + return hash_ + + +INIT_ZOBRIST_HASH = _init_zobrist_hash() class GameState(NamedTuple):