diff --git a/rlax/_src/multistep.py b/rlax/_src/multistep.py index 407f591..5c95003 100644 --- a/rlax/_src/multistep.py +++ b/rlax/_src/multistep.py @@ -37,6 +37,7 @@ def lambda_returns( v_t: Array, lambda_: Numeric = 1., stop_target_gradients: bool = False, + unroll: int | bool = 1, ) -> Array: """Estimates a multistep truncated lambda return from a trajectory. @@ -93,6 +94,7 @@ def lambda_returns( lambda_: mixing parameter; a scalar or a vector for timesteps t in [1, T]. stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. + unroll: how many scan iterations to unroll. Returns: Multistep lambda returns. @@ -111,7 +113,12 @@ def _body(acc, xs): return acc, acc _, returns = jax.lax.scan( - _body, v_t[-1], (r_t, discount_t, v_t, lambda_), reverse=True) + _body, + v_t[-1], + (r_t, discount_t, v_t, lambda_), + reverse=True, + unroll=unroll, + ) return jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(returns), @@ -219,6 +226,7 @@ def importance_corrected_td_errors( lambda_: Array, values: Array, stop_target_gradients: bool = False, + unroll: int | bool = 1, ) -> Array: """Computes the multistep td errors with per decision importance sampling. @@ -246,6 +254,7 @@ def importance_corrected_td_errors( values: sequence of state values under π for all timesteps t in [0, T]. stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. + unroll: how many scan iterations to unroll. Returns: Off-policy estimates of the multistep td errors. @@ -269,7 +278,12 @@ def _body(acc, xs): return acc, acc _, errors = jax.lax.scan( - _body, 0.0, (one_step_delta, discount_t, rho_t, lambda_), reverse=True) + _body, + 0.0, + (one_step_delta, discount_t, rho_t, lambda_), + reverse=True, + unroll=unroll, + ) errors = rho_tm1 * errors return jax.lax.select(stop_target_gradients, @@ -282,6 +296,7 @@ def truncated_generalized_advantage_estimation( lambda_: Union[Array, Scalar], values: Array, stop_target_gradients: bool = False, + unroll: int | bool = 1, ) -> Array: """Computes truncated generalized advantage estimates for a sequence length k. @@ -303,6 +318,7 @@ def truncated_generalized_advantage_estimation( values: Sequence of values under π at times [0, k] stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. + unroll: how many scan iterations to unroll. Returns: Multistep truncated generalized advantage estimation at times [0, k-1]. @@ -320,7 +336,12 @@ def _body(acc, xs): return acc, acc _, advantage_t = jax.lax.scan( - _body, 0.0, (delta_t, discount_t, lambda_), reverse=True) + _body, + 0.0, + (delta_t, discount_t, lambda_), + reverse=True, + unroll=unroll, + ) return jax.lax.select(stop_target_gradients, jax.lax.stop_gradient(advantage_t), @@ -393,6 +414,7 @@ def general_off_policy_returns_from_q_and_v( discount_t: Array, c_t: Array, stop_target_gradients: bool = False, + unroll: int | bool = 1, ) -> Array: """Calculates targets for various off-policy evaluation algorithms. @@ -421,6 +443,7 @@ def general_off_policy_returns_from_q_and_v( c_t: weights at times [1, ..., K - 1]. stop_target_gradients: bool indicating whether or not to apply stop gradient to targets. + unroll: how many scan iterations to unroll. Returns: Off-policy estimates of the generalized returns from states visited at times @@ -438,7 +461,12 @@ def _body(acc, xs): return acc, acc _, returns = jax.lax.scan( - _body, g, (r_t[:-1], discount_t[:-1], c_t, v_t[:-1], q_t), reverse=True) + _body, + g, + (r_t[:-1], discount_t[:-1], c_t, v_t[:-1], q_t), + reverse=True, + unroll=unroll, + ) returns = jnp.concatenate([returns, g[jnp.newaxis]], axis=0) return jax.lax.select(stop_target_gradients, diff --git a/rlax/_src/vtrace.py b/rlax/_src/vtrace.py index 99e3a66..158b98f 100644 --- a/rlax/_src/vtrace.py +++ b/rlax/_src/vtrace.py @@ -44,6 +44,7 @@ def vtrace( lambda_: Numeric = 1.0, clip_rho_threshold: float = 1.0, stop_target_gradients: bool = True, + unroll: int | bool = 1, ) -> Array: """Calculates V-Trace errors from importance weights. @@ -62,6 +63,7 @@ def vtrace( lambda_: mixing parameter; a scalar or a vector for timesteps t. clip_rho_threshold: clip threshold for importance weights. stop_target_gradients: whether or not to apply stop gradient to targets. + unroll: how many scan iterations to unroll. Returns: V-Trace error. @@ -86,7 +88,12 @@ def _body(acc, xs): return acc, acc _, errors = jax.lax.scan( - _body, 0.0, (td_errors, discount_t, c_tm1), reverse=True) + _body, + 0.0, + (td_errors, discount_t, c_tm1), + reverse=True, + unroll=unroll, + ) # Return errors, maybe disabling gradient flow through bootstrap targets. return jax.lax.select( @@ -104,7 +111,9 @@ def leaky_vtrace( alpha_: float = 1.0, lambda_: Numeric = 1.0, clip_rho_threshold: float = 1.0, - stop_target_gradients: bool = True): + stop_target_gradients: bool = True, + unroll: int | bool = 1, +): """Calculates Leaky V-Trace errors from importance weights. Leaky-Vtrace is a combination of Importance sampling and V-trace, where the @@ -123,6 +132,7 @@ def leaky_vtrace( lambda_: mixing parameter; a scalar or a vector for timesteps t. clip_rho_threshold: clip threshold for importance weights. stop_target_gradients: whether or not to apply stop gradient to targets. + unroll: how many scan iterations to unroll. Returns: Leaky V-Trace error. @@ -150,7 +160,12 @@ def _body(acc, xs): return acc, acc _, errors = jax.lax.scan( - _body, 0.0, (td_errors, discount_t, c_tm1), reverse=True) + _body, + 0.0, + (td_errors, discount_t, c_tm1), + reverse=True, + unroll=unroll, + ) # Return errors, maybe disabling gradient flow through bootstrap targets. return jax.lax.select(