diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index 433fdb5e0..493d2adf9 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -134,7 +134,20 @@ ZOBRIST_SIDE = jax.random.randint(keys[1], shape=(2,), minval=0, maxval=2**31 - 1, dtype=jnp.uint32) ZOBRIST_CASTLING = jax.random.randint(keys[2], shape=(4, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32) ZOBRIST_EN_PASSANT = jax.random.randint(keys[3], shape=(65, 2), minval=0, maxval=2**31 - 1, dtype=jnp.uint32) -INIT_ZOBRIST_HASH = jnp.uint32([1455170221, 1478960862]) + + +def _init_zobrist_hash() -> Array: + # Derived from the tables above instead of hardcoded: the tables come from jax.random, + # whose outputs are not stable across JAX versions — a stale constant silently breaks + # repetition counting for the initial position. + hash_ = ZOBRIST_SIDE # color == 0 + hash_ ^= lax.reduce(ZOBRIST_BOARD[jnp.arange(64), INIT_BOARD + 6], 0, lax.bitwise_xor, (0,)) + hash_ ^= lax.reduce(ZOBRIST_CASTLING, 0, lax.bitwise_xor, (0,)) # all castling rights + hash_ ^= ZOBRIST_EN_PASSANT[-1] # en_passant == -1 + return hash_ + + +INIT_ZOBRIST_HASH = _init_zobrist_hash() class GameState(NamedTuple):