diff --git a/enn/losses/base.py b/enn/losses/base.py index c027729..6934b99 100644 --- a/enn/losses/base.py +++ b/enn/losses/base.py @@ -83,7 +83,7 @@ def loss_fn(enn: base.EpistemicNetwork[base.Input, base.Output], # index. We choose to average the state across epistemic indices and # then perform basic error checking to make sure the shape is unchanged. new_state = jax.tree_map(batch_mean, new_state) - jax.tree_multimap( + jax.tree_map( lambda x, y: chex.assert_equal_shape([x, y]), new_state, state) mean_metrics = jax.tree_map(batch_mean, metrics) diff --git a/enn/metrics/calibration.py b/enn/metrics/calibration.py index 26b24fb..6d9b6e9 100644 --- a/enn/metrics/calibration.py +++ b/enn/metrics/calibration.py @@ -103,7 +103,7 @@ def __call__( state = self._get_init_stats() # Update state - new_stats = jax.tree_multimap(jnp.add, state.extra, batch_stats) + new_stats = jax.tree_map(jnp.add, state.extra, batch_stats) new_count = state.count + 1 new_value = _map_stats_to_ece(new_stats) return metrics_base.MetricsState( diff --git a/enn/networks/ensembles.py b/enn/networks/ensembles.py index 9f016b5..c8bd597 100644 --- a/enn/networks/ensembles.py +++ b/enn/networks/ensembles.py @@ -81,7 +81,7 @@ def apply(params: hk.Params, states: hk.State, inputs: chex.Array, sub_params = jax.tree_map(particle_selector, params) sub_states = jax.tree_map(particle_selector, states) out, new_sub_states = model.apply(sub_params, sub_states, inputs) - new_states = jax.tree_multimap( + new_states = jax.tree_map( lambda s, nss: s.at[index, ...].set(nss), states, new_sub_states) return out, new_states diff --git a/enn/networks/hypermodels.py b/enn/networks/hypermodels.py index 3b30850..8fe408f 100644 --- a/enn/networks/hypermodels.py +++ b/enn/networks/hypermodels.py @@ -156,7 +156,7 @@ def hyper_fn(inputs: chex.Array, flat_output = jax.tree_map(lambda layer: layer(hyper_index), final_layers) # Reshape this flattened output to the original base shapes (unflatten) - generated_params = jax.tree_multimap(jnp.reshape, flat_output, base_shapes) + generated_params = jax.tree_map(jnp.reshape, flat_output, base_shapes) if scale: # Scale the generated params such that expected variance of the raw