diff --git a/enn/losses/utils.py b/enn/losses/utils.py index cf87c22..431b4ca 100644 --- a/enn/losses/utils.py +++ b/enn/losses/utils.py @@ -68,7 +68,8 @@ def noisy_loss( def add_l2_weight_decay( loss_fn: _LossFn, scale: Union[float, Callable[[hk.Params], hk.Params]], - predicate: Optional[PredicateFn] = None + predicate: Optional[PredicateFn] = None, + regularize_towards_random_weights: bool = False, ) -> _LossFn: """Adds scale * l2 weight decay to an existing loss function.""" try: # Scale is numeric. @@ -79,14 +80,26 @@ def add_l2_weight_decay( def new_loss( enn: base.EpistemicNetwork[base.Input, base.Output], - params: hk.Params, state: hk.State, batch: base.Data, - key: chex.PRNGKey) -> base.LossOutput: - loss, (state, metrics) = loss_fn(enn, params, state, batch, key) + params: hk.Params, + state: hk.State, + batch: base.Data, + key: chex.PRNGKey, + ) -> base.LossOutput: + loss_key, params_key = jax.random.split(key) + loss, (state, metrics) = loss_fn(enn, params, state, batch, loss_key) + if regularize_towards_random_weights: + random_params = jax.tree_util.tree_map( + lambda x: jax.random.normal(params_key, x.shape), params + ) + params = jax.tree_util.tree_map( + lambda p, rp: p - rp, params, random_params + ) decay = l2_weights_with_predicate(scale_fn(params), predicate) - total_loss = loss + decay + total_loss = loss + decay metrics['decay'] = decay metrics['raw_loss'] = loss return total_loss, (state, metrics) + return new_loss