diff --git a/rlax/_src/value_learning_test.py b/rlax/_src/value_learning_test.py index 21f2487..22fd451 100644 --- a/rlax/_src/value_learning_test.py +++ b/rlax/_src/value_learning_test.py @@ -896,6 +896,108 @@ def test_quantile_expected_sarsa_batch_uniform(self, huber_param): self.uniform_expected[huber_param], actual, rtol=1e-5) +class StopTargetGradientsDefaultTest(absltest.TestCase): + """Regression tests for stop_target_gradients=True default in value_learning. + + All value_learning functions default to stop_target_gradients=True, which + means gradients do NOT flow through the bootstrap targets. This is the + correct behaviour for standard TD-learning and matches vtrace.py. + + These tests verify: + 1. Default (True): gradient w.r.t. bootstrap value v_t is zero. + 2. Explicit False: gradient does flow through v_t (meta-gradient path). + 3. Forward values are identical regardless of the flag. + """ + + # value_learning functions operate on scalars; use vmap for batched tests. + + # ── td_learning ──────────────────────────────────────────────────────────── + + def test_td_learning_default_stops_gradient_on_v_t(self): + """Default stop_target_gradients=True: grad wrt v_t is zero.""" + v_tm1 = jnp.array([1.0, 2.0, 3.0]) + r_t = jnp.array([0.5, -0.5, 1.0]) + disc = jnp.array([0.9, 0.8, 1.0]) + v_t = jnp.array([1.5, 2.5, 2.0]) + batched = jax.vmap(value_learning.td_learning) + def fn(vt): + return batched(v_tm1, r_t, disc, vt).sum() + grad = jax.grad(fn)(v_t) + np.testing.assert_array_equal(grad, jnp.zeros_like(v_t)) + + def test_td_learning_explicit_false_passes_gradient(self): + """stop_target_gradients=False: gradient does flow through v_t.""" + v_tm1 = jnp.array([1.0, 2.0]) + r_t = jnp.array([0.5, -0.5]) + disc = jnp.array([0.9, 0.8]) + v_t = jnp.array([1.5, 2.5]) + batched = jax.vmap(functools.partial( + value_learning.td_learning, stop_target_gradients=False)) + def fn(vt): + return batched(v_tm1, r_t, disc, vt).sum() + grad = jax.grad(fn)(v_t) + self.assertFalse(jnp.all(grad == 0)) + + def test_td_learning_forward_unchanged(self): + """Forward values must be identical regardless of stop_target_gradients.""" + v_tm1 = jnp.array([1.0, 2.0]) + r_t = jnp.array([0.5, -0.5]) + disc = jnp.array([0.9, 0.8]) + v_t = jnp.array([1.5, 2.5]) + batched_true = jax.vmap(value_learning.td_learning) + batched_false = jax.vmap(functools.partial( + value_learning.td_learning, stop_target_gradients=False)) + np.testing.assert_allclose( + batched_true(v_tm1, r_t, disc, v_t), + batched_false(v_tm1, r_t, disc, v_t), atol=1e-6) + + # ── sarsa ───────────────────────────────────────────────────────────────── + # sarsa(q_tm1, a_tm1, r_t, discount_t, q_t, a_t, stop_target_gradients) + + def test_sarsa_default_stops_gradient_on_q_t(self): + """Default stop_target_gradients=True: grad wrt q_t is zero.""" + q_tm1 = jnp.array([[1.0, 2.0], [3.0, 1.5]]) + a_tm1 = jnp.array([0, 1]) + r_t = jnp.array([0.5, -0.5]) + disc = jnp.array([0.9, 0.8]) + q_t = jnp.array([[1.5, 2.5], [2.0, 1.0]]) + a_t = jnp.array([1, 0]) + batched = jax.vmap(value_learning.sarsa) + def fn(qt): + return batched(q_tm1, a_tm1, r_t, disc, qt, a_t).sum() + grad = jax.grad(fn)(q_t) + np.testing.assert_array_equal(grad, jnp.zeros_like(q_t)) + + def test_sarsa_forward_unchanged(self): + q_tm1 = jnp.array([[1.0, 2.0], [3.0, 1.5]]) + a_tm1 = jnp.array([0, 1]) + r_t = jnp.array([0.5, -0.5]) + disc = jnp.array([0.9, 0.8]) + q_t = jnp.array([[1.5, 2.5], [2.0, 1.0]]) + a_t = jnp.array([1, 0]) + batched_true = jax.vmap(value_learning.sarsa) + batched_false = jax.vmap(functools.partial( + value_learning.sarsa, stop_target_gradients=False)) + np.testing.assert_allclose( + batched_true(q_tm1, a_tm1, r_t, disc, q_t, a_t), + batched_false(q_tm1, a_tm1, r_t, disc, q_t, a_t), atol=1e-6) + + # ── q_learning ──────────────────────────────────────────────────────────── + + def test_q_learning_default_stops_gradient_on_q_t(self): + """Default stop_target_gradients=True: grad wrt q_t is zero.""" + q_tm1 = jnp.array([[1.0, 2.0], [3.0, 1.5]]) + a_tm1 = jnp.array([0, 1]) + r_t = jnp.array([0.5, -0.5]) + disc = jnp.array([0.9, 0.8]) + q_t = jnp.array([[1.5, 2.5], [2.0, 1.0]]) + batched = jax.vmap(value_learning.q_learning) + def fn(qt): + return batched(q_tm1, a_tm1, r_t, disc, qt).sum() + grad = jax.grad(fn)(q_t) + np.testing.assert_array_equal(grad, jnp.zeros_like(q_t)) + + if __name__ == '__main__': jax.config.update('jax_numpy_rank_promotion', 'raise') absltest.main()