From 3220e2a38d0fefc78d9a9796c8a803f2d10ae150 Mon Sep 17 00:00:00 2001 From: RLaxDev Date: Fri, 17 Jan 2025 08:10:38 -0800 Subject: [PATCH] Stop target gradients in quantile_regression_loss. PiperOrigin-RevId: 716674227 --- rlax/_src/value_learning.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/rlax/_src/value_learning.py b/rlax/_src/value_learning.py index 5f02392..d1bdfcf 100644 --- a/rlax/_src/value_learning.py +++ b/rlax/_src/value_learning.py @@ -796,7 +796,9 @@ def quantile_regression_loss( chex.assert_type([dist_src, tau_src, dist_target], float) # Calculate quantile error. - delta = dist_target[None, :] - dist_src[:, None] + target = jax.lax.select(stop_target_gradients, + jax.lax.stop_gradient(dist_target), dist_target) + delta = target[None, :] - dist_src[:, None] delta_neg = (delta < 0.).astype(jnp.float32) delta_neg = jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(delta_neg), delta_neg)