diff --git a/examples/simple_dqn.py b/examples/simple_dqn.py index e817281..78c7d0b 100644 --- a/examples/simple_dqn.py +++ b/examples/simple_dqn.py @@ -25,7 +25,7 @@ import jax.numpy as jnp import optax import rlax -from rlax.examples import experiment +import experiment Params = collections.namedtuple("Params", "online target") ActorState = collections.namedtuple("ActorState", "count")