From b1fe844a2cf5611dab21b8c916da6aea9214c679 Mon Sep 17 00:00:00 2001 From: Mehdi Jafarnia Date: Fri, 14 Jun 2024 17:24:04 -0700 Subject: [PATCH] Add capability of regularization towards random weights. PiperOrigin-RevId: 643493342 --- enn/losses/utils.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/enn/losses/utils.py b/enn/losses/utils.py index cf87c22..431b4ca 100644 --- a/enn/losses/utils.py +++ b/enn/losses/utils.py @@ -68,7 +68,8 @@ def noisy_loss( def add_l2_weight_decay( loss_fn: _LossFn, scale: Union[float, Callable[[hk.Params], hk.Params]], - predicate: Optional[PredicateFn] = None + predicate: Optional[PredicateFn] = None, + regularize_towards_random_weights: bool = False, ) -> _LossFn: """Adds scale * l2 weight decay to an existing loss function.""" try: # Scale is numeric. @@ -79,14 +80,26 @@ def add_l2_weight_decay( def new_loss( enn: base.EpistemicNetwork[base.Input, base.Output], - params: hk.Params, state: hk.State, batch: base.Data, - key: chex.PRNGKey) -> base.LossOutput: - loss, (state, metrics) = loss_fn(enn, params, state, batch, key) + params: hk.Params, + state: hk.State, + batch: base.Data, + key: chex.PRNGKey, + ) -> base.LossOutput: + loss_key, params_key = jax.random.split(key) + loss, (state, metrics) = loss_fn(enn, params, state, batch, loss_key) + if regularize_towards_random_weights: + random_params = jax.tree_util.tree_map( + lambda x: jax.random.normal(params_key, x.shape), params + ) + params = jax.tree_util.tree_map( + lambda p, rp: p - rp, params, random_params + ) decay = l2_weights_with_predicate(scale_fn(params), predicate) - total_loss = loss + decay + total_loss = loss + decay metrics['decay'] = decay metrics['raw_loss'] = loss return total_loss, (state, metrics) + return new_loss