Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 34 additions & 29 deletions pgx/animal_shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,31 @@
MAX_TERMINATION_STEPS = 256


def _compute_hash(board, hand, turn):
# Zobrist hash as a PURE function of the canonical (turn-0 frame) position. The previous
# incremental updates were not invertible (captured piece removed from the wrong square; drop
# hand-count off-by-one; hand side indexed by the post-flip turn), so the hash was
# path-dependent and repetition detection silently undercounted. Recomputing from scratch makes
# equal positions hash equally by construction.
sq = jnp.arange(12)
csq = jnp.where(turn == 0, sq, 11 - sq)
cpiece = jnp.where(turn == 0, board, (board + 5) % 10)
occ = board != EMPTY
terms = jnp.where(occ[:, None], ZOBRIST_BOARD[csq, cpiece], jnp.uint32(0)) # (12, 2)
h = jnp.zeros(2, dtype=jnp.uint32)
for i in range(12):
h = h ^ terms[i]
cside = jnp.array([turn, 1 - turn]) # hand[0]=mine -> side `turn`, hand[1]=opp -> side `1-turn`
for s in range(2):
for t in range(3):
h = h ^ ZOBRIST_HAND[cside[s], t, hand[s, t]]
h = h ^ jnp.where(turn == 1, ZOBRIST_SIDE, jnp.zeros(2, dtype=jnp.uint32))
return h


INIT_ZOBRIST_HASH = _compute_hash(INIT_BOARD, jnp.zeros((2, 3), dtype=jnp.int32), jnp.int32(0))


@dataclass
class State(core.State):
current_player: Array = jnp.int32(0)
Expand All @@ -79,9 +104,9 @@ class State(core.State):
_turn: Array = jnp.int32(0)
_board: Array = INIT_BOARD # (12,)
_hand: Array = jnp.zeros((2, 3), dtype=jnp.int32)
_zobrist_hash: Array = jnp.uint32([233882788, 593924309])
_zobrist_hash: Array = INIT_ZOBRIST_HASH
_hash_history: Array = (
jnp.zeros((MAX_TERMINATION_STEPS + 1, 2), dtype=jnp.uint32).at[0].set(jnp.uint32([233882788, 593924309]))
jnp.zeros((MAX_TERMINATION_STEPS + 1, 2), dtype=jnp.uint32).at[0].set(INIT_ZOBRIST_HASH)
)
_board_history: Array = (-jnp.ones((8, 12), dtype=jnp.int32)).at[0, :].set(INIT_BOARD)
_hand_history: Array = jnp.zeros((8, 6), dtype=jnp.int32)
Expand Down Expand Up @@ -168,6 +193,9 @@ def _step(state: State, action: Array):
state = state.replace(_hand_history=hand_history) # type:ignore

state = _flip(state)
# recompute the Zobrist hash from scratch (pure function of position) to avoid the
# path-dependence of the incremental updates in _step_move/_step_drop
state = state.replace(_zobrist_hash=_compute_hash(state._board, state._hand, state._turn)) # type: ignore
state = state.replace( # type: ignore
_hash_history=state._hash_history.at[state._step_count].set(state._zobrist_hash),
)
Expand Down Expand Up @@ -235,9 +263,6 @@ def _step_move(state: State, action: Action) -> State:
piece = state._board[action.from_]
# remove piece from the original position
board = state._board.at[action.from_].set(EMPTY)
zb_from_ = jax.lax.select(state._turn == 0, action.from_, 11 - action.from_)
zb_piece = jax.lax.select(state._turn == 0, piece, (piece + 5) % 10)
zobrist_hash = state._zobrist_hash ^ ZOBRIST_BOARD[zb_from_, zb_piece]
# capture the opponent if exists
captured = board[action.to] # suppose >= OPP_PAWN, -1 if EMPTY
hand = jax.lax.cond(
Expand All @@ -248,42 +273,22 @@ def _step_move(state: State, action: Action) -> State:
# (2) filtering promoted piece by x % 4
lambda: state._hand.at[0, (captured % 5) % 4].add(1),
)
zobrist_hash = jax.lax.select(
captured == EMPTY,
zobrist_hash,
zobrist_hash
^ ZOBRIST_BOARD[
zb_from_,
jax.lax.select(state._turn == 0, captured, (captured + 5) % 10),
],
)
num_hand = hand[0, (captured % 5) % 4]
zobrist_hash ^= ZOBRIST_HAND[state._turn, (captured % 5) % 4, num_hand - 1]
zobrist_hash ^= ZOBRIST_HAND[state._turn, (captured % 5) % 4, num_hand]
# promote piece (PAWN to GOLD)
is_promotion = (action.from_ % 4 == 1) & (piece == PAWN)
piece = jax.lax.select(is_promotion, GOLD, piece)
# set piece to the target position
board = board.at[action.to].set(piece)
zb_to_ = jax.lax.select(state._turn == 0, action.to, 11 - action.to)
zb_piece = jax.lax.select(state._turn == 0, piece, (piece + 5) % 10)
zobrist_hash ^= ZOBRIST_BOARD[zb_to_, zb_piece]
# apply piece moves
return state.replace(_board=board, _hand=hand, _zobrist_hash=zobrist_hash) # type: ignore
# _zobrist_hash is recomputed from scratch in _step (pure function of position)
return state.replace(_board=board, _hand=hand) # type: ignore


