Skip to content

Commit 26ba03c

Browse files
Ryan McKennaOptaxDev
authored andcommitted
Update microbatching to support size 0 batches. This change is important for Jax privacy, where the batch size is random, and possibly empty.
PiperOrigin-RevId: 889281155
1 parent 36d72ac commit 26ba03c

2 files changed

Lines changed: 87 additions & 11 deletions

File tree

optax/microbatching/_microbatching.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -183,8 +183,6 @@ def _sum() -> Accumulator:
183183

184184
def _mean(num_microbatches: int) -> Accumulator:
185185
"""An Accumulator that computes the mean of microbatched outputs."""
186-
if num_microbatches <= 0:
187-
raise ValueError(f'{num_microbatches=} must be positive.')
188186
return _lift(
189187
Accumulator(
190188
init=_with_floating_check(jnp.zeros_like),
@@ -230,8 +228,6 @@ def _get_out_sharding(x):
230228

231229
def _concat(num_microbatches: int) -> Accumulator:
232230
"""An Accumulator that concatenates microbatched outputs along the axis 0."""
233-
if num_microbatches <= 0:
234-
raise ValueError(f'{num_microbatches=} must be positive.')
235231

236232
def init(value):
237233
shape = (num_microbatches,) + value.shape
@@ -321,6 +317,15 @@ def _reshape_all_args(
321317
return tuple(new_args), new_kwargs, tuple(batch_sizes)[0]
322318

323319

320+
def _take_fn(index: int, axis: int) -> Callable[[jax.Array], jax.Array]:
321+
"""Returns a function that takes the `index`-th element along the `axis`."""
322+
def fun(x):
323+
if x.shape[axis] == 0: # jnp.take doesn't work with zero axis size.
324+
return jnp.empty_like(x, shape=x.shape[:axis] + x.shape[axis + 1:])
325+
return jnp.take(x, indices=index, axis=axis)
326+
return fun
327+
328+
324329
def microbatch(
325330
fun: Function,
326331
argnums: int | Sequence[int],
@@ -421,13 +426,9 @@ def f(index):
421426
input_args = list(reshaped_args)
422427
input_kwargs = dict(reshaped_kwargs)
423428
for i, ax in zip(argnums, in_axes):
424-
input_args[i] = jax.tree.map(
425-
functools.partial(jnp.take, indices=index, axis=ax), input_args[i]
426-
)
429+
input_args[i] = jax.tree.map(_take_fn(index, ax), input_args[i])
427430
for i, ax in zip(argnames, in_axes[len(argnums) :]):
428-
input_kwargs[i] = jax.tree.map(
429-
functools.partial(jnp.take, indices=index, axis=ax), input_kwargs[i]
430-
)
431+
input_kwargs[i] = jax.tree.map(_take_fn(index, ax), input_kwargs[i])
431432
return fun(*input_args, **input_kwargs)
432433

433434
def body_fun(index, carry):
@@ -436,8 +437,10 @@ def body_fun(index, carry):
436437
early_stop = num_real_microbatches is not None
437438
loop_bound = num_real_microbatches if early_stop else num_microbatches
438439
init_carry = accumulator_.init(jax.eval_shape(f, 0))
439-
answer = jax.lax.fori_loop(0, loop_bound, body_fun, init_carry)
440+
if num_microbatches == 0:
441+
return accumulator_.finalize(init_carry)
440442

443+
answer = jax.lax.fori_loop(0, loop_bound, body_fun, init_carry)
441444
return accumulator_.finalize(answer)
442445

443446
return microbatched_fun

optax/microbatching/_microbatching_test.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -333,6 +333,79 @@ def test_micro_grad_early_stopping(self):
333333
result, _ = grad_fn(1.0, jnp.ones(16))
334334
test_utils.assert_trees_all_close(result, 12.0)
335335

336+
def test_zero_batch_size_microbatch(self):
337+
def fun(x):
338+
return jnp.sum(x, axis=0)
339+
340+
m_fun = microbatching.microbatch(
341+
fun,
342+
argnums=0,
343+
microbatch_size=2,
344+
accumulator=microbatching.AccumulationType.SUM,
345+
)
346+
res = m_fun(jnp.zeros((0, 4)))
347+
self.assertEqual(res.shape, (4,))
348+
test_utils.assert_trees_all_close(res, jnp.zeros(4))
349+
350+
def test_zero_batch_size_microbatch_mean(self):
351+
def fun(x):
352+
return jnp.mean(x, axis=0)
353+
354+
m_fun = microbatching.microbatch(
355+
fun,
356+
argnums=0,
357+
microbatch_size=2,
358+
accumulator=microbatching.AccumulationType.MEAN,
359+
)
360+
res = m_fun(jnp.zeros((0, 4)))
361+
self.assertEqual(res.shape, (4,))
362+
self.assertTrue(jnp.all(jnp.isnan(res)))
363+
364+
def test_zero_batch_size_microbatch_concat(self):
365+
def fun(x):
366+
return x * 2
367+
368+
m_fun = microbatching.microbatch(
369+
fun,
370+
argnums=0,
371+
microbatch_size=2,
372+
accumulator=microbatching.AccumulationType.CONCAT,
373+
)
374+
res = m_fun(jnp.zeros((0, 4)))
375+
self.assertEqual(res.shape, (0, 4))
376+
377+
def test_zero_batch_size_micro_vmap(self):
378+
m_vmap = microbatching.micro_vmap(lambda x: x * 2, microbatch_size=2)
379+
res = m_vmap(jnp.zeros((0, 4)))
380+
self.assertEqual(res.shape, (0, 4))
381+
382+
def test_zero_batch_size_micro_grad(self):
383+
def mean_squared_loss(params, features, targets):
384+
preds = features @ params
385+
diff = preds - targets
386+
return 0.5 * jnp.mean(diff**2)
387+
388+
grad_fn = microbatching.micro_grad(
389+
mean_squared_loss,
390+
argnums=0,
391+
batch_argnums=(1, 2),
392+
transform_fn=lambda x: (x, x**2),
393+
metrics_fn=jnp.linalg.norm,
394+
keep_batch_dim=True,
395+
microbatch_size=1,
396+
)
397+
params = jnp.zeros(1)
398+
features = jnp.zeros((0, 1))
399+
targets = jnp.zeros((0,))
400+
(grads, squared_grads), aux = grad_fn(params, features, targets)
401+
402+
self.assertEqual(grads.shape, (1,))
403+
test_utils.assert_trees_all_close(grads, jnp.zeros(1))
404+
self.assertEqual(squared_grads.shape, (1,))
405+
test_utils.assert_trees_all_close(squared_grads, jnp.zeros(1))
406+
self.assertEqual(aux.values.shape, (0,))
407+
self.assertEqual(aux.metrics.shape, (0,))
408+
336409

337410
if __name__ == '__main__':
338411
absltest.main()

0 commit comments

Comments
 (0)