@@ -60,19 +60,18 @@ class BatchNorm(StatefulLayer, strict=True):
6060
6161 ```bibtex
6262 @article{DBLP:journals/corr/IoffeS15,
63- author = {Sergey Ioffe and
64- Christian Szegedy},
65- title = {Batch Normalization: Accelerating Deep Network Training
66- by Reducing Internal Covariate Shift},
67- journal = {CoRR},
68- volume = {abs/1502.03167},
69- year = {2015},
70- url = {http://arxiv.org/abs/1502.03167},
71- eprinttype = {arXiv},
72- eprint = {1502.03167},
73- timestamp = {Mon, 13 Aug 2018 16:47:06 +0200},
74- biburl = {https://dblp.org/rec/journals/corr/IoffeS15.bib},
75- bibsource = {dblp computer science bibliography, https://dblp.org}
63+ author = {Sergey Ioffe and Christian Szegedy},
64+ title = {Batch Normalization: Accelerating Deep Network Training by Reducing
65+ Internal Covariate Shift},
66+ journal = {CoRR},
67+ volume = {abs/1502.03167},
68+ year = {2015},
69+ url = {http://arxiv.org/abs/1502.03167},
70+ eprinttype = {arXiv},
71+ eprint = {1502.03167},
72+ timestamp = {Mon, 13 Aug 2018 16:47:06 +0200},
73+ biburl = {https://dblp.org/rec/journals/corr/IoffeS15.bib},
74+ bibsource = {dblp computer science bibliography, https://dblp.org}
7675 }
7776 ```
7877 """ # noqa: E501
@@ -85,13 +84,7 @@ class BatchNorm(StatefulLayer, strict=True):
8584 )
8685 batch_counter : None | StateIndex [Int [Array , "" ]]
8786 batch_state_index : (
88- None
89- | StateIndex [
90- tuple [
91- tuple [Float [Array , "input_size" ], Float [Array , "input_size" ]],
92- tuple [Float [Array , "input_size" ], Float [Array , "input_size" ]],
93- ],
94- ]
87+ None | StateIndex [tuple [Float [Array , "input_size" ], Float [Array , "input_size" ]]]
9588 )
9689 axis_name : Hashable | Sequence [Hashable ]
9790 inference : bool
@@ -105,20 +98,19 @@ def __init__(
10598 self ,
10699 input_size : int ,
107100 axis_name : Hashable | Sequence [Hashable ],
108- mode : Literal ["ema" , "batch" , "legacy" ] = "legacy" ,
109101 eps : float = 1e-5 ,
110102 channelwise_affine : bool = True ,
111103 momentum : float = 0.99 ,
112104 inference : bool = False ,
113105 dtype = None ,
106+ mode : Literal ["ema" , "batch" , "legacy" ] = "legacy" ,
114107 ):
115108 """**Arguments:**
116109
117110 - `input_size`: The number of channels in the input array.
118111 - `axis_name`: The name of the batch axis to compute statistics over, as passed
119112 to `axis_name` in `jax.vmap` or `jax.pmap`. Can also be a sequence (e.g. a
120113 tuple or a list) of names, to compute statistics over multiple named axes.
121- - `mode`: The variant of batch norm to use, either 'ema' or 'batch'.
122114 - `eps`: Value added to the denominator for numerical stability.
123115 - `channelwise_affine`: Whether the module has learnable channel-wise affine
124116 parameters.
@@ -133,15 +125,17 @@ def __init__(
133125 if `channelwise_affine` is `True`. Defaults to either
134126 `jax.numpy.float32` or `jax.numpy.float64` depending on whether JAX is in
135127 64-bit mode.
128+ - `mode`: The variant of batch norm to use, either 'ema' or 'batch'.
136129 """
137130 if mode == "legacy" :
138131 mode = "ema"
139132 warnings .warn (
140- "When mode is unspecified it defaults to 'ema'. This can have "
141- "substantial performance impacts, and the user is encouraged to "
142- "consider and pick which mode they need."
133+ "When `eqx.nn.BatchNorm(..., mode=...)` is unspecified it defaults to "
134+ "'ema', for backward compatibility. This typically has a performance "
135+ "impact, and for new code the user is encouraged to use 'batch' "
136+ "instead. See `https://github.com/patrick-kidger/equinox/issues/659`."
143137 )
144- if mode not in ( "ema" , "batch" ) :
138+ if mode not in { "ema" , "batch" } :
145139 raise ValueError ("Invalid mode, must be 'ema' or 'batch'." )
146140 self .mode = mode
147141 dtype = default_floating_dtype () if dtype is None else dtype
@@ -166,11 +160,7 @@ def __init__(
166160 jnp .zeros ((input_size ,), dtype = dtype ),
167161 jnp .ones ((input_size ,), dtype = dtype ),
168162 )
169- init_avg = (
170- jnp .zeros ((input_size ,), dtype = dtype ),
171- jnp .ones ((input_size ,), dtype = dtype ),
172- )
173- self .batch_state_index = StateIndex ((init_hidden , init_avg ))
163+ self .batch_state_index = StateIndex (init_hidden )
174164 self .ema_first_time_index = None
175165 self .ema_state_index = None
176166 self .inference = inference
@@ -212,6 +202,8 @@ def __call__(
212202 A `NameError` if no `vmap`s are placed around this operation, or if this vmap
213203 does not have a matching `axis_name`.
214204 """
205+ del key
206+
215207 if inference is None :
216208 inference = self .inference
217209
@@ -230,61 +222,44 @@ def _norm(y, m, v, w, b):
230222 return out
231223
232224 if self .mode == "ema" :
233- assert (
234- self .ema_first_time_index is not None
235- and self .ema_state_index is not None
236- )
225+ assert self .ema_first_time_index is not None
226+ assert self .ema_state_index is not None
237227 if inference :
238- running_mean , running_var = state .get (self .ema_state_index )
228+ mean , var = state .get (self .ema_state_index )
239229 else :
240230 first_time = state .get (self .ema_first_time_index )
241231 state = state .set (self .ema_first_time_index , jnp .array (False ))
242-
243232 batch_mean , batch_var = jax .vmap (_stats )(x )
244233 running_mean , running_var = state .get (self .ema_state_index )
245234 momentum = self .momentum
246- running_mean = (1 - momentum ) * batch_mean + momentum * running_mean
247- running_var = (1 - momentum ) * batch_var + momentum * running_var
235+ mean = (1 - momentum ) * batch_mean + momentum * running_mean
236+ var = (1 - momentum ) * batch_var + momentum * running_var
248237 # since jnp.array(0) == False
249- running_mean = lax .select (first_time , batch_mean , running_mean )
250- running_var = lax .select (first_time , batch_var , running_var )
251- state = state .set (self .ema_state_index , (running_mean , running_var ))
252-
253- out = jax .vmap (_norm )(x , running_mean , running_var , self .weight , self .bias )
254- return out , state
238+ mean = lax .select (first_time , batch_mean , mean )
239+ var = lax .select (first_time , batch_var , var )
240+ state = state .set (self .ema_state_index , (mean , var ))
255241 else :
256- assert self .batch_state_index is not None and self .batch_counter is not None
242+ assert self .batch_state_index is not None
243+ assert self .batch_counter is not None
244+ counter = state .get (self .batch_counter )
245+ hidden_mean , hidden_var = state .get (self .batch_state_index )
257246 if inference :
258- _ , (mean , var ) = state .get (self .batch_state_index )
259- else :
260- batch_mean , batch_var = jax .vmap (_stats )(x )
261- counter = state .get (self .batch_counter )
262- (hidden_mean , hidden_var ), (running_mean , running_var ) = state .get (
263- self .batch_state_index
264- )
265-
266- decay = self .momentum
267- one = jnp .array (1.0 , dtype = x .dtype )
268-
269- # Update hidden_{mean,var}
270- new_hidden_mean = hidden_mean * decay + batch_mean * (one - decay )
271- new_hidden_var = hidden_var * decay + batch_var * (one - decay )
272-
273247 # Zero-debias approach: average_ = hidden_ / (1 - decay^counter)
274248 # For simplicity we do the minimal version here (no warmup).
249+ scale = 1 - self .momentum ** counter
250+ mean = hidden_mean / scale
251+ var = hidden_var / scale
252+ else :
253+ mean , var = jax .vmap (_stats )(x )
275254 new_counter = counter + 1
276- decay_power = decay ** new_counter
277- new_running_mean = new_hidden_mean / ( one - decay_power )
278- new_running_var = new_hidden_var / ( one - decay_power )
279-
255+ new_hidden_mean = hidden_mean * self . momentum + mean * (
256+ 1 - self . momentum
257+ )
258+ new_hidden_var = hidden_var * self . momentum + var * ( 1 - self . momentum )
280259 state = state .set (self .batch_counter , new_counter )
281- new_state_data = (
282- (new_hidden_mean , new_hidden_var ),
283- (new_running_mean , new_running_var ),
260+ state = state .set (
261+ self .batch_state_index , (new_hidden_mean , new_hidden_var )
284262 )
285- state = state .set (self .batch_state_index , new_state_data )
286-
287- mean , var = (batch_mean , batch_var )
288263
289- out = jax .vmap (_norm )(x , mean , var , self .weight , self .bias )
290- return out , state
264+ out = jax .vmap (_norm )(x , mean , var , self .weight , self .bias )
265+ return out , state
0 commit comments