diff --git a/pgx/_src/games/go.py b/pgx/_src/games/go.py index 9593a406a..f00aef85b 100644 --- a/pgx/_src/games/go.py +++ b/pgx/_src/games/go.py @@ -25,7 +25,7 @@ class GameState(NamedTuple): step_count: Array = jnp.int32(0) # ids of representative stone (smallest) in the connected stones board: Array = jnp.zeros(19 * 19, dtype=jnp.int32) # b > 0, w < 0, empty = 0 - board_history: Array = jnp.full((8, 19 * 19), 2, dtype=jnp.int32) # for obs + board_history: Array = jnp.full((8, 19 * 19), 2, dtype=jnp.int8) # for obs; values in {-1, 0, 1, 2} num_captured: Array = jnp.zeros(2, dtype=jnp.int32) # (b, w) consecutive_pass_count: Array = jnp.int32(0) ko: Array = jnp.int32(-1) # by SSK @@ -49,7 +49,7 @@ def __init__( def init(self) -> GameState: return GameState( board=jnp.zeros(self.size**2, dtype=jnp.int32), - board_history=jnp.full((self.history_length, self.size**2), 2, dtype=jnp.int32), + board_history=jnp.full((self.history_length, self.size**2), 2, dtype=jnp.int8), hash_history=jnp.zeros((self.max_termination_steps, 2), dtype=jnp.uint32), ) @@ -63,7 +63,7 @@ def step(self, state: GameState, action: Array) -> GameState: ) # update board history board_history = jnp.roll(state.board_history, self.size**2) - board_history = board_history.at[0].set(jnp.clip(state.board, -1, 1).astype(jnp.int32)) + board_history = board_history.at[0].set(jnp.clip(state.board, -1, 1).astype(jnp.int8)) state = state._replace(board_history=board_history) # check PSK hash_ = _compute_hash(state) @@ -111,12 +111,13 @@ def is_terminal(self, state: GameState) -> Array: return two_consecutive_pass | state.is_psk | timeover def rewards(self, state: GameState) -> Array: - scores = _count_scores(state, self.size) + is_terminal = self.is_terminal(state) + scores = _count_scores(state, self.size, enable=is_terminal) is_black_win = scores[0] - self.komi > scores[1] rewards = lax.select(is_black_win, jnp.float32([1, -1]), jnp.float32([-1, 1])) to_play = state.color rewards = lax.select(state.is_psk, jnp.float32([-1, -1]).at[to_play].set(1.0), rewards) - rewards = lax.select(self.is_terminal(state), rewards, jnp.zeros(2, dtype=jnp.float32)) + rewards = lax.select(is_terminal, rewards, jnp.zeros(2, dtype=jnp.float32)) return rewards @@ -214,15 +215,21 @@ def _is_psk(state: GameState): return not_passed & has_same_hash -def _count_scores(state: GameState, size): +def _count_scores(state: GameState, size, enable=True): + # `enable=False` replaces the board with a fully-occupied dummy whose flood fill converges + # immediately. rewards() discards the scores of non-terminal states anyway, but under + # vmap/jit the while_loop below runs as many rounds as the worst board in the batch needs — + # without the dummy, every step pays the full territory fill even though almost no state + # in the batch is terminal (the empty early-game board is the worst case at ~2*size rounds). def calc_point(c): - return _count_ji(state, c, size) + jnp.count_nonzero(state.board * c > 0) + return _count_ji(state, c, size, enable) + jnp.count_nonzero(state.board * c > 0) return jax.vmap(calc_point)(jnp.int32([1, -1])) -def _count_ji(state: GameState, color: int, size: int): +def _count_ji(state: GameState, color: int, size: int, enable=True): board = jnp.clip(state.board * color, -1, 1) # my stone: 1, opp stone: -1 + board = jnp.where(enable, board, 1) adj_mat = jax.vmap(_adj_ixs, in_axes=(0, None))(jnp.arange(size**2), size) # (size**2, 4) def fill_opp(x):