[Go] Skip territory flood fill for non-terminal states (3-12x faster rewards); int8 board history#1321
Open
gweber wants to merge 1 commit into
Open
[Go] Skip territory flood fill for non-terminal states (3-12x faster rewards); int8 board history#1321gweber wants to merge 1 commit into
gweber wants to merge 1 commit into
Conversation
…tory
rewards() computes Tromp-Taylor scores on every step, but the result is
discarded for non-terminal states. Under vmap/jit the territory flood fill
(lax.while_loop) still runs as many rounds as the worst board in the batch
needs - on a near-empty board that is ~2*size rounds, making rewards() the
single most expensive component of a Go env step (0.56 ms at ply 4 vs
0.16 ms for game.step itself, batch 256, 19x19).
Fix: feed the fill a fully-occupied dummy board for non-terminal states
(enable=False), which converges in one round. Only actually-terminal
boards pay the real fill. Semantics are exact: the discarded scores were
never observable, terminal states use the unchanged algorithm.
rewards() drops from 0.16-0.71 ms (phase-dependent) to a flat ~0.06 ms,
3-12x faster.
Also store board_history as int8 (values are in {-1,0,1,2}): 8.7 KB less
per 19x19 env state, relevant when states are carried as MCTS embeddings.
Validation: 300-ply lockstep vs the previous implementation (board,
rewards, terminal, observation bit-identical), two-pass terminal scoring,
and force-enabled scoring on 256 dense boards identical to the old scorer.
A pointer-jumping CCL variant (one shared component labeling for both
colors) was benchmarked and rejected: it wins on sparse boards but loses
2-4x on realistic mid-game boards, where empty regions are corridor-shaped
(large diameter) yet every empty point is adjacent to a stone (old fill
converges in 2-3 rounds).
(cherry picked from commit d3e4b50)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Game.rewards()computes Tromp-Taylor scores on every step, but the result is discarded for non-terminal states (lax.select(self.is_terminal(state), rewards, zeros)). Undervmap/jitthe discarded computation is not free: the territory flood fill (lax.while_loopin_count_ji) runs as many rounds as the worst board in the batch needs. On a near-empty board that is ~2*sizerounds, which makesrewards()the single most expensive component of a Go env step:rewards()game.step()for comparisonFix
Feed the flood fill a fully-occupied dummy board for non-terminal states (
enable=False): it converges in one round, and only actually-terminal boards pay the real fill. Semantics are exact — the discarded scores were never observable, and terminal states use the unchanged algorithm.rewards()drops to a flat ~0.06 ms across all phases (3–12x faster).Also stores
board_historyasint8(values are in{-1, 0, 1, 2}): 8.7 KB less per 19x19 state, which matters when states are carried as MCTS embeddings (AlphaZero-style tree search holds one state per node).Validation
Alternatives considered
A shared connected-component labeling of the empty points (pointer jumping, one labeling for both colors — related to #1205's torch port and the
sotetsuk/go-avoid-whilebranch) was implemented and benchmarked: it wins on sparse boards (0.15 ms at ply 4) but loses 2-4x on realistic mid-game boards, where empty regions are corridor-shaped (large component diameter) yet every empty point is adjacent to a stone, so the existing fill converges in 2-3 rounds. Thefori_loop(2*size-2)approach fromsotetsuk/go-avoid-whileis also slower (0.39 ms vs 0.23 ms at ply 120) because it always pays the worst case. Masking out non-terminal states beats both without touching the algorithm.