From 26ad22f305ed2863bece29cb983914fa419c06a8 Mon Sep 17 00:00:00 2001 From: Sumukh Chaluvaraju Date: Thu, 4 Jun 2026 15:57:51 +0100 Subject: [PATCH] fix(multistep): change stop_target_gradients default to True, matching vtrace MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All seven return/advantage functions in multistep.py previously defaulted stop_target_gradients=False. This silently allowed gradients to flow through bootstrap targets during RL training — incorrect for standard agents and inconsistent with vtrace.py, where every function has defaulted to True since the library's inception. Functions changed: lambda_returns, n_step_bootstrapped_returns, discounted_returns, importance_corrected_td_errors, truncated_generalized_advantage_estimation, general_off_policy_returns_from_action_values, general_off_policy_returns_from_q_and_v The False case is still reachable by passing stop_target_gradients=False explicitly; it is only needed for meta-gradient methods (a rare use case that should be opt-in, not the default). Note: stop_gradient does not affect forward-pass values, so this change carries no risk of altering numerical results in existing code — it only affects gradient computation. Users who were relying on the old default (gradients through targets) for meta-learning must now pass stop_target_gradients=False explicitly. Also adds StopTargetGradientsDefaultTest covering: - Default (True) blocks gradients for lambda_returns, GAE, n_step_returns - Explicit False still passes gradients (opt-in meta-gradient path) - Forward values are identical regardless of the flag Fixes: google-deepmind/rlax#28 --- rlax/_src/multistep.py | 14 +++--- rlax/_src/multistep_test.py | 92 +++++++++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 7 deletions(-) diff --git a/rlax/_src/multistep.py b/rlax/_src/multistep.py index 407f591..9ed6d26 100644 --- a/rlax/_src/multistep.py +++ b/rlax/_src/multistep.py @@ -36,7 +36,7 @@ def lambda_returns( discount_t: Array, v_t: Array, lambda_: Numeric = 1., - stop_target_gradients: bool = False, + stop_target_gradients: bool = True, ) -> Array: """Estimates a multistep truncated lambda return from a trajectory. @@ -124,7 +124,7 @@ def n_step_bootstrapped_returns( v_t: Array, n: int, lambda_t: Numeric = 1., - stop_target_gradients: bool = False, + stop_target_gradients: bool = True, ) -> Array: """Computes strided n-step bootstrapped return targets over a sequence. @@ -182,7 +182,7 @@ def discounted_returns( r_t: Array, discount_t: Array, v_t: Array, - stop_target_gradients: bool = False, + stop_target_gradients: bool = True, ) -> Array: """Calculates a discounted return from a trajectory. @@ -218,7 +218,7 @@ def importance_corrected_td_errors( rho_tm1: Array, lambda_: Array, values: Array, - stop_target_gradients: bool = False, + stop_target_gradients: bool = True, ) -> Array: """Computes the multistep td errors with per decision importance sampling. @@ -281,7 +281,7 @@ def truncated_generalized_advantage_estimation( discount_t: Array, lambda_: Union[Array, Scalar], values: Array, - stop_target_gradients: bool = False, + stop_target_gradients: bool = True, ) -> Array: """Computes truncated generalized advantage estimates for a sequence length k. @@ -334,7 +334,7 @@ def general_off_policy_returns_from_action_values( discount_t: Array, c_t: Array, pi_t: Array, - stop_target_gradients: bool = False, + stop_target_gradients: bool = True, ) -> Array: """Calculates targets for various off-policy correction algorithms. @@ -392,7 +392,7 @@ def general_off_policy_returns_from_q_and_v( r_t: Array, discount_t: Array, c_t: Array, - stop_target_gradients: bool = False, + stop_target_gradients: bool = True, ) -> Array: """Calculates targets for various off-policy evaluation algorithms. diff --git a/rlax/_src/multistep_test.py b/rlax/_src/multistep_test.py index 006cedd..af968ca 100644 --- a/rlax/_src/multistep_test.py +++ b/rlax/_src/multistep_test.py @@ -253,6 +253,98 @@ def test_gae_as_special_case_of_importance_corrected_td_errors(self, lambda_): self.v_t) np.testing.assert_allclose(gae_result, ictd_errors_result, atol=1e-3) +class StopTargetGradientsDefaultTest(absltest.TestCase): + """Regression tests for stop_target_gradients=True default (GitHub #28). + + In standard RL, bootstrap targets should not propagate gradients back into + the value network — otherwise the training objective becomes a moving target + and training can diverge. All multistep return / advantage functions now + default stop_target_gradients=True, matching vtrace's long-standing default. + + These tests verify: + 1. Default behavior (True): gradients through the output are zero w.r.t. v_t. + 2. Explicit False: gradients do flow (needed for meta-gradient methods). + 3. Forward values are identical regardless of stop_target_gradients. + """ + + def _make_inputs(self): + r_t = jnp.array([1.0, 0.0, -1.0, 0.5]) + discount_t = jnp.array([0.9, 0.8, 1.0, 0.9]) + v_t = jnp.array([1.0, 2.0, 1.5, 0.5]) + return r_t, discount_t, v_t + + # ── lambda_returns ──────────────────────────────────────────────────────── + + def test_lambda_returns_default_stops_gradients(self): + r_t, discount_t, v_t = self._make_inputs() + def fn(v): return multistep.lambda_returns(r_t, discount_t, v).sum() + grad = jax.grad(fn)(v_t) + np.testing.assert_array_equal(grad, jnp.zeros_like(v_t)) + + def test_lambda_returns_explicit_false_passes_gradients(self): + r_t, discount_t, v_t = self._make_inputs() + def fn(v): + return multistep.lambda_returns( + r_t, discount_t, v, stop_target_gradients=False).sum() + grad = jax.grad(fn)(v_t) + self.assertFalse(jnp.all(grad == 0)) + + def test_lambda_returns_forward_values_unchanged(self): + r_t, discount_t, v_t = self._make_inputs() + out_true = multistep.lambda_returns(r_t, discount_t, v_t, + stop_target_gradients=True) + out_false = multistep.lambda_returns(r_t, discount_t, v_t, + stop_target_gradients=False) + np.testing.assert_allclose(out_true, out_false, atol=1e-6) + + # ── truncated_generalized_advantage_estimation ──────────────────────────── + + def test_gae_default_stops_gradients(self): + r_t, discount_t, v_t = self._make_inputs() + values = jnp.concatenate([v_t, jnp.array([0.0])]) + def fn(v): + return multistep.truncated_generalized_advantage_estimation( + r_t, discount_t, 0.95, v).sum() + grad = jax.grad(fn)(values) + np.testing.assert_array_equal(grad, jnp.zeros_like(values)) + + def test_gae_explicit_false_passes_gradients(self): + r_t, discount_t, v_t = self._make_inputs() + values = jnp.concatenate([v_t, jnp.array([0.0])]) + def fn(v): + return multistep.truncated_generalized_advantage_estimation( + r_t, discount_t, 0.95, v, stop_target_gradients=False).sum() + grad = jax.grad(fn)(values) + self.assertFalse(jnp.all(grad == 0)) + + def test_gae_forward_values_unchanged(self): + r_t, discount_t, v_t = self._make_inputs() + values = jnp.concatenate([v_t, jnp.array([0.0])]) + out_true = multistep.truncated_generalized_advantage_estimation( + r_t, discount_t, 0.95, values, stop_target_gradients=True) + out_false = multistep.truncated_generalized_advantage_estimation( + r_t, discount_t, 0.95, values, stop_target_gradients=False) + np.testing.assert_allclose(out_true, out_false, atol=1e-6) + + # ── n_step_bootstrapped_returns ──────────────────────────────────────────── + + def test_n_step_returns_default_stops_gradients(self): + r_t, discount_t, v_t = self._make_inputs() + def fn(v): + return multistep.n_step_bootstrapped_returns( + r_t, discount_t, v, n=2).sum() + grad = jax.grad(fn)(v_t) + np.testing.assert_array_equal(grad, jnp.zeros_like(v_t)) + + def test_n_step_returns_forward_values_unchanged(self): + r_t, discount_t, v_t = self._make_inputs() + out_true = multistep.n_step_bootstrapped_returns( + r_t, discount_t, v_t, n=2, stop_target_gradients=True) + out_false = multistep.n_step_bootstrapped_returns( + r_t, discount_t, v_t, n=2, stop_target_gradients=False) + np.testing.assert_allclose(out_true, out_false, atol=1e-6) + + if __name__ == '__main__': jax.config.update('jax_numpy_rank_promotion', 'raise') absltest.main()