From 2a9c42a86a70a07d9c3741f6596fe52a9c8a08bf Mon Sep 17 00:00:00 2001 From: gweber Date: Wed, 10 Jun 2026 13:37:53 +0200 Subject: [PATCH 1/3] Make Zobrist hashing Apple Metal (jax-metal) compatible lax.reduce with bitwise_xor fails to legalize on the Metal XLA backend (UNIMPLEMENTED: mhlo.reduce). Replace the XOR-reduction in chess and go Zobrist hashing with an equivalent bit-parity computation built from sum/ shifts (utils.xor_reduce), numerically identical on every backend. This unblocks chess and go self-play on Apple Silicon GPUs. --- pgx/_src/games/chess.py | 6 ++++-- pgx/_src/games/go.py | 4 +++- pgx/_src/utils.py | 20 ++++++++++++++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index 433fdb5e0..310ff8d1c 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -19,6 +19,8 @@ import numpy as np from jax import Array, lax +from pgx._src.utils import xor_reduce + EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = tuple(range(7)) # opponent: -1 * piece MAX_TERMINATION_STEPS = 512 # from AlphaZero paper @@ -405,8 +407,8 @@ def _is_checked(state: GameState): def _zobrist_hash(state: GameState) -> Array: hash_ = lax.select(state.color == 0, ZOBRIST_SIDE, jnp.zeros_like(ZOBRIST_SIDE)) to_reduce = ZOBRIST_BOARD[jnp.arange(64), state.board + 6] # 0, ..., 12 (w:pawn, ..., b:king) - hash_ ^= lax.reduce(to_reduce, 0, lax.bitwise_xor, (0,)) + hash_ ^= xor_reduce(to_reduce, 0) to_reduce = jnp.where(state.castling_rights.reshape(-1, 1), ZOBRIST_CASTLING, 0) - hash_ ^= lax.reduce(to_reduce, 0, lax.bitwise_xor, (0,)) + hash_ ^= xor_reduce(to_reduce, 0) hash_ ^= ZOBRIST_EN_PASSANT[state.en_passant] return hash_ diff --git a/pgx/_src/games/go.py b/pgx/_src/games/go.py index 9593a406a..5001849e9 100644 --- a/pgx/_src/games/go.py +++ b/pgx/_src/games/go.py @@ -18,6 +18,8 @@ from jax import Array, lax from jax import numpy as jnp +from pgx._src.utils import xor_reduce + ZOBRIST_BOARD = jax.random.randint(jax.random.PRNGKey(12345), (3, 19 * 19, 2), 0, 2**31 - 1, jnp.uint32) @@ -204,7 +206,7 @@ def _adj_ixs(xy, size): def _compute_hash(state: GameState): board = jnp.clip(state.board, -1, 1) to_reduce = ZOBRIST_BOARD[board, jnp.arange(board.shape[-1])] - return lax.reduce(to_reduce, 0, lax.bitwise_xor, (0,)) + return xor_reduce(to_reduce, 0) def _is_psk(state: GameState): diff --git a/pgx/_src/utils.py b/pgx/_src/utils.py index 2e50b2252..7045d4ff1 100644 --- a/pgx/_src/utils.py +++ b/pgx/_src/utils.py @@ -1,6 +1,26 @@ import sys from urllib.request import urlopen +import jax.numpy as jnp +from jax import Array + + +def xor_reduce(operand: Array, axis: int = 0) -> Array: + """XOR-reduce an unsigned-integer array along ``axis`` via per-bit parity. + + Equivalent to ``lax.reduce(operand, 0, lax.bitwise_xor, (axis,))`` but built only from + ``sum`` / shifts / bitwise-and, because the ``bitwise_xor`` reduction primitive fails to + legalize on the Apple Metal (``jax-metal``) XLA backend + (``UNIMPLEMENTED: failed to legalize operation 'mhlo.reduce'``). Numerically identical + on every backend; used for Zobrist hashing so chess/go/etc. run on Apple Silicon GPUs. + """ + nbits = jnp.iinfo(operand.dtype).bits + bitpos = jnp.arange(nbits, dtype=operand.dtype) + bits = (jnp.expand_dims(operand, -1) >> bitpos) & 1 # (..., nbits) + parity = jnp.sum(bits, axis=axis) & 1 # reduce the requested axis, keep bit axis + weights = jnp.ones((), operand.dtype) << bitpos + return jnp.sum(parity.astype(operand.dtype) * weights, axis=-1).astype(operand.dtype) + def _download(url, filename): try: From 62e1c8465114baacadd84d998521d86d7e69ae60 Mon Sep 17 00:00:00 2001 From: Taro Date: Fri, 12 Jun 2026 03:18:36 +0200 Subject: [PATCH 2/3] xor_reduce: use native bitwise_xor reduction except on Metal Keep the per-bit-parity path (which makes the reduction legalize on jax-metal) only on the Metal backend, and use the native bitwise_xor reduction on CPU/CUDA/TPU. The bit-parity fallback expands every value into its bits, which is ~15x slower on CPU and ~100x on CUDA; since this runs once per env step for Zobrist hashing it measurably slowed chess/go on the non-Metal backends. Backend is resolved at trace time, so jitted code pays no runtime cost. Hashes are bit-identical on all backends. --- pgx/_src/utils.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/pgx/_src/utils.py b/pgx/_src/utils.py index 7045d4ff1..8328c4c4c 100644 --- a/pgx/_src/utils.py +++ b/pgx/_src/utils.py @@ -1,19 +1,17 @@ import sys from urllib.request import urlopen +import jax import jax.numpy as jnp -from jax import Array +from jax import Array, lax -def xor_reduce(operand: Array, axis: int = 0) -> Array: - """XOR-reduce an unsigned-integer array along ``axis`` via per-bit parity. - - Equivalent to ``lax.reduce(operand, 0, lax.bitwise_xor, (axis,))`` but built only from - ``sum`` / shifts / bitwise-and, because the ``bitwise_xor`` reduction primitive fails to - legalize on the Apple Metal (``jax-metal``) XLA backend - (``UNIMPLEMENTED: failed to legalize operation 'mhlo.reduce'``). Numerically identical - on every backend; used for Zobrist hashing so chess/go/etc. run on Apple Silicon GPUs. - """ +def _xor_reduce_bitparity(operand: Array, axis: int) -> Array: + # Metal fallback: the ``bitwise_xor`` reduction primitive fails to legalize on the + # Apple Metal (``jax-metal``) XLA backend + # (``UNIMPLEMENTED: failed to legalize operation 'mhlo.reduce'``). Compute the XOR via + # per-bit parity using only ``sum`` / shifts / bitwise-and. Numerically identical to the + # native reduction, but expands each value into its bits, so it is used only on Metal. nbits = jnp.iinfo(operand.dtype).bits bitpos = jnp.arange(nbits, dtype=operand.dtype) bits = (jnp.expand_dims(operand, -1) >> bitpos) & 1 # (..., nbits) @@ -22,6 +20,21 @@ def xor_reduce(operand: Array, axis: int = 0) -> Array: return jnp.sum(parity.astype(operand.dtype) * weights, axis=-1).astype(operand.dtype) +def xor_reduce(operand: Array, axis: int = 0) -> Array: + """XOR-reduce an unsigned-integer array along ``axis``. + + Uses the native ``bitwise_xor`` reduction on CPU/CUDA/TPU, and falls back to a + numerically identical per-bit-parity implementation only on the Apple Metal + (``jax-metal``) backend, where the native reduction fails to legalize. The native path + avoids the per-bit expansion of the fallback (~15x faster on CPU, ~100x on CUDA), which + matters because this runs once per env step for Zobrist hashing. The backend is resolved + at trace time, so jitted code pays no runtime cost for the check. + """ + if "metal" in jax.default_backend().lower(): + return _xor_reduce_bitparity(operand, axis) + return lax.reduce(operand, jnp.zeros((), operand.dtype), lax.bitwise_xor, (axis,)) + + def _download(url, filename): try: print(f"Downloading from {url} ...", file=sys.stderr) From b4d97c6c5a10d5e8dc9b6e075338c5a4b5e4c860 Mon Sep 17 00:00:00 2001 From: Taro Date: Fri, 12 Jun 2026 09:29:14 +0200 Subject: [PATCH 3/3] xor_reduce: signal-based Metal detection + PGX_XOR_REDUCE override Harden the backend check from a default_backend() string match to a device SIGNAL (platform + device_kind contains 'metal'/'apple'), so a CUDA device reporting platform 'gpu' is never misread as Metal, and a Metal device reporting platform 'gpu' is still caught by its Apple device_kind. Add a PGX_XOR_REDUCE=native|parity escape hatch; fall back to the parity path if devices can't be introspected. No behavior change on CPU/CUDA/Metal (verified native==parity==ground-truth). --- pgx/_src/utils.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/pgx/_src/utils.py b/pgx/_src/utils.py index 8328c4c4c..f67a43a64 100644 --- a/pgx/_src/utils.py +++ b/pgx/_src/utils.py @@ -1,3 +1,4 @@ +import os import sys from urllib.request import urlopen @@ -20,6 +21,26 @@ def _xor_reduce_bitparity(operand: Array, axis: int) -> Array: return jnp.sum(parity.astype(operand.dtype) * weights, axis=-1).astype(operand.dtype) +def _use_native_xor() -> bool: + """True everywhere the native ``bitwise_xor`` reduction legalizes (CPU/CUDA/ROCm/TPU); + False only on Apple Metal. Metal is detected by SIGNAL (a 'metal' platform or an 'apple' + device_kind) rather than by an exact backend string, so a CUDA device (platform 'gpu', + NVIDIA device_kind) is never misread as Metal, and a Metal device that reported platform + 'gpu' would still be caught by its Apple device_kind. Override with + ``PGX_XOR_REDUCE=native|parity``. + """ + forced = os.environ.get("PGX_XOR_REDUCE", "").lower() + if forced in ("native", "parity"): + return forced == "native" + try: + sig = " ".join( + f"{getattr(d, 'platform', '')} {getattr(d, 'device_kind', '')}" for d in jax.devices() + ).lower() + except Exception: + return False # can't introspect devices -> safe portable fallback + return not ("metal" in sig or "apple" in sig) + + def xor_reduce(operand: Array, axis: int = 0) -> Array: """XOR-reduce an unsigned-integer array along ``axis``. @@ -28,11 +49,12 @@ def xor_reduce(operand: Array, axis: int = 0) -> Array: (``jax-metal``) backend, where the native reduction fails to legalize. The native path avoids the per-bit expansion of the fallback (~15x faster on CPU, ~100x on CUDA), which matters because this runs once per env step for Zobrist hashing. The backend is resolved - at trace time, so jitted code pays no runtime cost for the check. + at trace time, so jitted code pays no runtime cost for the check; override the choice with + the ``PGX_XOR_REDUCE=native|parity`` environment variable. """ - if "metal" in jax.default_backend().lower(): - return _xor_reduce_bitparity(operand, axis) - return lax.reduce(operand, jnp.zeros((), operand.dtype), lax.bitwise_xor, (axis,)) + if _use_native_xor(): + return lax.reduce(operand, jnp.zeros((), operand.dtype), lax.bitwise_xor, (axis,)) + return _xor_reduce_bitparity(operand, axis) def _download(url, filename):