Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pgx/_src/games/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import numpy as np
from jax import Array, lax

from pgx._src.utils import xor_reduce

EMPTY, PAWN, KNIGHT, BISHOP, ROOK, QUEEN, KING = tuple(range(7)) # opponent: -1 * piece
MAX_TERMINATION_STEPS = 512 # from AlphaZero paper

Expand Down Expand Up @@ -405,8 +407,8 @@ def _is_checked(state: GameState):
def _zobrist_hash(state: GameState) -> Array:
hash_ = lax.select(state.color == 0, ZOBRIST_SIDE, jnp.zeros_like(ZOBRIST_SIDE))
to_reduce = ZOBRIST_BOARD[jnp.arange(64), state.board + 6] # 0, ..., 12 (w:pawn, ..., b:king)
hash_ ^= lax.reduce(to_reduce, 0, lax.bitwise_xor, (0,))
hash_ ^= xor_reduce(to_reduce, 0)
to_reduce = jnp.where(state.castling_rights.reshape(-1, 1), ZOBRIST_CASTLING, 0)
hash_ ^= lax.reduce(to_reduce, 0, lax.bitwise_xor, (0,))
hash_ ^= xor_reduce(to_reduce, 0)
hash_ ^= ZOBRIST_EN_PASSANT[state.en_passant]
return hash_
4 changes: 3 additions & 1 deletion pgx/_src/games/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from jax import Array, lax
from jax import numpy as jnp

from pgx._src.utils import xor_reduce

ZOBRIST_BOARD = jax.random.randint(jax.random.PRNGKey(12345), (3, 19 * 19, 2), 0, 2**31 - 1, jnp.uint32)


Expand Down Expand Up @@ -204,7 +206,7 @@ def _adj_ixs(xy, size):
def _compute_hash(state: GameState):
board = jnp.clip(state.board, -1, 1)
to_reduce = ZOBRIST_BOARD[board, jnp.arange(board.shape[-1])]
return lax.reduce(to_reduce, 0, lax.bitwise_xor, (0,))
return xor_reduce(to_reduce, 0)


def _is_psk(state: GameState):
Expand Down
55 changes: 55 additions & 0 deletions pgx/_src/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,61 @@
import os
import sys
from urllib.request import urlopen

import jax
import jax.numpy as jnp
from jax import Array, lax


def _xor_reduce_bitparity(operand: Array, axis: int) -> Array:
# Metal fallback: the ``bitwise_xor`` reduction primitive fails to legalize on the
# Apple Metal (``jax-metal``) XLA backend
# (``UNIMPLEMENTED: failed to legalize operation 'mhlo.reduce'``). Compute the XOR via
# per-bit parity using only ``sum`` / shifts / bitwise-and. Numerically identical to the
# native reduction, but expands each value into its bits, so it is used only on Metal.
nbits = jnp.iinfo(operand.dtype).bits
bitpos = jnp.arange(nbits, dtype=operand.dtype)
bits = (jnp.expand_dims(operand, -1) >> bitpos) & 1 # (..., nbits)
parity = jnp.sum(bits, axis=axis) & 1 # reduce the requested axis, keep bit axis
weights = jnp.ones((), operand.dtype) << bitpos
return jnp.sum(parity.astype(operand.dtype) * weights, axis=-1).astype(operand.dtype)


def _use_native_xor() -> bool:
"""True everywhere the native ``bitwise_xor`` reduction legalizes (CPU/CUDA/ROCm/TPU);
False only on Apple Metal. Metal is detected by SIGNAL (a 'metal' platform or an 'apple'
device_kind) rather than by an exact backend string, so a CUDA device (platform 'gpu',
NVIDIA device_kind) is never misread as Metal, and a Metal device that reported platform
'gpu' would still be caught by its Apple device_kind. Override with
``PGX_XOR_REDUCE=native|parity``.
"""
forced = os.environ.get("PGX_XOR_REDUCE", "").lower()
if forced in ("native", "parity"):
return forced == "native"
try:
sig = " ".join(
f"{getattr(d, 'platform', '')} {getattr(d, 'device_kind', '')}" for d in jax.devices()
).lower()
except Exception:
return False # can't introspect devices -> safe portable fallback
return not ("metal" in sig or "apple" in sig)


def xor_reduce(operand: Array, axis: int = 0) -> Array:
"""XOR-reduce an unsigned-integer array along ``axis``.

Uses the native ``bitwise_xor`` reduction on CPU/CUDA/TPU, and falls back to a
numerically identical per-bit-parity implementation only on the Apple Metal
(``jax-metal``) backend, where the native reduction fails to legalize. The native path
avoids the per-bit expansion of the fallback (~15x faster on CPU, ~100x on CUDA), which
matters because this runs once per env step for Zobrist hashing. The backend is resolved
at trace time, so jitted code pays no runtime cost for the check; override the choice with
the ``PGX_XOR_REDUCE=native|parity`` environment variable.
"""
if _use_native_xor():
return lax.reduce(operand, jnp.zeros((), operand.dtype), lax.bitwise_xor, (axis,))
return _xor_reduce_bitparity(operand, axis)


def _download(url, filename):
try:
Expand Down