Skip to content

Commit 6dd6e69

Browse files
Updates to BatchNorm:
- Now tracking only the running statistics, not the zero-debiased statistics. These are handled at inference time instead. - Standarised bibtex formatting. - Moved `mode` argument to the end for backward compatibility.
1 parent 8a1c546 commit 6dd6e69

2 files changed

Lines changed: 54 additions & 79 deletions

File tree

equinox/nn/_batch_norm.py

Lines changed: 48 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,18 @@ class BatchNorm(StatefulLayer, strict=True):
6060
6161
```bibtex
6262
@article{DBLP:journals/corr/IoffeS15,
63-
author = {Sergey Ioffe and
64-
Christian Szegedy},
65-
title = {Batch Normalization: Accelerating Deep Network Training
66-
by Reducing Internal Covariate Shift},
67-
journal = {CoRR},
68-
volume = {abs/1502.03167},
69-
year = {2015},
70-
url = {http://arxiv.org/abs/1502.03167},
71-
eprinttype = {arXiv},
72-
eprint = {1502.03167},
73-
timestamp = {Mon, 13 Aug 2018 16:47:06 +0200},
74-
biburl = {https://dblp.org/rec/journals/corr/IoffeS15.bib},
75-
bibsource = {dblp computer science bibliography, https://dblp.org}
63+
author = {Sergey Ioffe and Christian Szegedy},
64+
title = {Batch Normalization: Accelerating Deep Network Training by Reducing
65+
Internal Covariate Shift},
66+
journal = {CoRR},
67+
volume = {abs/1502.03167},
68+
year = {2015},
69+
url = {http://arxiv.org/abs/1502.03167},
70+
eprinttype = {arXiv},
71+
eprint = {1502.03167},
72+
timestamp = {Mon, 13 Aug 2018 16:47:06 +0200},
73+
biburl = {https://dblp.org/rec/journals/corr/IoffeS15.bib},
74+
bibsource = {dblp computer science bibliography, https://dblp.org}
7675
}
7776
```
7877
""" # noqa: E501
@@ -85,13 +84,7 @@ class BatchNorm(StatefulLayer, strict=True):
8584
)
8685
batch_counter: None | StateIndex[Int[Array, ""]]
8786
batch_state_index: (
88-
None
89-
| StateIndex[
90-
tuple[
91-
tuple[Float[Array, "input_size"], Float[Array, "input_size"]],
92-
tuple[Float[Array, "input_size"], Float[Array, "input_size"]],
93-
],
94-
]
87+
None | StateIndex[tuple[Float[Array, "input_size"], Float[Array, "input_size"]]]
9588
)
9689
axis_name: Hashable | Sequence[Hashable]
9790
inference: bool
@@ -105,20 +98,19 @@ def __init__(
10598
self,
10699
input_size: int,
107100
axis_name: Hashable | Sequence[Hashable],
108-
mode: Literal["ema", "batch", "legacy"] = "legacy",
109101
eps: float = 1e-5,
110102
channelwise_affine: bool = True,
111103
momentum: float = 0.99,
112104
inference: bool = False,
113105
dtype=None,
106+
mode: Literal["ema", "batch", "legacy"] = "legacy",
114107
):
115108
"""**Arguments:**
116109
117110
- `input_size`: The number of channels in the input array.
118111
- `axis_name`: The name of the batch axis to compute statistics over, as passed
119112
to `axis_name` in `jax.vmap` or `jax.pmap`. Can also be a sequence (e.g. a
120113
tuple or a list) of names, to compute statistics over multiple named axes.
121-
- `mode`: The variant of batch norm to use, either 'ema' or 'batch'.
122114
- `eps`: Value added to the denominator for numerical stability.
123115
- `channelwise_affine`: Whether the module has learnable channel-wise affine
124116
parameters.
@@ -133,15 +125,17 @@ def __init__(
133125
if `channelwise_affine` is `True`. Defaults to either
134126
`jax.numpy.float32` or `jax.numpy.float64` depending on whether JAX is in
135127
64-bit mode.
128+
- `mode`: The variant of batch norm to use, either 'ema' or 'batch'.
136129
"""
137130
if mode == "legacy":
138131
mode = "ema"
139132
warnings.warn(
140-
"When mode is unspecified it defaults to 'ema'. This can have "
141-
"substantial performance impacts, and the user is encouraged to "
142-
"consider and pick which mode they need."
133+
"When `eqx.nn.BatchNorm(..., mode=...)` is unspecified it defaults to "
134+
"'ema', for backward compatibility. This typically has a performance "
135+
"impact, and for new code the user is encouraged to use 'batch' "
136+
"instead. See `https://github.com/patrick-kidger/equinox/issues/659`."
143137
)
144-
if mode not in ("ema", "batch"):
138+
if mode not in {"ema", "batch"}:
145139
raise ValueError("Invalid mode, must be 'ema' or 'batch'.")
146140
self.mode = mode
147141
dtype = default_floating_dtype() if dtype is None else dtype
@@ -166,11 +160,7 @@ def __init__(
166160
jnp.zeros((input_size,), dtype=dtype),
167161
jnp.ones((input_size,), dtype=dtype),
168162
)
169-
init_avg = (
170-
jnp.zeros((input_size,), dtype=dtype),
171-
jnp.ones((input_size,), dtype=dtype),
172-
)
173-
self.batch_state_index = StateIndex((init_hidden, init_avg))
163+
self.batch_state_index = StateIndex(init_hidden)
174164
self.ema_first_time_index = None
175165
self.ema_state_index = None
176166
self.inference = inference
@@ -212,6 +202,8 @@ def __call__(
212202
A `NameError` if no `vmap`s are placed around this operation, or if this vmap
213203
does not have a matching `axis_name`.
214204
"""
205+
del key
206+
215207
if inference is None:
216208
inference = self.inference
217209

@@ -230,61 +222,44 @@ def _norm(y, m, v, w, b):
230222
return out
231223

232224
if self.mode == "ema":
233-
assert (
234-
self.ema_first_time_index is not None
235-
and self.ema_state_index is not None
236-
)
225+
assert self.ema_first_time_index is not None
226+
assert self.ema_state_index is not None
237227
if inference:
238-
running_mean, running_var = state.get(self.ema_state_index)
228+
mean, var = state.get(self.ema_state_index)
239229
else:
240230
first_time = state.get(self.ema_first_time_index)
241231
state = state.set(self.ema_first_time_index, jnp.array(False))
242-
243232
batch_mean, batch_var = jax.vmap(_stats)(x)
244233
running_mean, running_var = state.get(self.ema_state_index)
245234
momentum = self.momentum
246-
running_mean = (1 - momentum) * batch_mean + momentum * running_mean
247-
running_var = (1 - momentum) * batch_var + momentum * running_var
235+
mean = (1 - momentum) * batch_mean + momentum * running_mean
236+
var = (1 - momentum) * batch_var + momentum * running_var
248237
# since jnp.array(0) == False
249-
running_mean = lax.select(first_time, batch_mean, running_mean)
250-
running_var = lax.select(first_time, batch_var, running_var)
251-
state = state.set(self.ema_state_index, (running_mean, running_var))
252-
253-
out = jax.vmap(_norm)(x, running_mean, running_var, self.weight, self.bias)
254-
return out, state
238+
mean = lax.select(first_time, batch_mean, mean)
239+
var = lax.select(first_time, batch_var, var)
240+
state = state.set(self.ema_state_index, (mean, var))
255241
else:
256-
assert self.batch_state_index is not None and self.batch_counter is not None
242+
assert self.batch_state_index is not None
243+
assert self.batch_counter is not None
244+
counter = state.get(self.batch_counter)
245+
hidden_mean, hidden_var = state.get(self.batch_state_index)
257246
if inference:
258-
_, (mean, var) = state.get(self.batch_state_index)
259-
else:
260-
batch_mean, batch_var = jax.vmap(_stats)(x)
261-
counter = state.get(self.batch_counter)
262-
(hidden_mean, hidden_var), (running_mean, running_var) = state.get(
263-
self.batch_state_index
264-
)
265-
266-
decay = self.momentum
267-
one = jnp.array(1.0, dtype=x.dtype)
268-
269-
# Update hidden_{mean,var}
270-
new_hidden_mean = hidden_mean * decay + batch_mean * (one - decay)
271-
new_hidden_var = hidden_var * decay + batch_var * (one - decay)
272-
273247
# Zero-debias approach: average_ = hidden_ / (1 - decay^counter)
274248
# For simplicity we do the minimal version here (no warmup).
249+
scale = 1 - self.momentum**counter
250+
mean = hidden_mean / scale
251+
var = hidden_var / scale
252+
else:
253+
mean, var = jax.vmap(_stats)(x)
275254
new_counter = counter + 1
276-
decay_power = decay**new_counter
277-
new_running_mean = new_hidden_mean / (one - decay_power)
278-
new_running_var = new_hidden_var / (one - decay_power)
279-
255+
new_hidden_mean = hidden_mean * self.momentum + mean * (
256+
1 - self.momentum
257+
)
258+
new_hidden_var = hidden_var * self.momentum + var * (1 - self.momentum)
280259
state = state.set(self.batch_counter, new_counter)
281-
new_state_data = (
282-
(new_hidden_mean, new_hidden_var),
283-
(new_running_mean, new_running_var),
260+
state = state.set(
261+
self.batch_state_index, (new_hidden_mean, new_hidden_var)
284262
)
285-
state = state.set(self.batch_state_index, new_state_data)
286-
287-
mean, var = (batch_mean, batch_var)
288263

289-
out = jax.vmap(_norm)(x, mean, var, self.weight, self.bias)
290-
return out, state
264+
out = jax.vmap(_norm)(x, mean, var, self.weight, self.bias)
265+
return out, state

tests/test_nn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -970,7 +970,7 @@ def test_batch_norm(getkey, mode):
970970
running_mean, running_var = state.get(bn.ema_state_index)
971971
else:
972972
assert bn.batch_state_index is not None
973-
_, (running_mean, running_var) = state.get(bn.batch_state_index)
973+
running_mean, running_var = state.get(bn.batch_state_index)
974974
assert running_mean.shape == (5,)
975975
assert running_var.shape == (5,)
976976

@@ -996,7 +996,7 @@ def test_batch_norm(getkey, mode):
996996
running_mean, running_var = state.get(bn.ema_state_index)
997997
else:
998998
assert bn.batch_state_index is not None
999-
_, (running_mean, running_var) = state.get(bn.batch_state_index)
999+
running_mean, running_var = state.get(bn.batch_state_index)
10001000
assert running_mean.shape == (10, 5)
10011001
assert running_var.shape == (10, 5)
10021002

@@ -1015,7 +1015,7 @@ def test_batch_norm(getkey, mode):
10151015
running_mean, running_var = out_vvstate.get(vvbn.ema_state_index)
10161016
else:
10171017
assert vvbn.batch_state_index is not None
1018-
_, (running_mean, running_var) = out_vvstate.get(vvbn.batch_state_index)
1018+
running_mean, running_var = out_vvstate.get(vvbn.batch_state_index)
10191019
assert running_mean.shape == (6,)
10201020
assert running_var.shape == (6,)
10211021

@@ -1038,14 +1038,14 @@ def test_batch_norm(getkey, mode):
10381038
running_mean, running_var = state.get(bn.ema_state_index)
10391039
else:
10401040
assert bn.batch_state_index is not None
1041-
_, (running_mean, running_var) = state.get(bn.batch_state_index)
1041+
running_mean, running_var = state.get(bn.batch_state_index)
10421042
out, state = vbn(3 * x1 + 10, state)
10431043
if mode == "ema":
10441044
assert bn.ema_state_index is not None
10451045
running_mean2, running_var2 = state.get(bn.ema_state_index)
10461046
else:
10471047
assert bn.batch_state_index is not None
1048-
_, (running_mean2, running_var2) = state.get(bn.batch_state_index)
1048+
running_mean2, running_var2 = state.get(bn.batch_state_index)
10491049
assert not jnp.allclose(running_mean, running_mean2)
10501050
assert not jnp.allclose(running_var, running_var2)
10511051

@@ -1059,7 +1059,7 @@ def test_batch_norm(getkey, mode):
10591059
running_mean3, running_var3 = state.get(bn.ema_state_index)
10601060
else:
10611061
assert bn.batch_state_index is not None
1062-
_, (running_mean3, running_var3) = state.get(bn.batch_state_index)
1062+
running_mean3, running_var3 = state.get(bn.batch_state_index)
10631063
assert jnp.array_equal(running_mean2, running_mean3)
10641064
assert jnp.array_equal(running_var2, running_var3)
10651065

0 commit comments

Comments
 (0)