Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions pgx/_src/games/go.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class GameState(NamedTuple):
step_count: Array = jnp.int32(0)
# ids of representative stone (smallest) in the connected stones
board: Array = jnp.zeros(19 * 19, dtype=jnp.int32) # b > 0, w < 0, empty = 0
board_history: Array = jnp.full((8, 19 * 19), 2, dtype=jnp.int32) # for obs
board_history: Array = jnp.full((8, 19 * 19), 2, dtype=jnp.int8) # for obs; values in {-1, 0, 1, 2}
num_captured: Array = jnp.zeros(2, dtype=jnp.int32) # (b, w)
consecutive_pass_count: Array = jnp.int32(0)
ko: Array = jnp.int32(-1) # by SSK
Expand All @@ -49,7 +49,7 @@ def __init__(
def init(self) -> GameState:
return GameState(
board=jnp.zeros(self.size**2, dtype=jnp.int32),
board_history=jnp.full((self.history_length, self.size**2), 2, dtype=jnp.int32),
board_history=jnp.full((self.history_length, self.size**2), 2, dtype=jnp.int8),
hash_history=jnp.zeros((self.max_termination_steps, 2), dtype=jnp.uint32),
)

Expand All @@ -63,7 +63,7 @@ def step(self, state: GameState, action: Array) -> GameState:
)
# update board history
board_history = jnp.roll(state.board_history, self.size**2)
board_history = board_history.at[0].set(jnp.clip(state.board, -1, 1).astype(jnp.int32))
board_history = board_history.at[0].set(jnp.clip(state.board, -1, 1).astype(jnp.int8))
state = state._replace(board_history=board_history)
# check PSK
hash_ = _compute_hash(state)
Expand Down Expand Up @@ -111,12 +111,13 @@ def is_terminal(self, state: GameState) -> Array:
return two_consecutive_pass | state.is_psk | timeover

def rewards(self, state: GameState) -> Array:
scores = _count_scores(state, self.size)
is_terminal = self.is_terminal(state)
scores = _count_scores(state, self.size, enable=is_terminal)
is_black_win = scores[0] - self.komi > scores[1]
rewards = lax.select(is_black_win, jnp.float32([1, -1]), jnp.float32([-1, 1]))
to_play = state.color
rewards = lax.select(state.is_psk, jnp.float32([-1, -1]).at[to_play].set(1.0), rewards)
rewards = lax.select(self.is_terminal(state), rewards, jnp.zeros(2, dtype=jnp.float32))
rewards = lax.select(is_terminal, rewards, jnp.zeros(2, dtype=jnp.float32))
return rewards


Expand Down Expand Up @@ -214,15 +215,21 @@ def _is_psk(state: GameState):
return not_passed & has_same_hash


def _count_scores(state: GameState, size):
def _count_scores(state: GameState, size, enable=True):
# `enable=False` replaces the board with a fully-occupied dummy whose flood fill converges
# immediately. rewards() discards the scores of non-terminal states anyway, but under
# vmap/jit the while_loop below runs as many rounds as the worst board in the batch needs —
# without the dummy, every step pays the full territory fill even though almost no state
# in the batch is terminal (the empty early-game board is the worst case at ~2*size rounds).
def calc_point(c):
return _count_ji(state, c, size) + jnp.count_nonzero(state.board * c > 0)
return _count_ji(state, c, size, enable) + jnp.count_nonzero(state.board * c > 0)

return jax.vmap(calc_point)(jnp.int32([1, -1]))


def _count_ji(state: GameState, color: int, size: int):
def _count_ji(state: GameState, color: int, size: int, enable=True):
board = jnp.clip(state.board * color, -1, 1) # my stone: 1, opp stone: -1
board = jnp.where(enable, board, 1)
adj_mat = jax.vmap(_adj_ixs, in_axes=(0, None))(jnp.arange(size**2), size) # (size**2, 4)

def fill_opp(x):
Expand Down