def _step_drop(state: State, action: Action) -> State:
# add piece to board
board = state._board.at[action.to].set(action.drop_piece)
zb_to_ = jax.lax.select(state._turn == 0, action.to, 11 - action.to)
zb_piece = jax.lax.select(state._turn == 0, action.drop_piece, (action.drop_piece + 5) % 10)
zobrist_hash = state._zobrist_hash ^ ZOBRIST_BOARD[zb_to_, zb_piece]
# remove piece from hand
hand = state._hand.at[0, action.drop_piece].add(-1)
num_hand = state._hand[0, action.drop_piece]
zobrist_hash ^= ZOBRIST_HAND[state._turn, action.drop_piece, num_hand + 1]
zobrist_hash ^= ZOBRIST_HAND[state._turn, action.drop_piece, num_hand]
return state.replace(_board=board, _hand=hand, _zobrist_hash=zobrist_hash) # type: ignore
# _zobrist_hash is recomputed from scratch in _step (pure function of position)
return state.replace(_board=board, _hand=hand) # type: ignore


def _legal_action_mask(state: State):
Expand Down
32 changes: 32 additions & 0 deletions tests/test_animal_shogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,35 @@ def test_buggy_samples():
state = step(state, 0 * 12 + 11) # Black: Right Up Bishop
DROP_PAWN_TO_0 = 8 * 12 + 0
assert state.legal_action_mask[DROP_PAWN_TO_0]


def test_zobrist_hash_is_pure_position_function():
# Regression: the Zobrist hash must depend only on (board, hand, turn). The previous
# incremental updates were path-dependent -- the captured piece was removed from the hash at the
# moving piece's `from` square instead of its `to` square, the drop hand-count update was
# off-by-one vs the capture path, and the hand term was indexed by the post-flip turn. As a
# result repetition detection silently undercounted (e.g. position at ply N == ply N+k but with
# different hashes). Here we roll out random self-play (which exercises captures and drops) and
# assert that any two states with identical (board, hand, turn) have identical Zobrist hashes.
key = jax.random.PRNGKey(0)
seen = {}
for _ in range(50):
key, k = jax.random.split(key)
state = init(k)
for _ in range(60):
if bool(state.terminated):
break
pos = (
tuple(int(x) for x in state._board.tolist()),
tuple(int(x) for x in state._hand.flatten().tolist()),
int(state._turn),
)
h = tuple(int(x) for x in state._zobrist_hash.tolist())
if pos in seen:
assert seen[pos] == h, f"equal position with different Zobrist hash: {pos}"
else:
seen[pos] = h
key, ak = jax.random.split(key)
legal = jnp.nonzero(state.legal_action_mask)[0]
a = legal[jax.random.randint(ak, (), 0, legal.shape[0])]
state = step(state, a)