Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 109 additions & 27 deletions pgx/_src/games/chess.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,24 @@
break
BETWEEN[from_, to, i] = c * 8 + r

FROM_PLANE, TO_PLANE, INIT_LEGAL_ACTION_MASK, LEGAL_DEST, LEGAL_DEST_NEAR, LEGAL_DEST_FAR, CAN_MOVE, BETWEEN = (
jnp.array(x) for x in (FROM_PLANE, TO_PLANE, INIT_LEGAL_ACTION_MASK, LEGAL_DEST, LEGAL_DEST_NEAR, LEGAL_DEST_FAR, CAN_MOVE, BETWEEN)
# RAYS[sq, d]: squares along queen-line direction d from sq, nearest first, -1 padded.
# RAY_DIR[sq, to]: direction index d such that to is on RAYS[sq, d], else -1.
_DIRS = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]
RAYS = -np.ones((64, 8, 7), dtype=np.int32)
RAY_DIR = -np.ones((64, 64), dtype=np.int32)
for sq in range(64):
r0, c0 = sq % 8, sq // 8
for d, (dr, dc) in enumerate(_DIRS):
for i in range(1, 8):
r, c = r0 + dr * i, c0 + dc * i
if not (0 <= r < 8 and 0 <= c < 8):
break
RAYS[sq, d, i - 1] = c * 8 + r
RAY_DIR[sq, c * 8 + r] = d
IS_DIAG_DIR = np.array([dr != 0 and dc != 0 for dr, dc in _DIRS], dtype=np.bool_)

FROM_PLANE, TO_PLANE, INIT_LEGAL_ACTION_MASK, LEGAL_DEST, LEGAL_DEST_NEAR, LEGAL_DEST_FAR, CAN_MOVE, BETWEEN, RAYS, RAY_DIR, IS_DIAG_DIR = (
jnp.array(x) for x in (FROM_PLANE, TO_PLANE, INIT_LEGAL_ACTION_MASK, LEGAL_DEST, LEGAL_DEST_NEAR, LEGAL_DEST_FAR, CAN_MOVE, BETWEEN, RAYS, RAY_DIR, IS_DIAG_DIR)
)

keys = jax.random.split(jax.random.PRNGKey(12345), 4)
Expand Down Expand Up @@ -215,7 +231,8 @@ def is_terminal(self, state: GameState) -> Array:
terminated = ~state.legal_action_mask.any()
terminated |= state.halfmove_count >= 100
terminated |= has_insufficient_pieces(state)
rep = (state.hash_history == _zobrist_hash(state)).all(axis=1).sum() - 1
# hash_history[0] always holds the current position's hash (set by _update_history)
rep = (state.hash_history == state.hash_history[0]).all(axis=1).sum() - 1
terminated |= rep >= 2
terminated |= MAX_TERMINATION_STEPS <= state.step_count
return terminated
Expand Down Expand Up @@ -308,17 +325,84 @@ def _flip(state: GameState) -> GameState:


def _legal_action_mask(state: GameState) -> Array:
# Stockfish-style legality: compute checkers, pin rays, and king-danger squares once,
# then every pseudo-legal move is decided by table lookups — no per-move make/unmake.
# En passant is the lone exception (two candidate moves, validated by make-move).
board = state.board
king_pos = jnp.argmin(jnp.abs(board - KING))

