diff --git a/rlax/__init__.py b/rlax/__init__.py index 63f986b..053ad9e 100644 --- a/rlax/__init__.py +++ b/rlax/__init__.py @@ -101,6 +101,7 @@ from rlax._src.value_learning import categorical_td_learning from rlax._src.value_learning import double_q_learning from rlax._src.value_learning import expected_sarsa +from rlax._src.value_learning import expectile_naive_q_learning from rlax._src.value_learning import persistent_q_learning from rlax._src.value_learning import q_lambda from rlax._src.value_learning import q_learning @@ -143,6 +144,7 @@ "epsilon_greedy", "epsilon_softmax", "expected_sarsa", + "expectile_naive_q_learning", "feature_control_rewards", "gaussian_diagonal", "HYPERBOLIC_SIN_PAIR", diff --git a/rlax/_src/value_learning.py b/rlax/_src/value_learning.py index 3ea8a88..c93abca 100644 --- a/rlax/_src/value_learning.py +++ b/rlax/_src/value_learning.py @@ -852,3 +852,90 @@ def quantile_expected_sarsa( return _quantile_regression_loss( dist_qa_tm1, tau_q_tm1, dist_target, huber_param) + + +def _expectile_naive_regression_loss( + dist_src: Array, + tau_src: Array, + dist_target: Array +) -> Numeric: + """Compute ER-naive loss between two discrete quantile-valued distributions. + + See "Statistics and Samples in Distributional Reinforcement Learning" by + Rowland et al. (http://proceedings.mlr.press/v97/rowland19a). + + Args: + dist_src: source probability distribution. + tau_src: source distribution probability thresholds. + dist_target: target probability distribution. + + Returns: + Expectile regression (naive) loss. + """ + chex.assert_rank([dist_src, tau_src, dist_target], 1) + chex.assert_type([dist_src, tau_src, dist_target], float) + + # Calculate expectile error. + delta = dist_target[None, :] - dist_src[:, None] + delta_neg = (delta < 0.).astype(jnp.float32) + delta_neg = jax.lax.stop_gradient(delta_neg) + weight = jnp.abs(tau_src[:, None] - delta_neg) + + # Calculate expectile regression (naive) loss. + loss = jnp.square(delta) + loss *= weight + + # Average over target-samples dimension, sum over src-samples dimension. + return jnp.sum(jnp.mean(loss, axis=-1)) + + +def expectile_naive_q_learning( + dist_q_tm1: Array, + tau_q_tm1: Array, + a_tm1: Numeric, + r_t: Numeric, + discount_t: Numeric, + dist_q_t_selector: Array, + dist_q_t: Array, +) -> Numeric: + """Implements Q-learning for expectile-valued Q distributions. + + See "Statistics and Samples in Distributional Reinforcement Learning" by + Rowland et al. (http://proceedings.mlr.press/v97/rowland19a). + + Args: + dist_q_tm1: Q distribution at time t-1. + tau_q_tm1: Q distribution probability thresholds. + a_tm1: action index at time t-1. + r_t: reward at time t. + discount_t: discount at time t. + dist_q_t_selector: Q distribution at time t for selecting greedy action in + target policy. This is separate from dist_q_t as in Double Q-Learning, but + can be computed with the target network and a separate set of samples. + dist_q_t: target Q distribution at time t. + huber_param: Huber loss parameter, defaults to 0 (no Huber loss). + + Returns: + Expectile regression (naive) Q learning loss. + """ + chex.assert_rank([ + dist_q_tm1, tau_q_tm1, a_tm1, r_t, discount_t, dist_q_t_selector, dist_q_t + ], [2, 1, 0, 0, 0, 2, 2]) + chex.assert_type([ + dist_q_tm1, tau_q_tm1, a_tm1, r_t, discount_t, dist_q_t_selector, dist_q_t + ], [float, float, int, float, float, float, float]) + + # Only update the taken actions. + dist_qa_tm1 = dist_q_tm1[:, a_tm1] + + # Select target action according to greedy policy w.r.t. dist_q_t_selector. + q_t_selector = jnp.mean(dist_q_t_selector, axis=0) + a_t = jnp.argmax(q_t_selector) + dist_qa_t = dist_q_t[:, a_t] + + # Compute target, do not backpropagate into it. + dist_target = r_t + discount_t * dist_qa_t + dist_target = jax.lax.stop_gradient(dist_target) + + return _expectile_naive_regression_loss( + dist_qa_tm1, tau_q_tm1, dist_target)