diff --git a/rlax/_src/multistep.py b/rlax/_src/multistep.py index 407f591..9ed6d26 100644 --- a/rlax/_src/multistep.py +++ b/rlax/_src/multistep.py @@ -36,7 +36,7 @@ def lambda_returns( discount_t: Array, v_t: Array, lambda_: Numeric = 1., - stop_target_gradients: bool = False, + stop_target_gradients: bool = True, ) -> Array: """Estimates a multistep truncated lambda return from a trajectory. @@ -124,7 +124,7 @@ def n_step_bootstrapped_returns( v_t: Array, n: int, lambda_t: Numeric = 1., - stop_target_gradients: bool = False, + stop_target_gradients: bool = True, ) -> Array: """Computes strided n-step bootstrapped return targets over a sequence. @@ -182,7 +182,7 @@ def discounted_returns( r_t: Array, discount_t: Array, v_t: Array, - stop_target_gradients: bool = False, + stop_target_gradients: bool = True, ) -> Array: """Calculates a discounted return from a trajectory. @@ -218,7 +218,7 @@ def importance_corrected_td_errors( rho_tm1: Array, lambda_: Array, values: Array, - stop_target_gradients: bool = False, + stop_target_gradients: bool = True, ) -> Array: """Computes the multistep td errors with per decision importance sampling. @@ -281,7 +281,7 @@ def truncated_generalized_advantage_estimation( discount_t: Array, lambda_: Union[Array, Scalar], values: Array, - stop_target_gradients: bool = False, + stop_target_gradients: bool = True, ) -> Array: """Computes truncated generalized advantage estimates for a sequence length k. @@ -334,7 +334,7 @@ def general_off_policy_returns_from_action_values( discount_t: Array, c_t: Array, pi_t: Array, - stop_target_gradients: bool = False, + stop_target_gradients: bool = True, ) -> Array: """Calculates targets for various off-policy correction algorithms. @@ -392,7 +392,7 @@ def general_off_policy_returns_from_q_and_v( r_t: Array, discount_t: Array, c_t: Array, - stop_target_gradients: bool = False, + stop_target_gradients: bool = True, ) -> Array: """Calculates targets for various off-policy evaluation algorithms. diff --git a/rlax/_src/multistep_test.py b/rlax/_src/multistep_test.py index 006cedd..af968ca 100644 --- a/rlax/_src/multistep_test.py +++ b/rlax/_src/multistep_test.py @@ -253,6 +253,98 @@ def test_gae_as_special_case_of_importance_corrected_td_errors(self, lambda_): self.v_t) np.testing.assert_allclose(gae_result, ictd_errors_result, atol=1e-3) +class StopTargetGradientsDefaultTest(absltest.TestCase): + """Regression tests for stop_target_gradients=True default (GitHub #28). + + In standard RL, bootstrap targets should not propagate gradients back into + the value network — otherwise the training objective becomes a moving target + and training can diverge. All multistep return / advantage functions now + default stop_target_gradients=True, matching vtrace's long-standing default. + + These tests verify: + 1. Default behavior (True): gradients through the output are zero w.r.t. v_t. + 2. Explicit False: gradients do flow (needed for meta-gradient methods). + 3. Forward values are identical regardless of stop_target_gradients. + """ + + def _make_inputs(self): + r_t = jnp.array([1.0, 0.0, -1.0, 0.5]) + discount_t = jnp.array([0.9, 0.8, 1.0, 0.9]) + v_t = jnp.array([1.0, 2.0, 1.5, 0.5]) + return r_t, discount_t, v_t + + # ── lambda_returns ──────────────────────────────────────────────────────── + + def test_lambda_returns_default_stops_gradients(self): + r_t, discount_t, v_t = self._make_inputs() + def fn(v): return multistep.lambda_returns(r_t, discount_t, v).sum() + grad = jax.grad(fn)(v_t) + np.testing.assert_array_equal(grad, jnp.zeros_like(v_t)) + + def test_lambda_returns_explicit_false_passes_gradients(self): + r_t, discount_t, v_t = self._make_inputs() + def fn(v): + return multistep.lambda_returns( + r_t, discount_t, v, stop_target_gradients=False).sum() + grad = jax.grad(fn)(v_t) + self.assertFalse(jnp.all(grad == 0)) + + def test_lambda_returns_forward_values_unchanged(self): + r_t, discount_t, v_t = self._make_inputs() + out_true = multistep.lambda_returns(r_t, discount_t, v_t, + stop_target_gradients=True) + out_false = multistep.lambda_returns(r_t, discount_t, v_t, + stop_target_gradients=False) + np.testing.assert_allclose(out_true, out_false, atol=1e-6) + + # ── truncated_generalized_advantage_estimation ──────────────────────────── + + def test_gae_default_stops_gradients(self): + r_t, discount_t, v_t = self._make_inputs() + values = jnp.concatenate([v_t, jnp.array([0.0])]) + def fn(v): + return multistep.truncated_generalized_advantage_estimation( + r_t, discount_t, 0.95, v).sum() + grad = jax.grad(fn)(values) + np.testing.assert_array_equal(grad, jnp.zeros_like(values)) + + def test_gae_explicit_false_passes_gradients(self): + r_t, discount_t, v_t = self._make_inputs() + values = jnp.concatenate([v_t, jnp.array([0.0])]) + def fn(v): + return multistep.truncated_generalized_advantage_estimation( + r_t, discount_t, 0.95, v, stop_target_gradients=False).sum() + grad = jax.grad(fn)(values) + self.assertFalse(jnp.all(grad == 0)) + + def test_gae_forward_values_unchanged(self): + r_t, discount_t, v_t = self._make_inputs() + values = jnp.concatenate([v_t, jnp.array([0.0])]) + out_true = multistep.truncated_generalized_advantage_estimation( + r_t, discount_t, 0.95, values, stop_target_gradients=True) + out_false = multistep.truncated_generalized_advantage_estimation( + r_t, discount_t, 0.95, values, stop_target_gradients=False) + np.testing.assert_allclose(out_true, out_false, atol=1e-6) + + # ── n_step_bootstrapped_returns ──────────────────────────────────────────── + + def test_n_step_returns_default_stops_gradients(self): + r_t, discount_t, v_t = self._make_inputs() + def fn(v): + return multistep.n_step_bootstrapped_returns( + r_t, discount_t, v, n=2).sum() + grad = jax.grad(fn)(v_t) + np.testing.assert_array_equal(grad, jnp.zeros_like(v_t)) + + def test_n_step_returns_forward_values_unchanged(self): + r_t, discount_t, v_t = self._make_inputs() + out_true = multistep.n_step_bootstrapped_returns( + r_t, discount_t, v_t, n=2, stop_target_gradients=True) + out_false = multistep.n_step_bootstrapped_returns( + r_t, discount_t, v_t, n=2, stop_target_gradients=False) + np.testing.assert_allclose(out_true, out_false, atol=1e-6) + + if __name__ == '__main__': jax.config.update('jax_numpy_rank_promotion', 'raise') absltest.main()