Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 22 additions & 17 deletions tests/test_othello.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,25 @@
observe = jax.jit(env.observe)


def _init_with_current_player(player: int):
"""Return an initial state whose current_player is `player`.

The starting player is decided by an RNG coin flip in `_init`, and the exact
flip outcome for a given key is not stable across JAX versions. Searching keys
keeps these tests deterministic and version-agnostic without depending on a
specific `jax.random` implementation.
"""
for seed in range(1000):
state = init(jax.random.PRNGKey(seed))
if int(state.current_player) == player:
return state
raise AssertionError(f"no key produced current_player={player}")


def test_init():
key = jax.random.PRNGKey(0)
_, key = jax.random.split(key) # due to API update
_, key = jax.random.split(key) # due to API update
state = init(key=key)
assert state.current_player == 0
# Both starting players are reachable, and current_player is always 0 or 1.
players = {int(_init_with_current_player(p).current_player) for p in (0, 1)}
assert players == {0, 1}


def test_step():
Expand Down Expand Up @@ -42,10 +55,7 @@ def test_step():

def test_terminated():
# wipe out
key = jax.random.PRNGKey(0)
_, key = jax.random.split(key) # due to API update
_, key = jax.random.split(key) # due to API update
state = init(key)
state = _init_with_current_player(0)
for i in [37, 43, 34, 29, 52, 45, 38, 44]:
state = step(state, i)
assert not state.terminated
Expand All @@ -56,10 +66,7 @@ def test_terminated():

def test_legal_action():
# cannot put
key = jax.random.PRNGKey(0)
_, key = jax.random.split(key) # due to API update
_, key = jax.random.split(key) # due to API update
state = init(key)
state = _init_with_current_player(0)
assert state.current_player == 0
for i in [37, 29, 18, 44, 53, 46, 30, 60, 62, 38, 39]:
state = step(state, i)
Expand All @@ -76,10 +83,7 @@ def test_legal_action():


def test_observe():
key = jax.random.PRNGKey(0)
_, key = jax.random.split(key) # due to API update
_, key = jax.random.split(key) # due to API update
state = init(key)
state = _init_with_current_player(0)
assert state.current_player == 0

obs = observe(state, state.current_player)
Expand Down Expand Up @@ -125,6 +129,7 @@ def test_observe():

def test_api():
import pgx

env = pgx.make("othello")
pgx.api_test(env, 3, use_key=False)
pgx.api_test(env, 3, use_key=True)