From 8a1c546afc74f0b3147a3fd721f384d706666876 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Fri, 7 Feb 2025 14:39:28 -0800 Subject: [PATCH 1/3] batch batch norm --- .github/workflows/run_tests.yml | 2 +- equinox/nn/_batch_norm.py | 187 ++++++++++++++++++++++++++------ tests/test_nn.py | 57 +++++++--- tests/test_serialisation.py | 8 +- tests/test_stateful.py | 2 +- 5 files changed, 203 insertions(+), 53 deletions(-) diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 0934f36b..806ae013 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -27,7 +27,7 @@ jobs: python -m pip install -r ./tests/requirements.txt - name: Checks with pre-commit - uses: pre-commit/action@v2.0.3 + uses: pre-commit/action@v3.0.1 - name: Test with pytest run: | diff --git a/equinox/nn/_batch_norm.py b/equinox/nn/_batch_norm.py index eec94a73..8279ce40 100644 --- a/equinox/nn/_batch_norm.py +++ b/equinox/nn/_batch_norm.py @@ -1,9 +1,11 @@ +import warnings from collections.abc import Hashable, Sequence +from typing import Literal import jax import jax.lax as lax import jax.numpy as jnp -from jaxtyping import Array, Bool, Float, PRNGKeyArray +from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray from .._misc import default_floating_dtype from .._module import field @@ -40,25 +42,70 @@ class BatchNorm(StatefulLayer, strict=True): statistics updated. During inference then just the running statistics are used. Whether the model is in training or inference mode should be toggled using [`equinox.nn.inference_mode`][]. + + With `mode = "batch"` during training the batch mean and variance are used + for normalization. For inference the exponential running mean and unbiased + variance are used for normalization. This is in line with how other machine + learning packages (e.g. PyTorch, flax, haiku) implement batch norm. + + With `mode = "ema"` exponential running means and variances are kept. During + training the batch statistics are used to fill in the running statistics until + they are populated. During inference the running statistics are used for + normalization. + + ??? cite + + [Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift](https://arxiv.org/abs/1502.03167) + + ```bibtex + @article{DBLP:journals/corr/IoffeS15, + author = {Sergey Ioffe and + Christian Szegedy}, + title = {Batch Normalization: Accelerating Deep Network Training + by Reducing Internal Covariate Shift}, + journal = {CoRR}, + volume = {abs/1502.03167}, + year = {2015}, + url = {http://arxiv.org/abs/1502.03167}, + eprinttype = {arXiv}, + eprint = {1502.03167}, + timestamp = {Mon, 13 Aug 2018 16:47:06 +0200}, + biburl = {https://dblp.org/rec/journals/corr/IoffeS15.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} + } + ``` """ # noqa: E501 weight: Float[Array, "input_size"] | None bias: Float[Array, "input_size"] | None - first_time_index: StateIndex[Bool[Array, ""]] - state_index: StateIndex[ - tuple[Float[Array, "input_size"], Float[Array, "input_size"]] - ] + ema_first_time_index: None | StateIndex[Bool[Array, ""]] + ema_state_index: ( + None | StateIndex[tuple[Float[Array, "input_size"], Float[Array, "input_size"]]] + ) + batch_counter: None | StateIndex[Int[Array, ""]] + batch_state_index: ( + None + | StateIndex[ + tuple[ + tuple[Float[Array, "input_size"], Float[Array, "input_size"]], + tuple[Float[Array, "input_size"], Float[Array, "input_size"]], + ], + ] + ) axis_name: Hashable | Sequence[Hashable] inference: bool input_size: int = field(static=True) eps: float = field(static=True) channelwise_affine: bool = field(static=True) momentum: float = field(static=True) + mode: Literal["ema", "batch"] = field(static=True) def __init__( self, input_size: int, axis_name: Hashable | Sequence[Hashable], + mode: Literal["ema", "batch", "legacy"] = "legacy", eps: float = 1e-5, channelwise_affine: bool = True, momentum: float = 0.99, @@ -71,6 +118,7 @@ def __init__( - `axis_name`: The name of the batch axis to compute statistics over, as passed to `axis_name` in `jax.vmap` or `jax.pmap`. Can also be a sequence (e.g. a tuple or a list) of names, to compute statistics over multiple named axes. + - `mode`: The variant of batch norm to use, either 'ema' or 'batch'. - `eps`: Value added to the denominator for numerical stability. - `channelwise_affine`: Whether the module has learnable channel-wise affine parameters. @@ -86,6 +134,16 @@ def __init__( `jax.numpy.float32` or `jax.numpy.float64` depending on whether JAX is in 64-bit mode. """ + if mode == "legacy": + mode = "ema" + warnings.warn( + "When mode is unspecified it defaults to 'ema'. This can have " + "substantial performance impacts, and the user is encouraged to " + "consider and pick which mode they need." + ) + if mode not in ("ema", "batch"): + raise ValueError("Invalid mode, must be 'ema' or 'batch'.") + self.mode = mode dtype = default_floating_dtype() if dtype is None else dtype if channelwise_affine: self.weight = jnp.ones((input_size,), dtype=dtype) @@ -93,12 +151,28 @@ def __init__( else: self.weight = None self.bias = None - self.first_time_index = StateIndex(jnp.array(True)) - init_buffers = ( - jnp.empty((input_size,), dtype=dtype), - jnp.empty((input_size,), dtype=dtype), - ) - self.state_index = StateIndex(init_buffers) + if mode == "ema": + self.ema_first_time_index = StateIndex(jnp.array(True)) + init_buffers = ( + jnp.empty((input_size,), dtype=dtype), + jnp.empty((input_size,), dtype=dtype), + ) + self.ema_state_index = StateIndex(init_buffers) + self.batch_counter = None + self.batch_state_index = None + else: + self.batch_counter = StateIndex(jnp.array(0)) + init_hidden = ( + jnp.zeros((input_size,), dtype=dtype), + jnp.ones((input_size,), dtype=dtype), + ) + init_avg = ( + jnp.zeros((input_size,), dtype=dtype), + jnp.ones((input_size,), dtype=dtype), + ) + self.batch_state_index = StateIndex((init_hidden, init_avg)) + self.ema_first_time_index = None + self.ema_state_index = None self.inference = inference self.axis_name = axis_name self.input_size = input_size @@ -138,32 +212,16 @@ def __call__( A `NameError` if no `vmap`s are placed around this operation, or if this vmap does not have a matching `axis_name`. """ - if inference is None: inference = self.inference - if inference: - running_mean, running_var = state.get(self.state_index) - else: - def _stats(y): - mean = jnp.mean(y) - mean = lax.pmean(mean, self.axis_name) - var = jnp.mean((y - mean) * jnp.conj(y - mean)) - var = lax.pmean(var, self.axis_name) - var = jnp.maximum(0.0, var) - return mean, var - - first_time = state.get(self.first_time_index) - state = state.set(self.first_time_index, jnp.array(False)) - - batch_mean, batch_var = jax.vmap(_stats)(x) - running_mean, running_var = state.get(self.state_index) - momentum = self.momentum - running_mean = (1 - momentum) * batch_mean + momentum * running_mean - running_var = (1 - momentum) * batch_var + momentum * running_var - running_mean = lax.select(first_time, batch_mean, running_mean) - running_var = lax.select(first_time, batch_var, running_var) - state = state.set(self.state_index, (running_mean, running_var)) + def _stats(y): + mean = jnp.mean(y) + mean = lax.pmean(mean, self.axis_name) + var = jnp.mean((y - mean) * jnp.conj(y - mean)) + var = lax.pmean(var, self.axis_name) + var = jnp.maximum(0.0, var) + return mean, var def _norm(y, m, v, w, b): out = (y - m) / jnp.sqrt(v + self.eps) @@ -171,5 +229,62 @@ def _norm(y, m, v, w, b): out = out * w + b return out - out = jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias) - return out, state + if self.mode == "ema": + assert ( + self.ema_first_time_index is not None + and self.ema_state_index is not None + ) + if inference: + running_mean, running_var = state.get(self.ema_state_index) + else: + first_time = state.get(self.ema_first_time_index) + state = state.set(self.ema_first_time_index, jnp.array(False)) + + batch_mean, batch_var = jax.vmap(_stats)(x) + running_mean, running_var = state.get(self.ema_state_index) + momentum = self.momentum + running_mean = (1 - momentum) * batch_mean + momentum * running_mean + running_var = (1 - momentum) * batch_var + momentum * running_var + # since jnp.array(0) == False + running_mean = lax.select(first_time, batch_mean, running_mean) + running_var = lax.select(first_time, batch_var, running_var) + state = state.set(self.ema_state_index, (running_mean, running_var)) + + out = jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias) + return out, state + else: + assert self.batch_state_index is not None and self.batch_counter is not None + if inference: + _, (mean, var) = state.get(self.batch_state_index) + else: + batch_mean, batch_var = jax.vmap(_stats)(x) + counter = state.get(self.batch_counter) + (hidden_mean, hidden_var), (running_mean, running_var) = state.get( + self.batch_state_index + ) + + decay = self.momentum + one = jnp.array(1.0, dtype=x.dtype) + + # Update hidden_{mean,var} + new_hidden_mean = hidden_mean * decay + batch_mean * (one - decay) + new_hidden_var = hidden_var * decay + batch_var * (one - decay) + + # Zero-debias approach: average_ = hidden_ / (1 - decay^counter) + # For simplicity we do the minimal version here (no warmup). + new_counter = counter + 1 + decay_power = decay**new_counter + new_running_mean = new_hidden_mean / (one - decay_power) + new_running_var = new_hidden_var / (one - decay_power) + + state = state.set(self.batch_counter, new_counter) + new_state_data = ( + (new_hidden_mean, new_hidden_var), + (new_running_mean, new_running_var), + ) + state = state.set(self.batch_state_index, new_state_data) + + mean, var = (batch_mean, batch_var) + + out = jax.vmap(_norm)(x, mean, var, self.weight, self.bias) + return out, state diff --git a/tests/test_nn.py b/tests/test_nn.py index a4b766cd..f255b5ca 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -142,7 +142,7 @@ def test_sequential(getkey): [ eqx.nn.Linear(2, 4, key=getkey()), eqx.nn.Linear(4, 1, key=getkey()), - eqx.nn.BatchNorm(1, axis_name="batch"), + eqx.nn.BatchNorm(1, axis_name="batch", mode="ema"), eqx.nn.Linear(1, 3, key=getkey()), ] ) @@ -176,7 +176,7 @@ def make(): inner_seq = eqx.nn.Sequential( [ eqx.nn.Linear(2, 4, key=getkey()), - eqx.nn.BatchNorm(4, axis_name="batch") + eqx.nn.BatchNorm(4, axis_name="batch", mode="ema") if inner_stateful else eqx.nn.Identity(), eqx.nn.Linear(4, 3, key=getkey()), @@ -186,7 +186,7 @@ def make(): [ eqx.nn.Linear(5, 2, key=getkey()), inner_seq, - eqx.nn.BatchNorm(3, axis_name="batch") + eqx.nn.BatchNorm(3, axis_name="batch", mode="ema") if outer_stateful else eqx.nn.Identity(), eqx.nn.Linear(3, 6, key=getkey()), @@ -949,7 +949,8 @@ def test_group_norm(getkey): gn = eqx.nn.GroupNorm(groups=4, channels=None, channelwise_affine=True) -def test_batch_norm(getkey): +@pytest.mark.parametrize("mode", ("ema", "batch")) +def test_batch_norm(getkey, mode): x0 = jrandom.uniform(getkey(), (5,)) x1 = jrandom.uniform(getkey(), (10, 5)) x2 = jrandom.uniform(getkey(), (10, 5, 6)) @@ -957,14 +958,19 @@ def test_batch_norm(getkey): # Test that it works with a single vmap'd axis_name - bn = eqx.nn.BatchNorm(5, "batch") + bn = eqx.nn.BatchNorm(5, "batch", mode=mode) state = eqx.nn.State(bn) vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None)) for x in (x1, x2, x3): out, state = vbn(x, state) assert out.shape == x.shape - running_mean, running_var = state.get(bn.state_index) + if mode == "ema": + assert bn.ema_state_index is not None + running_mean, running_var = state.get(bn.ema_state_index) + else: + assert bn.batch_state_index is not None + _, (running_mean, running_var) = state.get(bn.batch_state_index) assert running_mean.shape == (5,) assert running_var.shape == (5,) @@ -985,13 +991,18 @@ def test_batch_norm(getkey): in_axes=(0, None), )(x2, state) assert out.shape == x2.shape - running_mean, running_var = state.get(bn.state_index) + if mode == "ema": + assert bn.ema_state_index is not None + running_mean, running_var = state.get(bn.ema_state_index) + else: + assert bn.batch_state_index is not None + _, (running_mean, running_var) = state.get(bn.batch_state_index) assert running_mean.shape == (10, 5) assert running_var.shape == (10, 5) # Test that it handles multiple axis_names - vvbn = eqx.nn.BatchNorm(6, ("batch1", "batch2")) + vvbn = eqx.nn.BatchNorm(6, ("batch1", "batch2"), mode=mode) vvstate = eqx.nn.State(vvbn) for axis_name in ("batch1", "batch2"): vvbn = jax.vmap( @@ -999,14 +1010,19 @@ def test_batch_norm(getkey): ) out, out_vvstate = vvbn(x2, vvstate) assert out.shape == x2.shape - running_mean, running_var = out_vvstate.get(vvbn.state_index) + if mode == "ema": + assert vvbn.ema_state_index is not None + running_mean, running_var = out_vvstate.get(vvbn.ema_state_index) + else: + assert vvbn.batch_state_index is not None + _, (running_mean, running_var) = out_vvstate.get(vvbn.batch_state_index) assert running_mean.shape == (6,) assert running_var.shape == (6,) # Test that it normalises x1alt = jrandom.normal(jrandom.PRNGKey(5678), (10, 5)) # avoid flakey test - bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False) + bn = eqx.nn.BatchNorm(5, "batch", channelwise_affine=False, mode=mode) state = eqx.nn.State(bn) vbn = jax.vmap(bn, axis_name="batch", in_axes=(0, None), out_axes=(0, None)) out, state = vbn(x1alt, state) @@ -1017,9 +1033,19 @@ def test_batch_norm(getkey): # Test that the statistics update during training out, state = vbn(x1, state) - running_mean, running_var = state.get(bn.state_index) + if mode == "ema": + assert bn.ema_state_index is not None + running_mean, running_var = state.get(bn.ema_state_index) + else: + assert bn.batch_state_index is not None + _, (running_mean, running_var) = state.get(bn.batch_state_index) out, state = vbn(3 * x1 + 10, state) - running_mean2, running_var2 = state.get(bn.state_index) + if mode == "ema": + assert bn.ema_state_index is not None + running_mean2, running_var2 = state.get(bn.ema_state_index) + else: + assert bn.batch_state_index is not None + _, (running_mean2, running_var2) = state.get(bn.batch_state_index) assert not jnp.allclose(running_mean, running_mean2) assert not jnp.allclose(running_var, running_var2) @@ -1028,7 +1054,12 @@ def test_batch_norm(getkey): ibn = eqx.nn.inference_mode(bn, value=True) vibn = jax.vmap(ibn, axis_name="batch", in_axes=(0, None), out_axes=(0, None)) out, state = vibn(4 * x1 + 20, state) - running_mean3, running_var3 = state.get(bn.state_index) + if mode == "ema": + assert bn.ema_state_index is not None + running_mean3, running_var3 = state.get(bn.ema_state_index) + else: + assert bn.batch_state_index is not None + _, (running_mean3, running_var3) = state.get(bn.batch_state_index) assert jnp.array_equal(running_mean2, running_mean3) assert jnp.array_equal(running_var2, running_var3) diff --git a/tests/test_serialisation.py b/tests/test_serialisation.py index e89c243a..44b214ba 100644 --- a/tests/test_serialisation.py +++ b/tests/test_serialisation.py @@ -239,12 +239,16 @@ class Model(eqx.Module): norm1: eqx.nn.BatchNorm norm2: eqx.nn.BatchNorm - model = Model(eqx.nn.BatchNorm(3, "hi"), eqx.nn.BatchNorm(4, "bye")) + model = Model( + eqx.nn.BatchNorm(3, "hi", mode="ema"), eqx.nn.BatchNorm(4, "bye", mode="ema") + ) state = eqx.nn.State(model) eqx.tree_serialise_leaves(tmp_path, (model, state)) - model2 = Model(eqx.nn.BatchNorm(3, "hi"), eqx.nn.BatchNorm(4, "bye")) + model2 = Model( + eqx.nn.BatchNorm(3, "hi", mode="ema"), eqx.nn.BatchNorm(4, "bye", mode="ema") + ) state2 = eqx.nn.State(model2) eqx.tree_deserialise_leaves(tmp_path, (model2, state2)) diff --git a/tests/test_stateful.py b/tests/test_stateful.py index d7bc632a..cf57b5cd 100644 --- a/tests/test_stateful.py +++ b/tests/test_stateful.py @@ -7,7 +7,7 @@ def test_delete_init_state(): - model = eqx.nn.BatchNorm(3, "batch") + model = eqx.nn.BatchNorm(3, "batch", mode="ema") eqx.nn.State(model) model2 = eqx.nn.delete_init_state(model) From 6dd6e695976fbd47cdbbf2e33f1e55cf6c45c636 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 14 Apr 2025 22:10:28 +0200 Subject: [PATCH 2/3] Updates to BatchNorm: - Now tracking only the running statistics, not the zero-debiased statistics. These are handled at inference time instead. - Standarised bibtex formatting. - Moved `mode` argument to the end for backward compatibility. --- equinox/nn/_batch_norm.py | 121 +++++++++++++++----------------------- tests/test_nn.py | 12 ++-- 2 files changed, 54 insertions(+), 79 deletions(-) diff --git a/equinox/nn/_batch_norm.py b/equinox/nn/_batch_norm.py index 8279ce40..bb60b816 100644 --- a/equinox/nn/_batch_norm.py +++ b/equinox/nn/_batch_norm.py @@ -60,19 +60,18 @@ class BatchNorm(StatefulLayer, strict=True): ```bibtex @article{DBLP:journals/corr/IoffeS15, - author = {Sergey Ioffe and - Christian Szegedy}, - title = {Batch Normalization: Accelerating Deep Network Training - by Reducing Internal Covariate Shift}, - journal = {CoRR}, - volume = {abs/1502.03167}, - year = {2015}, - url = {http://arxiv.org/abs/1502.03167}, - eprinttype = {arXiv}, - eprint = {1502.03167}, - timestamp = {Mon, 13 Aug 2018 16:47:06 +0200}, - biburl = {https://dblp.org/rec/journals/corr/IoffeS15.bib}, - bibsource = {dblp computer science bibliography, https://dblp.org} + author = {Sergey Ioffe and Christian Szegedy}, + title = {Batch Normalization: Accelerating Deep Network Training by Reducing + Internal Covariate Shift}, + journal = {CoRR}, + volume = {abs/1502.03167}, + year = {2015}, + url = {http://arxiv.org/abs/1502.03167}, + eprinttype = {arXiv}, + eprint = {1502.03167}, + timestamp = {Mon, 13 Aug 2018 16:47:06 +0200}, + biburl = {https://dblp.org/rec/journals/corr/IoffeS15.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} } ``` """ # noqa: E501 @@ -85,13 +84,7 @@ class BatchNorm(StatefulLayer, strict=True): ) batch_counter: None | StateIndex[Int[Array, ""]] batch_state_index: ( - None - | StateIndex[ - tuple[ - tuple[Float[Array, "input_size"], Float[Array, "input_size"]], - tuple[Float[Array, "input_size"], Float[Array, "input_size"]], - ], - ] + None | StateIndex[tuple[Float[Array, "input_size"], Float[Array, "input_size"]]] ) axis_name: Hashable | Sequence[Hashable] inference: bool @@ -105,12 +98,12 @@ def __init__( self, input_size: int, axis_name: Hashable | Sequence[Hashable], - mode: Literal["ema", "batch", "legacy"] = "legacy", eps: float = 1e-5, channelwise_affine: bool = True, momentum: float = 0.99, inference: bool = False, dtype=None, + mode: Literal["ema", "batch", "legacy"] = "legacy", ): """**Arguments:** @@ -118,7 +111,6 @@ def __init__( - `axis_name`: The name of the batch axis to compute statistics over, as passed to `axis_name` in `jax.vmap` or `jax.pmap`. Can also be a sequence (e.g. a tuple or a list) of names, to compute statistics over multiple named axes. - - `mode`: The variant of batch norm to use, either 'ema' or 'batch'. - `eps`: Value added to the denominator for numerical stability. - `channelwise_affine`: Whether the module has learnable channel-wise affine parameters. @@ -133,15 +125,17 @@ def __init__( if `channelwise_affine` is `True`. Defaults to either `jax.numpy.float32` or `jax.numpy.float64` depending on whether JAX is in 64-bit mode. + - `mode`: The variant of batch norm to use, either 'ema' or 'batch'. """ if mode == "legacy": mode = "ema" warnings.warn( - "When mode is unspecified it defaults to 'ema'. This can have " - "substantial performance impacts, and the user is encouraged to " - "consider and pick which mode they need." + "When `eqx.nn.BatchNorm(..., mode=...)` is unspecified it defaults to " + "'ema', for backward compatibility. This typically has a performance " + "impact, and for new code the user is encouraged to use 'batch' " + "instead. See `https://github.com/patrick-kidger/equinox/issues/659`." ) - if mode not in ("ema", "batch"): + if mode not in {"ema", "batch"}: raise ValueError("Invalid mode, must be 'ema' or 'batch'.") self.mode = mode dtype = default_floating_dtype() if dtype is None else dtype @@ -166,11 +160,7 @@ def __init__( jnp.zeros((input_size,), dtype=dtype), jnp.ones((input_size,), dtype=dtype), ) - init_avg = ( - jnp.zeros((input_size,), dtype=dtype), - jnp.ones((input_size,), dtype=dtype), - ) - self.batch_state_index = StateIndex((init_hidden, init_avg)) + self.batch_state_index = StateIndex(init_hidden) self.ema_first_time_index = None self.ema_state_index = None self.inference = inference @@ -212,6 +202,8 @@ def __call__( A `NameError` if no `vmap`s are placed around this operation, or if this vmap does not have a matching `axis_name`. """ + del key + if inference is None: inference = self.inference @@ -230,61 +222,44 @@ def _norm(y, m, v, w, b): return out if self.mode == "ema": - assert ( - self.ema_first_time_index is not None - and self.ema_state_index is not None - ) + assert self.ema_first_time_index is not None + assert self.ema_state_index is not None if inference: - running_mean, running_var = state.get(self.ema_state_index) + mean, var = state.get(self.ema_state_index) else: first_time = state.get(self.ema_first_time_index) state = state.set(self.ema_first_time_index, jnp.array(False)) - batch_mean, batch_var = jax.vmap(_stats)(x) running_mean, running_var = state.get(self.ema_state_index) momentum = self.momentum - running_mean = (1 - momentum) * batch_mean + momentum * running_mean - running_var = (1 - momentum) * batch_var + momentum * running_var + mean = (1 - momentum) * batch_mean + momentum * running_mean + var = (1 - momentum) * batch_var + momentum * running_var # since jnp.array(0) == False - running_mean = lax.select(first_time, batch_mean, running_mean) - running_var = lax.select(first_time, batch_var, running_var) - state = state.set(self.ema_state_index, (running_mean, running_var)) - - out = jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias) - return out, state + mean = lax.select(first_time, batch_mean, mean) + var = lax.select(first_time, batch_var, var) + state = state.set(self.ema_state_index, (mean, var)) else: - assert self.batch_state_index is not None and self.batch_counter is not None + assert self.batch_state_index is not None + assert self.batch_counter is not None + counter = state.get(self.batch_counter) + hidden_mean, hidden_var = state.get(self.batch_state_index) if inference: - _, (mean, var) = state.get(self.batch_state_index) - else: - batch_mean, batch_var = jax.vmap(_stats)(x) - counter = state.get(self.batch_counter) - (hidden_mean, hidden_var), (running_mean, running_var) = state.get( - self.batch_state_index - ) - - decay = self.momentum - one = jnp.array(1.0, dtype=x.dtype) - - # Update hidden_{mean,var} - new_hidden_mean = hidden_mean * decay + batch_mean * (one - decay) - new_hidden_var = hidden_var * decay + batch_var * (one - decay) - # Zero-debias approach: average_ = hidden_ / (1 - decay^counter) # For simplicity we do the minimal version here (no warmup). + scale = 1 - self.momentum**counter + mean = hidden_mean / scale + var = hidden_var / scale + else: + mean, var = jax.vmap(_stats)(x) new_counter = counter + 1 - decay_power = decay**new_counter - new_running_mean = new_hidden_mean / (one - decay_power) - new_running_var = new_hidden_var / (one - decay_power) - + new_hidden_mean = hidden_mean * self.momentum + mean * ( + 1 - self.momentum + ) + new_hidden_var = hidden_var * self.momentum + var * (1 - self.momentum) state = state.set(self.batch_counter, new_counter) - new_state_data = ( - (new_hidden_mean, new_hidden_var), - (new_running_mean, new_running_var), + state = state.set( + self.batch_state_index, (new_hidden_mean, new_hidden_var) ) - state = state.set(self.batch_state_index, new_state_data) - - mean, var = (batch_mean, batch_var) - out = jax.vmap(_norm)(x, mean, var, self.weight, self.bias) - return out, state + out = jax.vmap(_norm)(x, mean, var, self.weight, self.bias) + return out, state diff --git a/tests/test_nn.py b/tests/test_nn.py index f255b5ca..210f39f2 100644 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -970,7 +970,7 @@ def test_batch_norm(getkey, mode): running_mean, running_var = state.get(bn.ema_state_index) else: assert bn.batch_state_index is not None - _, (running_mean, running_var) = state.get(bn.batch_state_index) + running_mean, running_var = state.get(bn.batch_state_index) assert running_mean.shape == (5,) assert running_var.shape == (5,) @@ -996,7 +996,7 @@ def test_batch_norm(getkey, mode): running_mean, running_var = state.get(bn.ema_state_index) else: assert bn.batch_state_index is not None - _, (running_mean, running_var) = state.get(bn.batch_state_index) + running_mean, running_var = state.get(bn.batch_state_index) assert running_mean.shape == (10, 5) assert running_var.shape == (10, 5) @@ -1015,7 +1015,7 @@ def test_batch_norm(getkey, mode): running_mean, running_var = out_vvstate.get(vvbn.ema_state_index) else: assert vvbn.batch_state_index is not None - _, (running_mean, running_var) = out_vvstate.get(vvbn.batch_state_index) + running_mean, running_var = out_vvstate.get(vvbn.batch_state_index) assert running_mean.shape == (6,) assert running_var.shape == (6,) @@ -1038,14 +1038,14 @@ def test_batch_norm(getkey, mode): running_mean, running_var = state.get(bn.ema_state_index) else: assert bn.batch_state_index is not None - _, (running_mean, running_var) = state.get(bn.batch_state_index) + running_mean, running_var = state.get(bn.batch_state_index) out, state = vbn(3 * x1 + 10, state) if mode == "ema": assert bn.ema_state_index is not None running_mean2, running_var2 = state.get(bn.ema_state_index) else: assert bn.batch_state_index is not None - _, (running_mean2, running_var2) = state.get(bn.batch_state_index) + running_mean2, running_var2 = state.get(bn.batch_state_index) assert not jnp.allclose(running_mean, running_mean2) assert not jnp.allclose(running_var, running_var2) @@ -1059,7 +1059,7 @@ def test_batch_norm(getkey, mode): running_mean3, running_var3 = state.get(bn.ema_state_index) else: assert bn.batch_state_index is not None - _, (running_mean3, running_var3) = state.get(bn.batch_state_index) + running_mean3, running_var3 = state.get(bn.batch_state_index) assert jnp.array_equal(running_mean2, running_mean3) assert jnp.array_equal(running_var2, running_var3) From 7dd5efb5674bcc656414afcabb5ea2306668292a Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Mon, 14 Apr 2025 13:44:13 -0700 Subject: [PATCH 3/3] comment --- equinox/nn/_batch_norm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/equinox/nn/_batch_norm.py b/equinox/nn/_batch_norm.py index bb60b816..35ee733e 100644 --- a/equinox/nn/_batch_norm.py +++ b/equinox/nn/_batch_norm.py @@ -244,7 +244,7 @@ def _norm(y, m, v, w, b): counter = state.get(self.batch_counter) hidden_mean, hidden_var = state.get(self.batch_state_index) if inference: - # Zero-debias approach: average_ = hidden_ / (1 - decay^counter) + # Zero-debias approach: mean = hidden_mean / (1 - momentum^counter) # For simplicity we do the minimal version here (no warmup). scale = 1 - self.momentum**counter mean = hidden_mean / scale