diff --git a/tests/test_othello.py b/tests/test_othello.py index cae8981e8..61d09cf7c 100644 --- a/tests/test_othello.py +++ b/tests/test_othello.py @@ -8,12 +8,25 @@ observe = jax.jit(env.observe) +def _init_with_current_player(player: int): + """Return an initial state whose current_player is `player`. + + The starting player is decided by an RNG coin flip in `_init`, and the exact + flip outcome for a given key is not stable across JAX versions. Searching keys + keeps these tests deterministic and version-agnostic without depending on a + specific `jax.random` implementation. + """ + for seed in range(1000): + state = init(jax.random.PRNGKey(seed)) + if int(state.current_player) == player: + return state + raise AssertionError(f"no key produced current_player={player}") + + def test_init(): - key = jax.random.PRNGKey(0) - _, key = jax.random.split(key) # due to API update - _, key = jax.random.split(key) # due to API update - state = init(key=key) - assert state.current_player == 0 + # Both starting players are reachable, and current_player is always 0 or 1. + players = {int(_init_with_current_player(p).current_player) for p in (0, 1)} + assert players == {0, 1} def test_step(): @@ -42,10 +55,7 @@ def test_step(): def test_terminated(): # wipe out - key = jax.random.PRNGKey(0) - _, key = jax.random.split(key) # due to API update - _, key = jax.random.split(key) # due to API update - state = init(key) + state = _init_with_current_player(0) for i in [37, 43, 34, 29, 52, 45, 38, 44]: state = step(state, i) assert not state.terminated @@ -56,10 +66,7 @@ def test_terminated(): def test_legal_action(): # cannot put - key = jax.random.PRNGKey(0) - _, key = jax.random.split(key) # due to API update - _, key = jax.random.split(key) # due to API update - state = init(key) + state = _init_with_current_player(0) assert state.current_player == 0 for i in [37, 29, 18, 44, 53, 46, 30, 60, 62, 38, 39]: state = step(state, i) @@ -76,10 +83,7 @@ def test_legal_action(): def test_observe(): - key = jax.random.PRNGKey(0) - _, key = jax.random.split(key) # due to API update - _, key = jax.random.split(key) # due to API update - state = init(key) + state = _init_with_current_player(0) assert state.current_player == 0 obs = observe(state, state.current_player) @@ -125,6 +129,7 @@ def test_observe(): def test_api(): import pgx + env = pgx.make("othello") pgx.api_test(env, 3, use_key=False) pgx.api_test(env, 3, use_key=True)