# opponent pieces currently giving check
def near_checker(to): # knight/pawn/king patterns
ok = (to >= 0) & (board[to] < 0)
piece = jnp.abs(board[to])
ok &= CAN_MOVE[piece, king_pos, to]
ok &= ~((piece == PAWN) & (to // 8 == king_pos // 8)) # pawns only check diagonally
return jnp.where(ok, to, -1)

def far_checker(to): # distant sliders
ok = (to >= 0) & (board[to] < 0)
piece = jnp.abs(board[to])
ok &= (piece == QUEEN) | (piece == ROOK) | (piece == BISHOP)
between_ixs = BETWEEN[king_pos, to]
ok &= CAN_MOVE[piece, king_pos, to] & ((between_ixs < 0) | (board[between_ixs] == EMPTY)).all()
return jnp.where(ok, to, -1)

checker_sqs = jnp.hstack((
jax.vmap(near_checker)(LEGAL_DEST_NEAR[king_pos]),
jax.vmap(far_checker)(LEGAL_DEST_FAR[king_pos]),
))
checker_mask = jnp.zeros(65, dtype=jnp.bool_).at[checker_sqs].set(True)[:64]
num_checkers = checker_mask.sum()
single_checker = jnp.argmax(checker_mask)

# non-king moves must capture the single checker or block its line (none if double check)
blocking_mask = jnp.zeros(65, dtype=jnp.bool_).at[BETWEEN[king_pos, single_checker]].set(True)[:64]
check_target = jnp.where(
num_checkers == 0,
jnp.ones(64, dtype=jnp.bool_),
jnp.where(num_checkers == 1, blocking_mask | (jnp.arange(64) == single_checker), jnp.zeros(64, dtype=jnp.bool_)),
)

# absolute pins: first own piece along each king ray, backed by a matching enemy slider
def pin_dir(d):
ray = RAYS[king_pos, d]
vals = jnp.where(ray >= 0, board[ray], 0)
occ = vals != 0
i1 = jnp.argmax(occ)
first_is_mine = occ.any() & (vals[i1] > 0)
occ2 = occ & (jnp.arange(7) > i1)
v2 = jnp.where(occ2.any(), vals[jnp.argmax(occ2)], 0)
slider = jnp.where(IS_DIAG_DIR[d], (v2 == -QUEEN) | (v2 == -BISHOP), (v2 == -QUEEN) | (v2 == -ROOK))
return jnp.where(first_is_mine & slider, ray[i1], -1)

pinned_sqs = jax.vmap(pin_dir)(jnp.arange(8))
pinned_dir = jnp.full(65, -1, dtype=jnp.int32).at[pinned_sqs].set(jnp.arange(8, dtype=jnp.int32))

# squares the king may not step onto, with the king itself lifted off the board
# (a slider keeps attacking "through" the square the king vacates)
board_wo_king = board.at[king_pos].set(EMPTY)
# A king has at most 8 destinations; LEGAL_DEST is padded to 27 (the queen's max), so
# the tail is always -1 for the king. Slicing to [:8] drops 19 guaranteed-empty lanes,
# each of which would otherwise run a full _is_attacked probe.
king_dests = LEGAL_DEST[KING, king_pos, :8]
danger = jax.vmap(lambda to: (to >= 0) & _is_attacked(board_wo_king, to))(king_dests)
king_danger = jnp.zeros(65, dtype=jnp.bool_).at[jnp.where(danger, king_dests, 64)].set(True)[:64]

def legal_normal_moves(from_):
piece = state.board[from_]
piece = board[from_]

def legal_label(to):
ok = (from_ >= 0) & (piece > 0) & (to >= 0) & (state.board[to] <= 0)
ok = (from_ >= 0) & (piece > 0) & (to >= 0) & (board[to] <= 0)
between_ixs = BETWEEN[from_, to]
ok &= CAN_MOVE[piece, from_, to] & ((between_ixs < 0) | (state.board[between_ixs] == EMPTY)).all()
ok &= CAN_MOVE[piece, from_, to] & ((between_ixs < 0) | (board[between_ixs] == EMPTY)).all()
c0, c1 = from_ // 8, to // 8
pawn_should = ((c1 == c0) & (state.board[to] == EMPTY)) | ((c1 != c0) & (state.board[to] < 0))
pawn_should = ((c1 == c0) & (board[to] == EMPTY)) | ((c1 != c0) & (board[to] < 0))
ok &= (piece != PAWN) | pawn_should
return lax.select(ok, Action(from_=from_, to=to)._to_label(), -1)
# check/pin legality via the precomputed masks
pin_d = pinned_dir[from_]
non_king_ok = check_target[to] & ((pin_d < 0) | (RAY_DIR[king_pos, to] == pin_d))
ok &= jnp.where(piece == KING, ~king_danger[to], non_king_ok)
return jnp.where(ok, Action(from_=from_, to=to)._to_label(), -1)

return jax.vmap(legal_label)(LEGAL_DEST[piece, from_])

Expand All @@ -333,8 +417,8 @@ def legal_labels(from_):
return jax.vmap(legal_labels)(jnp.int32([to - 9, to + 7]))

def is_not_checked(label):
a = Action._from_label(label)
return ~_is_checked(_apply_move(state, a))
a = Action._from_label(jnp.maximum(label, 0))
return (label >= 0) & ~_is_checked(_apply_move(state, a))

def legal_underpromotions(mask):
def legal_labels(label):
Expand All @@ -346,26 +430,24 @@ def legal_labels(label):
labels = jnp.int32([from_ * 73 + i for i in range(9) for from_ in [6, 14, 22, 30, 38, 46, 54, 62]])
return jax.vmap(legal_labels)(labels)

# normal move and en passant
# normal moves (already fully legal thanks to the masks above)
possible_piece_positions = jnp.nonzero(state.board > 0, size=16, fill_value=-1)[0]
a1 = jax.vmap(legal_normal_moves)(possible_piece_positions).flatten()
a2 = legal_en_passants()
actions = jnp.hstack((a1, a2)) # include -1
# filter out -1. 200 is big enough for normal play.
ixs = jnp.nonzero(actions >= 0, size=200, fill_value=0)[0]
actions = actions[ixs] # size: 19 * 27 -> 200
# filter ignoring checks and suicides
actions = jnp.where(jax.vmap(is_not_checked)(actions), actions, -1)
mask = jnp.zeros(64 * 73 + 1, dtype=jnp.bool_) # +1 for sentinel
mask = mask.at[actions].set(True)
mask = mask.at[a1].set(True)

# en passant: rare and full of edge cases (rank pins, capturing the checker) — make-move test
a2 = legal_en_passants()
a2 = jnp.where(jax.vmap(is_not_checked)(a2), a2, -1)
mask = mask.at[a2].set(True)

# castling
b = state.board
can_castle_queen_side = state.castling_rights[0, 0]
can_castle_queen_side &= (b[0] == ROOK) & (b[8] == EMPTY) & (b[16] == EMPTY) & (b[24] == EMPTY) & (b[32] == KING)
can_castle_king_side = state.castling_rights[0, 1]
can_castle_king_side &= (b[32] == KING) & (b[40] == EMPTY) & (b[48] == EMPTY) & (b[56] == ROOK)
not_checked = ~jax.vmap(_is_attacked, in_axes=(None, 0))(state, jnp.int32([16, 24, 32, 40, 48]))
not_checked = ~jax.vmap(_is_attacked, in_axes=(None, 0))(state.board, jnp.int32([16, 24, 32, 40, 48]))
mask = mask.at[2364].set(mask[2364] | (can_castle_queen_side & not_checked[:3].all()))
mask = mask.at[2367].set(mask[2367] | (can_castle_king_side & not_checked[2:].all()))

Expand All @@ -376,18 +458,18 @@ def legal_labels(label):
return mask[:-1]


def _is_attacked(state: GameState, pos: Array):
def _is_attacked(board: Array, pos: Array):
def attacked_far(to):
ok = (to >= 0) & (state.board[to] < 0) # should be opponent's
piece = jnp.abs(state.board[to])
ok = (to >= 0) & (board[to] < 0) # should be opponent's
piece = jnp.abs(board[to])
ok &= (piece == QUEEN) | (piece == ROOK) | (piece == BISHOP)
between_ixs = BETWEEN[pos, to]
ok &= CAN_MOVE[piece, pos, to] & ((between_ixs < 0) | (state.board[between_ixs] == EMPTY)).all()
ok &= CAN_MOVE[piece, pos, to] & ((between_ixs < 0) | (board[between_ixs] == EMPTY)).all()
return ok

def attacked_near(to):
ok = (to >= 0) & (state.board[to] < 0) # should be opponent's
piece = jnp.abs(state.board[to])
ok = (to >= 0) & (board[to] < 0) # should be opponent's
piece = jnp.abs(board[to])
ok &= CAN_MOVE[piece, pos, to]
ok &= ~((piece == PAWN) & (to // 8 == pos // 8)) # should move diagonally to capture
return ok
Expand All @@ -399,7 +481,7 @@ def attacked_near(to):

def _is_checked(state: GameState):
king_pos = jnp.argmin(jnp.abs(state.board - KING))
return _is_attacked(state, king_pos)
return _is_attacked(state.board, king_pos)


def _zobrist_hash(state: GameState) -> Array:
Expand Down
160 changes: 160 additions & 0 deletions tests/diff_vs_python_chess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Differential test: pgx chess legal move generation vs python-chess (ground truth).
# Standalone dev tool, intentionally not named test_* (pgx's pytest suite is heavy).
#
# Usage (keep it to ONE process, the platform allocator returns memory to the OS):
# XLA_PYTHON_CLIENT_PREALLOCATE=false XLA_PYTHON_CLIENT_ALLOCATOR=platform \
# python tests/diff_vs_python_chess.py
import random

import jax
import jax.numpy as jnp
import numpy as np

import chess as pychess

from pgx._src.games import chess as C

FROM_PLANE = np.asarray(C.FROM_PLANE)
TO_PLANE = np.asarray(C.TO_PLANE)
UNDER = [pychess.ROOK, pychess.BISHOP, pychess.KNIGHT]

game = C.Game()
jit_mask = jax.jit(C._legal_action_mask)
jit_step = jax.jit(game.step)
jit_terminal = jax.jit(game.is_terminal)


def pgx_sq(sq_pychess: int, black: bool) -> int:
rank, file = sq_pychess // 8, sq_pychess % 8
if black:
rank = 7 - rank
return file * 8 + rank


def pychess_sq(sq_pgx: int, black: bool) -> int:
file, rank = sq_pgx // 8, sq_pgx % 8
if black:
rank = 7 - rank
return rank * 8 + file


def decode(mask, black: bool, board: pychess.Board) -> set:
moves = set()
for label in np.nonzero(np.asarray(mask))[0]:
f, plane = label // 73, label % 73
t = FROM_PLANE[f, plane]
fr, to = pychess_sq(int(f), black), pychess_sq(int(t), black)
promo = None
if plane < 9:
promo = UNDER[plane // 3]
elif board.piece_type_at(fr) == pychess.PAWN and pychess.square_rank(to) in (0, 7):
promo = pychess.QUEEN
moves.add(pychess.Move(fr, to, promotion=promo))
return moves


def encode(mv: pychess.Move, black: bool) -> int:
f, t = pgx_sq(mv.from_square, black), pgx_sq(mv.to_square, black)
if mv.promotion in (None, pychess.QUEEN):
return f * 73 + int(TO_PLANE[f, t])
direc = {1: 0, 9: 1, -7: 2}[t - f] # up, right, left
return f * 73 + {pychess.ROOK: 0, pychess.BISHOP: 1, pychess.KNIGHT: 2}[mv.promotion] * 3 + direc


def state_from_board(board: pychess.Board) -> C.GameState:
"""Build a GameState from a python-chess board (current player always positive/up)."""
black = board.turn == pychess.BLACK
arr = np.zeros(64, dtype=np.int32)
for sq, piece in board.piece_map().items():
sign = 1 if (piece.color == board.turn) else -1
arr[pgx_sq(sq, black)] = sign * piece.piece_type
my, opp = board.turn, not board.turn
castling = jnp.bool_([
[board.has_queenside_castling_rights(my), board.has_kingside_castling_rights(my)],
[board.has_queenside_castling_rights(opp), board.has_kingside_castling_rights(opp)],
])
ep = jnp.int32(-1 if board.ep_square is None else pgx_sq(board.ep_square, black))
x = C.GameState(
board=jnp.asarray(arr),
color=jnp.int32(1 if black else 0),
castling_rights=castling,
en_passant=ep,
hash_history=jnp.zeros_like(C.GameState().hash_history),
board_history=jnp.zeros_like(C.GameState().board_history),
)
return x._replace(legal_action_mask=jit_mask(x))


def compare(board: pychess.Board, x: C.GameState, ctx: str) -> bool:
black = board.turn == pychess.BLACK
got = decode(x.legal_action_mask, black, board)
ref = set(board.legal_moves)
if got != ref:
print(f"MISMATCH {ctx}\n fen={board.fen()}\n pgx-only={got - ref}\n ref-only={ref - got}")
return False
return True


FENS = [
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", # startpos
"r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 1", # kiwipete
"8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 1", # perft pos3 (ep + pins)
"r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", # perft pos4
"r2q1rk1/pP1p2pp/Q4n2/bbp1p3/Np6/1B3NBn/pPPP1PPP/R3K2R b KQ - 0 1", # pos4 mirrored
"rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", # perft pos5
"r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", # perft pos6
"R6R/3Q4/1Q4Q1/4Q3/2Q4Q/Q4Q2/pp1Q4/kBNN1KB1 w - - 0 1", # 218 legal moves (cap regression)
"8/8/8/8/k2Pp2Q/8/8/3K4 b - d3 0 1", # ep illegal: horizontal x-ray after both pawns vanish
"8/8/8/2k5/3Pp3/8/8/4K3 b - d3 0 1", # ep capture of the checking double-pushed pawn
"8/8/4k3/8/2pP4/8/B7/7K b - d3 0 1", # ep with diagonal pin geometry
"3k3r/8/8/8/3n4/8/8/3R2K1 b - - 0 1", # pinned knight: no moves at all
"r3k2r/8/5q2/8/8/8/8/R3K2R w KQkq - 0 1", # castling vs attacked squares
"r3k2r/8/8/8/8/8/8/R3K2R b KQkq - 0 1", # symmetric castling, black to move
"8/P6k/8/8/8/8/7K/8 w - - 0 1", # promotion (queen + under)
"8/8/8/8/8/2k5/1q6/K7 w - - 0 1", # nearly stalemated king
"4k3/8/8/8/8/8/1q6/R3K2N b - - 0 1", # contact + discovered check potential
]


def run_fen_suite() -> int:
bad = 0
for fen in FENS:
board = pychess.Board(fen)
assert board.is_valid(), f"invalid test fen: {fen}"
x = state_from_board(board)
if not compare(board, x, "depth1"):
bad += 1
continue
for mv in board.legal_moves: # depth 2: every child
child = board.copy()
child.push(mv)
x2 = jit_step(x, jnp.int32(encode(mv, board.turn == pychess.BLACK)))
if not compare(child, x2, f"depth2 after {mv.uci()}"):
bad += 1
print(f"FEN suite: {len(FENS)} positions, depth-2 expansion -> {bad} mismatches")
return bad


def run_random_games(n_games: int = 20, max_plies: int = 160) -> int:
random.seed(42)
bad = 0
for g in range(n_games):
s = game.init()
board = pychess.Board()
for _ in range(max_plies):
if not compare(board, s, f"random game {g}"):
bad += 1
break
if bool(jit_terminal(s)) or board.is_game_over(claim_draw=True):
break
mv = random.choice(sorted(board.legal_moves, key=str))
s = jit_step(s, jnp.int32(encode(mv, board.turn == pychess.BLACK)))
board.push(mv)
print(f"random games: {n_games} -> {bad} mismatches")
return bad


if __name__ == "__main__":
failures = run_fen_suite() + run_random_games()
print("RESULT:", "PASS" if failures == 0 else f"FAIL ({failures})")
raise SystemExit(0 if failures == 0 else 1)