From 5098faf89538cbd803d931e933c5ea15715e887d Mon Sep 17 00:00:00 2001 From: Taro Date: Thu, 11 Jun 2026 02:44:08 +0200 Subject: [PATCH] Othello tests: make starting-player-dependent tests RNG-version-agnostic The bernoulli coin flip that picks the starting player in Othello._init is not stable across jax.random versions, so tests that hard-coded the outcome of a specific PRNGKey (via a double split chosen to land on current_player==0) broke on newer JAX. Replace that with a small _init_with_current_player(player) helper that searches keys for the desired starting player, keeping every downstream board/reward/observation assertion deterministic on any JAX version. Verified on jax 0.4.30 and jax 0.10.1. --- tests/test_othello.py | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/tests/test_othello.py b/tests/test_othello.py index cae8981e8..61d09cf7c 100644 --- a/tests/test_othello.py +++ b/tests/test_othello.py @@ -8,12 +8,25 @@ observe = jax.jit(env.observe) +def _init_with_current_player(player: int): + """Return an initial state whose current_player is `player`. + + The starting player is decided by an RNG coin flip in `_init`, and the exact + flip outcome for a given key is not stable across JAX versions. Searching keys + keeps these tests deterministic and version-agnostic without depending on a + specific `jax.random` implementation. + """ + for seed in range(1000): + state = init(jax.random.PRNGKey(seed)) + if int(state.current_player) == player: + return state + raise AssertionError(f"no key produced current_player={player}") + + def test_init(): - key = jax.random.PRNGKey(0) - _, key = jax.random.split(key) # due to API update - _, key = jax.random.split(key) # due to API update - state = init(key=key) - assert state.current_player == 0 + # Both starting players are reachable, and current_player is always 0 or 1. + players = {int(_init_with_current_player(p).current_player) for p in (0, 1)} + assert players == {0, 1} def test_step(): @@ -42,10 +55,7 @@ def test_step(): def test_terminated(): # wipe out - key = jax.random.PRNGKey(0) - _, key = jax.random.split(key) # due to API update - _, key = jax.random.split(key) # due to API update - state = init(key) + state = _init_with_current_player(0) for i in [37, 43, 34, 29, 52, 45, 38, 44]: state = step(state, i) assert not state.terminated @@ -56,10 +66,7 @@ def test_terminated(): def test_legal_action(): # cannot put - key = jax.random.PRNGKey(0) - _, key = jax.random.split(key) # due to API update - _, key = jax.random.split(key) # due to API update - state = init(key) + state = _init_with_current_player(0) assert state.current_player == 0 for i in [37, 29, 18, 44, 53, 46, 30, 60, 62, 38, 39]: state = step(state, i) @@ -76,10 +83,7 @@ def test_legal_action(): def test_observe(): - key = jax.random.PRNGKey(0) - _, key = jax.random.split(key) # due to API update - _, key = jax.random.split(key) # due to API update - state = init(key) + state = _init_with_current_player(0) assert state.current_player == 0 obs = observe(state, state.current_player) @@ -125,6 +129,7 @@ def test_observe(): def test_api(): import pgx + env = pgx.make("othello") pgx.api_test(env, 3, use_key=False) pgx.api_test(env, 3, use_key=True)