From b5ac77a90818cf7bb5dedd3b905e11f90d561a4d Mon Sep 17 00:00:00 2001 From: Sumukh Chaluvaraju Date: Thu, 4 Jun 2026 23:17:07 +0100 Subject: [PATCH 1/2] test(value_learning): add stop_target_gradients regression tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit value_learning functions (td_learning, sarsa, q_learning) correctly default to stop_target_gradients=True — matching vtrace.py — but had no tests verifying that the gradient is actually blocked. Add StopTargetGradientsDefaultTest covering: - Default (True): gradient wrt bootstrap target (v_t / q_t) is zero - Explicit False: gradient does flow (opt-in meta-gradient path) - Forward values are identical regardless of the flag (stop_gradient is transparent in forward computation) These tests are the value_learning counterpart of the regression tests added to multistep_test.py in PR #161, completing the coverage story. --- rlax/_src/value_learning_test.py | 98 ++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/rlax/_src/value_learning_test.py b/rlax/_src/value_learning_test.py index 21f2487..0b164a5 100644 --- a/rlax/_src/value_learning_test.py +++ b/rlax/_src/value_learning_test.py @@ -896,6 +896,104 @@ def test_quantile_expected_sarsa_batch_uniform(self, huber_param): self.uniform_expected[huber_param], actual, rtol=1e-5) +class StopTargetGradientsDefaultTest(absltest.TestCase): + """Regression tests verifying stop_target_gradients=True default in value_learning. + + All value_learning functions default to stop_target_gradients=True, which + means gradients do NOT flow through the bootstrap targets. This is the + correct behaviour for standard TD-learning and matches vtrace.py. + + These tests verify: + 1. Default (True): gradient w.r.t. bootstrap value v_t is zero. + 2. Explicit False: gradient does flow through v_t (meta-gradient path). + 3. Forward values are identical regardless of the flag. + """ + + # value_learning functions operate on scalars; use vmap for batched tests. + + # ── td_learning ──────────────────────────────────────────────────────────── + + def test_td_learning_default_stops_gradient_on_v_t(self): + """Default stop_target_gradients=True: grad wrt v_t is zero.""" + v_tm1 = jnp.array([1.0, 2.0, 3.0]) + r_t = jnp.array([0.5, -0.5, 1.0]) + disc = jnp.array([0.9, 0.8, 1.0]) + v_t = jnp.array([1.5, 2.5, 2.0]) + batched = jax.vmap(value_learning.td_learning) + def fn(vt): return batched(v_tm1, r_t, disc, vt).sum() + grad = jax.grad(fn)(v_t) + np.testing.assert_array_equal(grad, jnp.zeros_like(v_t)) + + def test_td_learning_explicit_false_passes_gradient(self): + """stop_target_gradients=False: gradient does flow through v_t.""" + v_tm1 = jnp.array([1.0, 2.0]) + r_t = jnp.array([0.5, -0.5]) + disc = jnp.array([0.9, 0.8]) + v_t = jnp.array([1.5, 2.5]) + batched = jax.vmap(functools.partial( + value_learning.td_learning, stop_target_gradients=False)) + def fn(vt): return batched(v_tm1, r_t, disc, vt).sum() + grad = jax.grad(fn)(v_t) + self.assertFalse(jnp.all(grad == 0)) + + def test_td_learning_forward_unchanged(self): + """Forward values must be identical regardless of stop_target_gradients.""" + v_tm1 = jnp.array([1.0, 2.0]) + r_t = jnp.array([0.5, -0.5]) + disc = jnp.array([0.9, 0.8]) + v_t = jnp.array([1.5, 2.5]) + batched_true = jax.vmap(value_learning.td_learning) + batched_false = jax.vmap(functools.partial( + value_learning.td_learning, stop_target_gradients=False)) + np.testing.assert_allclose( + batched_true(v_tm1, r_t, disc, v_t), + batched_false(v_tm1, r_t, disc, v_t), atol=1e-6) + + # ── sarsa ───────────────────────────────────────────────────────────────── + # sarsa(q_tm1, a_tm1, r_t, discount_t, q_t, a_t, stop_target_gradients) + + def test_sarsa_default_stops_gradient_on_q_t(self): + """Default stop_target_gradients=True: grad wrt q_t is zero.""" + q_tm1 = jnp.array([[1.0, 2.0], [3.0, 1.5]]) + a_tm1 = jnp.array([0, 1]) + r_t = jnp.array([0.5, -0.5]) + disc = jnp.array([0.9, 0.8]) + q_t = jnp.array([[1.5, 2.5], [2.0, 1.0]]) + a_t = jnp.array([1, 0]) + batched = jax.vmap(value_learning.sarsa) + def fn(qt): return batched(q_tm1, a_tm1, r_t, disc, qt, a_t).sum() + grad = jax.grad(fn)(q_t) + np.testing.assert_array_equal(grad, jnp.zeros_like(q_t)) + + def test_sarsa_forward_unchanged(self): + q_tm1 = jnp.array([[1.0, 2.0], [3.0, 1.5]]) + a_tm1 = jnp.array([0, 1]) + r_t = jnp.array([0.5, -0.5]) + disc = jnp.array([0.9, 0.8]) + q_t = jnp.array([[1.5, 2.5], [2.0, 1.0]]) + a_t = jnp.array([1, 0]) + batched_true = jax.vmap(value_learning.sarsa) + batched_false = jax.vmap(functools.partial( + value_learning.sarsa, stop_target_gradients=False)) + np.testing.assert_allclose( + batched_true(q_tm1, a_tm1, r_t, disc, q_t, a_t), + batched_false(q_tm1, a_tm1, r_t, disc, q_t, a_t), atol=1e-6) + + # ── q_learning ──────────────────────────────────────────────────────────── + + def test_q_learning_default_stops_gradient_on_q_t(self): + """Default stop_target_gradients=True: grad wrt q_t is zero.""" + q_tm1 = jnp.array([[1.0, 2.0], [3.0, 1.5]]) + a_tm1 = jnp.array([0, 1]) + r_t = jnp.array([0.5, -0.5]) + disc = jnp.array([0.9, 0.8]) + q_t = jnp.array([[1.5, 2.5], [2.0, 1.0]]) + batched = jax.vmap(value_learning.q_learning) + def fn(qt): return batched(q_tm1, a_tm1, r_t, disc, qt).sum() + grad = jax.grad(fn)(q_t) + np.testing.assert_array_equal(grad, jnp.zeros_like(q_t)) + + if __name__ == '__main__': jax.config.update('jax_numpy_rank_promotion', 'raise') absltest.main() From b59c316506cd43e12abc5d4ff0d419861b2ffdd8 Mon Sep 17 00:00:00 2001 From: Sumukh Chaluvaraju Date: Fri, 5 Jun 2026 13:00:06 +0100 Subject: [PATCH 2/2] fix(value_learning_test): fix pylint style violations in regression tests - Break inline lambda-style defs onto two lines (C0321 multiple-statements) - Shorten StopTargetGradientsDefaultTest docstring to fit 80 chars (C0301) --- rlax/_src/value_learning_test.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/rlax/_src/value_learning_test.py b/rlax/_src/value_learning_test.py index 0b164a5..22fd451 100644 --- a/rlax/_src/value_learning_test.py +++ b/rlax/_src/value_learning_test.py @@ -897,7 +897,7 @@ def test_quantile_expected_sarsa_batch_uniform(self, huber_param): class StopTargetGradientsDefaultTest(absltest.TestCase): - """Regression tests verifying stop_target_gradients=True default in value_learning. + """Regression tests for stop_target_gradients=True default in value_learning. All value_learning functions default to stop_target_gradients=True, which means gradients do NOT flow through the bootstrap targets. This is the @@ -920,7 +920,8 @@ def test_td_learning_default_stops_gradient_on_v_t(self): disc = jnp.array([0.9, 0.8, 1.0]) v_t = jnp.array([1.5, 2.5, 2.0]) batched = jax.vmap(value_learning.td_learning) - def fn(vt): return batched(v_tm1, r_t, disc, vt).sum() + def fn(vt): + return batched(v_tm1, r_t, disc, vt).sum() grad = jax.grad(fn)(v_t) np.testing.assert_array_equal(grad, jnp.zeros_like(v_t)) @@ -932,7 +933,8 @@ def test_td_learning_explicit_false_passes_gradient(self): v_t = jnp.array([1.5, 2.5]) batched = jax.vmap(functools.partial( value_learning.td_learning, stop_target_gradients=False)) - def fn(vt): return batched(v_tm1, r_t, disc, vt).sum() + def fn(vt): + return batched(v_tm1, r_t, disc, vt).sum() grad = jax.grad(fn)(v_t) self.assertFalse(jnp.all(grad == 0)) @@ -961,7 +963,8 @@ def test_sarsa_default_stops_gradient_on_q_t(self): q_t = jnp.array([[1.5, 2.5], [2.0, 1.0]]) a_t = jnp.array([1, 0]) batched = jax.vmap(value_learning.sarsa) - def fn(qt): return batched(q_tm1, a_tm1, r_t, disc, qt, a_t).sum() + def fn(qt): + return batched(q_tm1, a_tm1, r_t, disc, qt, a_t).sum() grad = jax.grad(fn)(q_t) np.testing.assert_array_equal(grad, jnp.zeros_like(q_t)) @@ -989,7 +992,8 @@ def test_q_learning_default_stops_gradient_on_q_t(self): disc = jnp.array([0.9, 0.8]) q_t = jnp.array([[1.5, 2.5], [2.0, 1.0]]) batched = jax.vmap(value_learning.q_learning) - def fn(qt): return batched(q_tm1, a_tm1, r_t, disc, qt).sum() + def fn(qt): + return batched(q_tm1, a_tm1, r_t, disc, qt).sum() grad = jax.grad(fn)(q_t) np.testing.assert_array_equal(grad, jnp.zeros_like(q_t))