diff --git a/pgx/experimental/chess.py b/pgx/experimental/chess.py index 36313d069..125b4f5f1 100644 --- a/pgx/experimental/chess.py +++ b/pgx/experimental/chess.py @@ -66,6 +66,9 @@ def from_fen(fen: str): halfmove_count=jnp.int32(halfmove_cnt), fullmove_count=jnp.int32(fullmove_cnt), ) + # default GameState carries the startpos in hash_history[0]/board_history[0]; clear both + # so the restored position does not inherit a phantom repetition/observation entry + x = x._replace(hash_history=jnp.zeros_like(x.hash_history), board_history=jnp.zeros_like(x.board_history)) legal_action_mask = jax.jit(_legal_action_mask)(x) x = x._replace(legal_action_mask=legal_action_mask) x = _update_history(x)