diff --git a/pgx/_src/games/chess.py b/pgx/_src/games/chess.py index 433fdb5e0..ff5f25731 100644 --- a/pgx/_src/games/chess.py +++ b/pgx/_src/games/chess.py @@ -125,8 +125,24 @@ break BETWEEN[from_, to, i] = c * 8 + r -FROM_PLANE, TO_PLANE, INIT_LEGAL_ACTION_MASK, LEGAL_DEST, LEGAL_DEST_NEAR, LEGAL_DEST_FAR, CAN_MOVE, BETWEEN = ( - jnp.array(x) for x in (FROM_PLANE, TO_PLANE, INIT_LEGAL_ACTION_MASK, LEGAL_DEST, LEGAL_DEST_NEAR, LEGAL_DEST_FAR, CAN_MOVE, BETWEEN) +# RAYS[sq, d]: squares along queen-line direction d from sq, nearest first, -1 padded. +# RAY_DIR[sq, to]: direction index d such that to is on RAYS[sq, d], else -1. +_DIRS = [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)] +RAYS = -np.ones((64, 8, 7), dtype=np.int32) +RAY_DIR = -np.ones((64, 64), dtype=np.int32) +for sq in range(64): + r0, c0 = sq % 8, sq // 8 + for d, (dr, dc) in enumerate(_DIRS): + for i in range(1, 8): + r, c = r0 + dr * i, c0 + dc * i + if not (0 <= r < 8 and 0 <= c < 8): + break + RAYS[sq, d, i - 1] = c * 8 + r + RAY_DIR[sq, c * 8 + r] = d +IS_DIAG_DIR = np.array([dr != 0 and dc != 0 for dr, dc in _DIRS], dtype=np.bool_) + +FROM_PLANE, TO_PLANE, INIT_LEGAL_ACTION_MASK, LEGAL_DEST, LEGAL_DEST_NEAR, LEGAL_DEST_FAR, CAN_MOVE, BETWEEN, RAYS, RAY_DIR, IS_DIAG_DIR = ( + jnp.array(x) for x in (FROM_PLANE, TO_PLANE, INIT_LEGAL_ACTION_MASK, LEGAL_DEST, LEGAL_DEST_NEAR, LEGAL_DEST_FAR, CAN_MOVE, BETWEEN, RAYS, RAY_DIR, IS_DIAG_DIR) ) keys = jax.random.split(jax.random.PRNGKey(12345), 4) @@ -215,7 +231,8 @@ def is_terminal(self, state: GameState) -> Array: terminated = ~state.legal_action_mask.any() terminated |= state.halfmove_count >= 100 terminated |= has_insufficient_pieces(state) - rep = (state.hash_history == _zobrist_hash(state)).all(axis=1).sum() - 1 + # hash_history[0] always holds the current position's hash (set by _update_history) + rep = (state.hash_history == state.hash_history[0]).all(axis=1).sum() - 1 terminated |= rep >= 2 terminated |= MAX_TERMINATION_STEPS <= state.step_count return terminated @@ -308,17 +325,84 @@ def _flip(state: GameState) -> GameState: def _legal_action_mask(state: GameState) -> Array: + # Stockfish-style legality: compute checkers, pin rays, and king-danger squares once, + # then every pseudo-legal move is decided by table lookups — no per-move make/unmake. + # En passant is the lone exception (two candidate moves, validated by make-move). + board = state.board + king_pos = jnp.argmin(jnp.abs(board - KING)) + + # opponent pieces currently giving check + def near_checker(to): # knight/pawn/king patterns + ok = (to >= 0) & (board[to] < 0) + piece = jnp.abs(board[to]) + ok &= CAN_MOVE[piece, king_pos, to] + ok &= ~((piece == PAWN) & (to // 8 == king_pos // 8)) # pawns only check diagonally + return jnp.where(ok, to, -1) + + def far_checker(to): # distant sliders + ok = (to >= 0) & (board[to] < 0) + piece = jnp.abs(board[to]) + ok &= (piece == QUEEN) | (piece == ROOK) | (piece == BISHOP) + between_ixs = BETWEEN[king_pos, to] + ok &= CAN_MOVE[piece, king_pos, to] & ((between_ixs < 0) | (board[between_ixs] == EMPTY)).all() + return jnp.where(ok, to, -1) + + checker_sqs = jnp.hstack(( + jax.vmap(near_checker)(LEGAL_DEST_NEAR[king_pos]), + jax.vmap(far_checker)(LEGAL_DEST_FAR[king_pos]), + )) + checker_mask = jnp.zeros(65, dtype=jnp.bool_).at[checker_sqs].set(True)[:64] + num_checkers = checker_mask.sum() + single_checker = jnp.argmax(checker_mask) + + # non-king moves must capture the single checker or block its line (none if double check) + blocking_mask = jnp.zeros(65, dtype=jnp.bool_).at[BETWEEN[king_pos, single_checker]].set(True)[:64] + check_target = jnp.where( + num_checkers == 0, + jnp.ones(64, dtype=jnp.bool_), + jnp.where(num_checkers == 1, blocking_mask | (jnp.arange(64) == single_checker), jnp.zeros(64, dtype=jnp.bool_)), + ) + + # absolute pins: first own piece along each king ray, backed by a matching enemy slider + def pin_dir(d): + ray = RAYS[king_pos, d] + vals = jnp.where(ray >= 0, board[ray], 0) + occ = vals != 0 + i1 = jnp.argmax(occ) + first_is_mine = occ.any() & (vals[i1] > 0) + occ2 = occ & (jnp.arange(7) > i1) + v2 = jnp.where(occ2.any(), vals[jnp.argmax(occ2)], 0) + slider = jnp.where(IS_DIAG_DIR[d], (v2 == -QUEEN) | (v2 == -BISHOP), (v2 == -QUEEN) | (v2 == -ROOK)) + return jnp.where(first_is_mine & slider, ray[i1], -1) + + pinned_sqs = jax.vmap(pin_dir)(jnp.arange(8)) + pinned_dir = jnp.full(65, -1, dtype=jnp.int32).at[pinned_sqs].set(jnp.arange(8, dtype=jnp.int32)) + + # squares the king may not step onto, with the king itself lifted off the board + # (a slider keeps attacking "through" the square the king vacates) + board_wo_king = board.at[king_pos].set(EMPTY) + # A king has at most 8 destinations; LEGAL_DEST is padded to 27 (the queen's max), so + # the tail is always -1 for the king. Slicing to [:8] drops 19 guaranteed-empty lanes, + # each of which would otherwise run a full _is_attacked probe. + king_dests = LEGAL_DEST[KING, king_pos, :8] + danger = jax.vmap(lambda to: (to >= 0) & _is_attacked(board_wo_king, to))(king_dests) + king_danger = jnp.zeros(65, dtype=jnp.bool_).at[jnp.where(danger, king_dests, 64)].set(True)[:64] + def legal_normal_moves(from_): - piece = state.board[from_] + piece = board[from_] def legal_label(to): - ok = (from_ >= 0) & (piece > 0) & (to >= 0) & (state.board[to] <= 0) + ok = (from_ >= 0) & (piece > 0) & (to >= 0) & (board[to] <= 0) between_ixs = BETWEEN[from_, to] - ok &= CAN_MOVE[piece, from_, to] & ((between_ixs < 0) | (state.board[between_ixs] == EMPTY)).all() + ok &= CAN_MOVE[piece, from_, to] & ((between_ixs < 0) | (board[between_ixs] == EMPTY)).all() c0, c1 = from_ // 8, to // 8 - pawn_should = ((c1 == c0) & (state.board[to] == EMPTY)) | ((c1 != c0) & (state.board[to] < 0)) + pawn_should = ((c1 == c0) & (board[to] == EMPTY)) | ((c1 != c0) & (board[to] < 0)) ok &= (piece != PAWN) | pawn_should - return lax.select(ok, Action(from_=from_, to=to)._to_label(), -1) + # check/pin legality via the precomputed masks + pin_d = pinned_dir[from_] + non_king_ok = check_target[to] & ((pin_d < 0) | (RAY_DIR[king_pos, to] == pin_d)) + ok &= jnp.where(piece == KING, ~king_danger[to], non_king_ok) + return jnp.where(ok, Action(from_=from_, to=to)._to_label(), -1) return jax.vmap(legal_label)(LEGAL_DEST[piece, from_]) @@ -333,8 +417,8 @@ def legal_labels(from_): return jax.vmap(legal_labels)(jnp.int32([to - 9, to + 7])) def is_not_checked(label): - a = Action._from_label(label) - return ~_is_checked(_apply_move(state, a)) + a = Action._from_label(jnp.maximum(label, 0)) + return (label >= 0) & ~_is_checked(_apply_move(state, a)) def legal_underpromotions(mask): def legal_labels(label): @@ -346,18 +430,16 @@ def legal_labels(label): labels = jnp.int32([from_ * 73 + i for i in range(9) for from_ in [6, 14, 22, 30, 38, 46, 54, 62]]) return jax.vmap(legal_labels)(labels) - # normal move and en passant + # normal moves (already fully legal thanks to the masks above) possible_piece_positions = jnp.nonzero(state.board > 0, size=16, fill_value=-1)[0] a1 = jax.vmap(legal_normal_moves)(possible_piece_positions).flatten() - a2 = legal_en_passants() - actions = jnp.hstack((a1, a2)) # include -1 - # filter out -1. 200 is big enough for normal play. - ixs = jnp.nonzero(actions >= 0, size=200, fill_value=0)[0] - actions = actions[ixs] # size: 19 * 27 -> 200 - # filter ignoring checks and suicides - actions = jnp.where(jax.vmap(is_not_checked)(actions), actions, -1) mask = jnp.zeros(64 * 73 + 1, dtype=jnp.bool_) # +1 for sentinel - mask = mask.at[actions].set(True) + mask = mask.at[a1].set(True) + + # en passant: rare and full of edge cases (rank pins, capturing the checker) — make-move test + a2 = legal_en_passants() + a2 = jnp.where(jax.vmap(is_not_checked)(a2), a2, -1) + mask = mask.at[a2].set(True) # castling b = state.board @@ -365,7 +447,7 @@ def legal_labels(label): can_castle_queen_side &= (b[0] == ROOK) & (b[8] == EMPTY) & (b[16] == EMPTY) & (b[24] == EMPTY) & (b[32] == KING) can_castle_king_side = state.castling_rights[0, 1] can_castle_king_side &= (b[32] == KING) & (b[40] == EMPTY) & (b[48] == EMPTY) & (b[56] == ROOK) - not_checked = ~jax.vmap(_is_attacked, in_axes=(None, 0))(state, jnp.int32([16, 24, 32, 40, 48])) + not_checked = ~jax.vmap(_is_attacked, in_axes=(None, 0))(state.board, jnp.int32([16, 24, 32, 40, 48])) mask = mask.at[2364].set(mask[2364] | (can_castle_queen_side & not_checked[:3].all())) mask = mask.at[2367].set(mask[2367] | (can_castle_king_side & not_checked[2:].all())) @@ -376,18 +458,18 @@ def legal_labels(label): return mask[:-1] -def _is_attacked(state: GameState, pos: Array): +def _is_attacked(board: Array, pos: Array): def attacked_far(to): - ok = (to >= 0) & (state.board[to] < 0) # should be opponent's - piece = jnp.abs(state.board[to]) + ok = (to >= 0) & (board[to] < 0) # should be opponent's + piece = jnp.abs(board[to]) ok &= (piece == QUEEN) | (piece == ROOK) | (piece == BISHOP) between_ixs = BETWEEN[pos, to] - ok &= CAN_MOVE[piece, pos, to] & ((between_ixs < 0) | (state.board[between_ixs] == EMPTY)).all() + ok &= CAN_MOVE[piece, pos, to] & ((between_ixs < 0) | (board[between_ixs] == EMPTY)).all() return ok def attacked_near(to): - ok = (to >= 0) & (state.board[to] < 0) # should be opponent's - piece = jnp.abs(state.board[to]) + ok = (to >= 0) & (board[to] < 0) # should be opponent's + piece = jnp.abs(board[to]) ok &= CAN_MOVE[piece, pos, to] ok &= ~((piece == PAWN) & (to // 8 == pos // 8)) # should move diagonally to capture return ok @@ -399,7 +481,7 @@ def attacked_near(to): def _is_checked(state: GameState): king_pos = jnp.argmin(jnp.abs(state.board - KING)) - return _is_attacked(state, king_pos) + return _is_attacked(state.board, king_pos) def _zobrist_hash(state: GameState) -> Array: diff --git a/tests/diff_vs_python_chess.py b/tests/diff_vs_python_chess.py new file mode 100644 index 000000000..d07a59073 --- /dev/null +++ b/tests/diff_vs_python_chess.py @@ -0,0 +1,160 @@ +# Differential test: pgx chess legal move generation vs python-chess (ground truth). +# Standalone dev tool, intentionally not named test_* (pgx's pytest suite is heavy). +# +# Usage (keep it to ONE process, the platform allocator returns memory to the OS): +# XLA_PYTHON_CLIENT_PREALLOCATE=false XLA_PYTHON_CLIENT_ALLOCATOR=platform \ +# python tests/diff_vs_python_chess.py +import random + +import jax +import jax.numpy as jnp +import numpy as np + +import chess as pychess + +from pgx._src.games import chess as C + +FROM_PLANE = np.asarray(C.FROM_PLANE) +TO_PLANE = np.asarray(C.TO_PLANE) +UNDER = [pychess.ROOK, pychess.BISHOP, pychess.KNIGHT] + +game = C.Game() +jit_mask = jax.jit(C._legal_action_mask) +jit_step = jax.jit(game.step) +jit_terminal = jax.jit(game.is_terminal) + + +def pgx_sq(sq_pychess: int, black: bool) -> int: + rank, file = sq_pychess // 8, sq_pychess % 8 + if black: + rank = 7 - rank + return file * 8 + rank + + +def pychess_sq(sq_pgx: int, black: bool) -> int: + file, rank = sq_pgx // 8, sq_pgx % 8 + if black: + rank = 7 - rank + return rank * 8 + file + + +def decode(mask, black: bool, board: pychess.Board) -> set: + moves = set() + for label in np.nonzero(np.asarray(mask))[0]: + f, plane = label // 73, label % 73 + t = FROM_PLANE[f, plane] + fr, to = pychess_sq(int(f), black), pychess_sq(int(t), black) + promo = None + if plane < 9: + promo = UNDER[plane // 3] + elif board.piece_type_at(fr) == pychess.PAWN and pychess.square_rank(to) in (0, 7): + promo = pychess.QUEEN + moves.add(pychess.Move(fr, to, promotion=promo)) + return moves + + +def encode(mv: pychess.Move, black: bool) -> int: + f, t = pgx_sq(mv.from_square, black), pgx_sq(mv.to_square, black) + if mv.promotion in (None, pychess.QUEEN): + return f * 73 + int(TO_PLANE[f, t]) + direc = {1: 0, 9: 1, -7: 2}[t - f] # up, right, left + return f * 73 + {pychess.ROOK: 0, pychess.BISHOP: 1, pychess.KNIGHT: 2}[mv.promotion] * 3 + direc + + +def state_from_board(board: pychess.Board) -> C.GameState: + """Build a GameState from a python-chess board (current player always positive/up).""" + black = board.turn == pychess.BLACK + arr = np.zeros(64, dtype=np.int32) + for sq, piece in board.piece_map().items(): + sign = 1 if (piece.color == board.turn) else -1 + arr[pgx_sq(sq, black)] = sign * piece.piece_type + my, opp = board.turn, not board.turn + castling = jnp.bool_([ + [board.has_queenside_castling_rights(my), board.has_kingside_castling_rights(my)], + [board.has_queenside_castling_rights(opp), board.has_kingside_castling_rights(opp)], + ]) + ep = jnp.int32(-1 if board.ep_square is None else pgx_sq(board.ep_square, black)) + x = C.GameState( + board=jnp.asarray(arr), + color=jnp.int32(1 if black else 0), + castling_rights=castling, + en_passant=ep, + hash_history=jnp.zeros_like(C.GameState().hash_history), + board_history=jnp.zeros_like(C.GameState().board_history), + ) + return x._replace(legal_action_mask=jit_mask(x)) + + +def compare(board: pychess.Board, x: C.GameState, ctx: str) -> bool: + black = board.turn == pychess.BLACK + got = decode(x.legal_action_mask, black, board) + ref = set(board.legal_moves) + if got != ref: + print(f"MISMATCH {ctx}\n fen={board.fen()}\n pgx-only={got - ref}\n ref-only={ref - got}") + return False + return True + + +FENS = [ + "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", # startpos + "r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 1", # kiwipete + "8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 1", # perft pos3 (ep + pins) + "r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", # perft pos4 + "r2q1rk1/pP1p2pp/Q4n2/bbp1p3/Np6/1B3NBn/pPPP1PPP/R3K2R b KQ - 0 1", # pos4 mirrored + "rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", # perft pos5 + "r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", # perft pos6 + "R6R/3Q4/1Q4Q1/4Q3/2Q4Q/Q4Q2/pp1Q4/kBNN1KB1 w - - 0 1", # 218 legal moves (cap regression) + "8/8/8/8/k2Pp2Q/8/8/3K4 b - d3 0 1", # ep illegal: horizontal x-ray after both pawns vanish + "8/8/8/2k5/3Pp3/8/8/4K3 b - d3 0 1", # ep capture of the checking double-pushed pawn + "8/8/4k3/8/2pP4/8/B7/7K b - d3 0 1", # ep with diagonal pin geometry + "3k3r/8/8/8/3n4/8/8/3R2K1 b - - 0 1", # pinned knight: no moves at all + "r3k2r/8/5q2/8/8/8/8/R3K2R w KQkq - 0 1", # castling vs attacked squares + "r3k2r/8/8/8/8/8/8/R3K2R b KQkq - 0 1", # symmetric castling, black to move + "8/P6k/8/8/8/8/7K/8 w - - 0 1", # promotion (queen + under) + "8/8/8/8/8/2k5/1q6/K7 w - - 0 1", # nearly stalemated king + "4k3/8/8/8/8/8/1q6/R3K2N b - - 0 1", # contact + discovered check potential +] + + +def run_fen_suite() -> int: + bad = 0 + for fen in FENS: + board = pychess.Board(fen) + assert board.is_valid(), f"invalid test fen: {fen}" + x = state_from_board(board) + if not compare(board, x, "depth1"): + bad += 1 + continue + for mv in board.legal_moves: # depth 2: every child + child = board.copy() + child.push(mv) + x2 = jit_step(x, jnp.int32(encode(mv, board.turn == pychess.BLACK))) + if not compare(child, x2, f"depth2 after {mv.uci()}"): + bad += 1 + print(f"FEN suite: {len(FENS)} positions, depth-2 expansion -> {bad} mismatches") + return bad + + +def run_random_games(n_games: int = 20, max_plies: int = 160) -> int: + random.seed(42) + bad = 0 + for g in range(n_games): + s = game.init() + board = pychess.Board() + for _ in range(max_plies): + if not compare(board, s, f"random game {g}"): + bad += 1 + break + if bool(jit_terminal(s)) or board.is_game_over(claim_draw=True): + break + mv = random.choice(sorted(board.legal_moves, key=str)) + s = jit_step(s, jnp.int32(encode(mv, board.turn == pychess.BLACK))) + board.push(mv) + print(f"random games: {n_games} -> {bad} mismatches") + return bad + + +if __name__ == "__main__": + failures = run_fen_suite() + run_random_games() + print("RESULT:", "PASS" if failures == 0 else f"FAIL ({failures})") + raise SystemExit(0 if failures == 0 else 1)