From 457da00863158c409b4c342825326cafa8b69738 Mon Sep 17 00:00:00 2001 From: Surya Bhupatiraju Date: Thu, 8 Dec 2022 13:56:55 -0800 Subject: [PATCH] rlax: Upstream Muesli utilities to rlax. We now provide methods for constructing the clipped MPO (CMPO) policy targets used as part of the Muesli agent loss. These CMPO targets are in expectation proportional to: `prior(a|s) * exp(clip(norm(Q(s, a))))` where the prior is computed by the actor policy head, and the Q values are computed using the learned model's reward and value heads. See "Muesli: Combining Improvements in Policy Optimization" by Hessel et al. (https://arxiv.org/pdf/2104.06159.pdf) for more details. PiperOrigin-RevId: 493987878 --- docs/api.rst | 14 +++ rlax/__init__.py | 4 + rlax/_src/policy_targets.py | 190 +++++++++++++++++++++++++++++++++++- 3 files changed, 207 insertions(+), 1 deletion(-) diff --git a/docs/api.rst b/docs/api.rst index 662c9c8..54d2122 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -229,6 +229,7 @@ Policy Optimization .. autosummary:: clipped_surrogate_pg_loss + cmpo_policy_targets constant_policy_targets dpg_loss entropy_loss @@ -238,6 +239,7 @@ Policy Optimization qpg_loss rm_loss rpg_loss + sampled_cmpo_policy_targets sampled_policy_distillation_loss zero_policy_targets @@ -247,6 +249,18 @@ Clipped Surrogate PG Loss .. autofunction:: clipped_surrogate_pg_loss +CMPO Policy Targets +~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: cmpo_policy_targets + + +Sampled CMPO Policy Targets +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: sampled_cmpo_policy_targets + + Compute Parametric KL Penalty and Dual Loss ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/rlax/__init__.py b/rlax/__init__.py index 5a69b60..57af086 100644 --- a/rlax/__init__.py +++ b/rlax/__init__.py @@ -88,8 +88,10 @@ from rlax._src.policy_gradients import qpg_loss from rlax._src.policy_gradients import rm_loss from rlax._src.policy_gradients import rpg_loss +from rlax._src.policy_targets import cmpo_policy_targets from rlax._src.policy_targets import constant_policy_targets from rlax._src.policy_targets import PolicyTarget +from rlax._src.policy_targets import sampled_cmpo_policy_targets from rlax._src.policy_targets import sampled_policy_distillation_loss from rlax._src.policy_targets import zero_policy_targets from rlax._src.pop_art import art @@ -159,6 +161,7 @@ "categorical_td_learning", "clip_gradient", "clipped_surrogate_pg_loss", + "cmpo_policy_targets", "compose_tx", "conditional_update", "constant_policy_targets", @@ -230,6 +233,7 @@ "rpg_loss", "sample_start_indices", "sampled_policy_distillation_loss", + "sampled_cmpo_policy_targets", "sarsa", "sarsa_lambda", "sigmoid", diff --git a/rlax/_src/policy_targets.py b/rlax/_src/policy_targets.py index 13736d1..05283df 100644 --- a/rlax/_src/policy_targets.py +++ b/rlax/_src/policy_targets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Utilities to construct and learn from policy targets.""" +"""Construct and learn from policy targets. Used by Muesli-based agents.""" import functools @@ -20,6 +20,7 @@ import distrax import jax import jax.numpy as jnp +from rlax._src import base @chex.dataclass(frozen=True) @@ -106,3 +107,190 @@ def sampled_policy_distillation_loss( # We average over the samples, over time and batch, and if the actions are # a continuous vector also over the actions. return -jnp.mean(weights * jnp.maximum(log_probs, min_logp)) + + +def cmpo_policy_targets( + prior_distribution, + embeddings, + rng_key, + baseline_value, + q_provider, + advantage_normalizer, + *, + num_actions, + min_target_advantage=-jnp.inf, + max_target_advantage=1.0, + kl_weight=1.0, +) -> PolicyTarget: + """Policy targets for Clipped MPO. + + The policy targets are in-expectation proportional to: + `prior(a|s) * exp(clip(norm(Q(s, a))))` + + See "Muesli: Combining Improvements in Policy Optimization" by Hessel et al. + (https://arxiv.org/pdf/2104.06159.pdf). + + Args: + prior_distribution: the prior policy distribution. + embeddings: embeddings for the `q_provider`. + rng_key: a JAX pseudo random number generator key. + baseline_value: the baseline for `advantage_normalizer`. + q_provider: a fn to compute q values. + advantage_normalizer: a fn to normalise advantages. + *, + num_actions: The total number of discrete actions. + min_target_advantage: The minimum advantage of a policy target. + max_target_advantage: The max advantage of a policy target. + kl_weight: The coefficient for the KL regularizer. + + Returns: + the clipped MPO policy targets. + """ + # Expecting shape [B]. + chex.assert_rank(baseline_value, 1) + rng_key, query_rng_key = jax.random.split(rng_key) + del rng_key + + # Producing all actions with shape [num_actions, B]. + batch_size, = baseline_value.shape + actions = jnp.broadcast_to( + jnp.expand_dims(jnp.arange(num_actions, dtype=jnp.int32), axis=-1), + [num_actions, batch_size]) + + # Using vmap over the num_actions in axis=0. + def _query_q(actions): + return q_provider( + # Using the same rng_key for the all actions samples. + rng_key=query_rng_key, + action=actions, + embeddings=embeddings) + qvalues = jax.vmap(_query_q)(actions) + + # Using the same advantage normalization as for policy gradients. + raw_advantage = advantage_normalizer( + returns=qvalues, baseline_value=baseline_value) + clipped_advantage = jnp.clip( + raw_advantage, min_target_advantage, + max_target_advantage) + + # Construct and normalise the weights. + log_prior = prior_distribution.log_prob(actions) + weights = softmax_policy_target_normalizer( + log_prior + clipped_advantage / kl_weight) + policy_targets = PolicyTarget(actions=actions, weights=weights) + return policy_targets + + +def sampled_cmpo_policy_targets( + prior_distribution, + embeddings, + rng_key, + baseline_value, + q_provider, + advantage_normalizer, + *, + num_actions=2, + min_target_advantage=-jnp.inf, + max_target_advantage=1.0, + kl_weight=1.0, +) -> PolicyTarget: + """Policy targets for sampled CMPO. + + As in CMPO the policy targets are in-expectation proportional to: + `prior(a|s) * exp(clip(norm(Q(s, a))))` + However we only sample a subset of the actions, this allows to scale to + large discrete action spaces and to continuous actions. + + See "Muesli: Combining Improvements in Policy Optimization" by Hessel et al. + (https://arxiv.org/pdf/2104.06159.pdf). + + Args: + prior_distribution: the prior policy distribution. + embeddings: embeddings for the `q_provider`. + rng_key: a JAX pseudo random number generator key. + baseline_value: the baseline for `advantage_normalizer`. + q_provider: a fn to compute q values. + advantage_normalizer: a fn to normalise advantages. + *, + num_actions: The number of actions to expand on each step. + min_target_advantage: The minimum advantage of a policy target. + max_target_advantage: The max advantage of a policy target. + kl_weight: The coefficient for the KL regularizer. + + Returns: + the sampled clipped MPO policy targets. + """ + # Expecting shape [B]. + chex.assert_rank(baseline_value, 1) + query_rng_key, action_key = jax.random.split(rng_key) + del rng_key + + # Sampling the actions from the prior. + actions = prior_distribution.sample( + seed=action_key, sample_shape=[num_actions]) + + # Using vmap over the num_expanded in axis=0. + def _query_q(actions): + return q_provider( + # Using the same rng_key for the all actions samples. + rng_key=query_rng_key, + action=actions, + embeddings=embeddings) + qvalues = jax.vmap(_query_q)(actions) + + # Using the same advantage normalization as for policy gradients. + raw_advantage = advantage_normalizer( + returns=qvalues, baseline_value=baseline_value) + clipped_advantage = jnp.clip( + raw_advantage, min_target_advantage, max_target_advantage) + + # The expected normalized weight would be 1.0. The weights would be + # normalized, if the baseline_value is the log of the expected weight. I.e., + # if the baseline_value is log(sum_a(prior(a|s) * exp(Q(s, a)/c))). + weights = jnp.exp(clipped_advantage / kl_weight) + + # The weights are tiled, if using multiple continuous actions. + # It is OK to use multiple continuous actions inside the Q(s, a), + # because the action is sampled from the joint distribution + # and weight is not based on non-joint probabilities. + log_prob = prior_distribution.log_prob(actions) + weights = jnp.broadcast_to( + base.lhs_broadcast(weights, log_prob), log_prob.shape) + return PolicyTarget(actions=actions, weights=weights) + + +def softmax_policy_target_normalizer(log_weights): + """Returns self-normalized weights. + + The self-normalizing weights introduce a significant bias, + if computing the average weight from a small number of samples. + + Args: + log_weights: log unnormalized weights, shape `[num_targets, ...]`. + + Returns: + Weights divided by average weight from sample. Weights sum to `num_targets`. + """ + num_targets = log_weights.shape[0] + return num_targets * jax.nn.softmax(log_weights, axis=0) + + +def loo_policy_target_normalizer(log_weights): + """A leave-one-out normalizer. + + Args: + log_weights: log unnormalized weights, shape `[num_targets, ...]`. + + Returns: + Weights divided by a consistent estimate of the average weight. The weights + are not guaranteed to sum to `num_targets`. + """ + num_targets = log_weights.shape[0] + weights = jnp.exp(log_weights) + # Using a safe consistent estimator of the average weight, independently of + # the numerator. + # The unnormalized weight are already approximately normalized by a + # baseline_value, so we use `1` as the initial estimate of the average weight. + avg_weight = ( + 1 + jnp.sum(weights, axis=0, keepdims=True) - weights) / num_targets + return weights / avg_weight