Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions rlax/_src/multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def lambda_returns(
discount_t: Array,
v_t: Array,
lambda_: Numeric = 1.,
stop_target_gradients: bool = False,
stop_target_gradients: bool = True,
) -> Array:
"""Estimates a multistep truncated lambda return from a trajectory.

Expand Down Expand Up @@ -124,7 +124,7 @@ def n_step_bootstrapped_returns(
v_t: Array,
n: int,
lambda_t: Numeric = 1.,
stop_target_gradients: bool = False,
stop_target_gradients: bool = True,
) -> Array:
"""Computes strided n-step bootstrapped return targets over a sequence.

Expand Down Expand Up @@ -182,7 +182,7 @@ def discounted_returns(
r_t: Array,
discount_t: Array,
v_t: Array,
stop_target_gradients: bool = False,
stop_target_gradients: bool = True,
) -> Array:
"""Calculates a discounted return from a trajectory.

Expand Down Expand Up @@ -218,7 +218,7 @@ def importance_corrected_td_errors(
rho_tm1: Array,
lambda_: Array,
values: Array,
stop_target_gradients: bool = False,
stop_target_gradients: bool = True,
) -> Array:
"""Computes the multistep td errors with per decision importance sampling.

Expand Down Expand Up @@ -281,7 +281,7 @@ def truncated_generalized_advantage_estimation(
discount_t: Array,
lambda_: Union[Array, Scalar],
values: Array,
stop_target_gradients: bool = False,
stop_target_gradients: bool = True,
) -> Array:
"""Computes truncated generalized advantage estimates for a sequence length k.

Expand Down Expand Up @@ -334,7 +334,7 @@ def general_off_policy_returns_from_action_values(
discount_t: Array,
c_t: Array,
pi_t: Array,
stop_target_gradients: bool = False,
stop_target_gradients: bool = True,
) -> Array:
"""Calculates targets for various off-policy correction algorithms.

Expand Down Expand Up @@ -392,7 +392,7 @@ def general_off_policy_returns_from_q_and_v(
r_t: Array,
discount_t: Array,
c_t: Array,
stop_target_gradients: bool = False,
stop_target_gradients: bool = True,
) -> Array:
"""Calculates targets for various off-policy evaluation algorithms.

Expand Down
92 changes: 92 additions & 0 deletions rlax/_src/multistep_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,98 @@ def test_gae_as_special_case_of_importance_corrected_td_errors(self, lambda_):
self.v_t)
np.testing.assert_allclose(gae_result, ictd_errors_result, atol=1e-3)

class StopTargetGradientsDefaultTest(absltest.TestCase):
"""Regression tests for stop_target_gradients=True default (GitHub #28).

In standard RL, bootstrap targets should not propagate gradients back into
the value network — otherwise the training objective becomes a moving target
and training can diverge. All multistep return / advantage functions now
default stop_target_gradients=True, matching vtrace's long-standing default.

These tests verify:
1. Default behavior (True): gradients through the output are zero w.r.t. v_t.
2. Explicit False: gradients do flow (needed for meta-gradient methods).
3. Forward values are identical regardless of stop_target_gradients.
"""

def _make_inputs(self):
r_t = jnp.array([1.0, 0.0, -1.0, 0.5])
discount_t = jnp.array([0.9, 0.8, 1.0, 0.9])
v_t = jnp.array([1.0, 2.0, 1.5, 0.5])
return r_t, discount_t, v_t

# ── lambda_returns ────────────────────────────────────────────────────────

def test_lambda_returns_default_stops_gradients(self):
r_t, discount_t, v_t = self._make_inputs()
def fn(v): return multistep.lambda_returns(r_t, discount_t, v).sum()
grad = jax.grad(fn)(v_t)
np.testing.assert_array_equal(grad, jnp.zeros_like(v_t))

def test_lambda_returns_explicit_false_passes_gradients(self):
r_t, discount_t, v_t = self._make_inputs()
def fn(v):
return multistep.lambda_returns(
r_t, discount_t, v, stop_target_gradients=False).sum()
grad = jax.grad(fn)(v_t)
self.assertFalse(jnp.all(grad == 0))

def test_lambda_returns_forward_values_unchanged(self):
r_t, discount_t, v_t = self._make_inputs()
out_true = multistep.lambda_returns(r_t, discount_t, v_t,
stop_target_gradients=True)
out_false = multistep.lambda_returns(r_t, discount_t, v_t,
stop_target_gradients=False)
np.testing.assert_allclose(out_true, out_false, atol=1e-6)

# ── truncated_generalized_advantage_estimation ────────────────────────────

def test_gae_default_stops_gradients(self):
r_t, discount_t, v_t = self._make_inputs()
values = jnp.concatenate([v_t, jnp.array([0.0])])
def fn(v):
return multistep.truncated_generalized_advantage_estimation(
r_t, discount_t, 0.95, v).sum()
grad = jax.grad(fn)(values)
np.testing.assert_array_equal(grad, jnp.zeros_like(values))

def test_gae_explicit_false_passes_gradients(self):
r_t, discount_t, v_t = self._make_inputs()
values = jnp.concatenate([v_t, jnp.array([0.0])])
def fn(v):
return multistep.truncated_generalized_advantage_estimation(
r_t, discount_t, 0.95, v, stop_target_gradients=False).sum()
grad = jax.grad(fn)(values)
self.assertFalse(jnp.all(grad == 0))

def test_gae_forward_values_unchanged(self):
r_t, discount_t, v_t = self._make_inputs()
values = jnp.concatenate([v_t, jnp.array([0.0])])
out_true = multistep.truncated_generalized_advantage_estimation(
r_t, discount_t, 0.95, values, stop_target_gradients=True)
out_false = multistep.truncated_generalized_advantage_estimation(
r_t, discount_t, 0.95, values, stop_target_gradients=False)
np.testing.assert_allclose(out_true, out_false, atol=1e-6)

# ── n_step_bootstrapped_returns ────────────────────────────────────────────

def test_n_step_returns_default_stops_gradients(self):
r_t, discount_t, v_t = self._make_inputs()
def fn(v):
return multistep.n_step_bootstrapped_returns(
r_t, discount_t, v, n=2).sum()
grad = jax.grad(fn)(v_t)
np.testing.assert_array_equal(grad, jnp.zeros_like(v_t))

def test_n_step_returns_forward_values_unchanged(self):
r_t, discount_t, v_t = self._make_inputs()
out_true = multistep.n_step_bootstrapped_returns(
r_t, discount_t, v_t, n=2, stop_target_gradients=True)
out_false = multistep.n_step_bootstrapped_returns(
r_t, discount_t, v_t, n=2, stop_target_gradients=False)
np.testing.assert_allclose(out_true, out_false, atol=1e-6)


if __name__ == '__main__':
jax.config.update('jax_numpy_rank_promotion', 'raise')
absltest.main()