diff --git a/rlax/_src/value_learning.py b/rlax/_src/value_learning.py index 5f02392..d1bdfcf 100644 --- a/rlax/_src/value_learning.py +++ b/rlax/_src/value_learning.py @@ -796,7 +796,9 @@ def quantile_regression_loss( chex.assert_type([dist_src, tau_src, dist_target], float) # Calculate quantile error. - delta = dist_target[None, :] - dist_src[:, None] + target = jax.lax.select(stop_target_gradients, + jax.lax.stop_gradient(dist_target), dist_target) + delta = target[None, :] - dist_src[:, None] delta_neg = (delta < 0.).astype(jnp.float32) delta_neg = jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(delta_neg), delta_neg)