Skip to content

Commit a0b4ce6

Browse files
author
OptaxDev
committed
Enable performance gain by batching small update tensors.
By setting vmap_optimization_threshold=2**20, the model will batch smaller than 2**20 updates and orthogonalize them in parallel, gaining better utilization of the accelerators. PiperOrigin-RevId: 878942519
1 parent 2225b90 commit a0b4ce6

2 files changed

Lines changed: 85 additions & 6 deletions

File tree

optax/contrib/_muon.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020
"""
2121

2222

23+
import collections
2324
import functools
2425
import math
25-
from typing import Any, Callable, NamedTuple, Optional, Union, Sequence, Literal
26+
from typing import Any, Callable, Literal, NamedTuple, Optional, Sequence, Union
2627

2728
import jax
2829
import jax.numpy as jnp
29-
3030
from optax._src import alias
3131
from optax._src import base
3232
from optax._src import combine
@@ -394,6 +394,7 @@ def scale_by_muon(
394394
'frobenius', 'spectral', 'aol', 'schatten'
395395
] = 'frobenius',
396396
weight_dimension_numbers: WeightDimNumOrFn | None = None,
397+
vmap_optimization_threshold: int = 0,
397398
) -> base.GradientTransformation:
398399
r"""Rescale updates according to the Muon algorithm.
399400
@@ -423,6 +424,9 @@ def scale_by_muon(
423424
params of `MuonDimensionNumbers`s, specifying how to reshape the
424425
parameters before and after the orthogonalization OR a callable returning
425426
such a tree. None implies that all parameters are 2D matrices.
427+
vmap_optimization_threshold: Parameters smaller than this threshold will be
428+
concatenated and orthogonalized together. A value of 2**20 showed good
429+
performance.
426430
427431
Returns:
428432
A `GradientTransformation` object.
@@ -496,10 +500,53 @@ def update_fn(updates, state, params=None):
496500
else:
497501
mu_hat = optax.tree.bias_correction(mu, beta, count_inc)
498502
# Apply Newton-schulz orthogonalization.
499-
updates = jax.tree.map(
500-
lambda x, dim_num: orthogonalize_via_newton_schulz(
501-
x, state.ns_coeffs, ns_steps, preconditioning, eps, dim_num),
502-
mu_hat, resolved_weight_dim_nums, is_leaf=_is_weight_dim_nums)
503+
# In order to better utilize parallel computation, we batch together small
504+
# updates with the same reduction and output shape.
505+
# This is done by:
506+
# 1. Reshape all updates to 3 dimensions: (batch, reduction, output),
507+
# flattening multiple axes if necessary.
508+
# 2. Group all small updates with the same (reduction, output) shape.
509+
# 3. Concatenate all updates within each group on the batch dimension, apply
510+
# orthogonalization on the concatenated tensors, and then split them back
511+
# into their original batch sizes.
512+
# 4. Reshape all updates to their original shapes.
513+
bucket_updates = collections.defaultdict(list)
514+
515+
def bucket_fn(mu_hat_i: jax.Array, dim_nums: MuonDimensionNumbers):
516+
# Reshape to (batch, reduction, output), put in the right bucket, and
517+
# return a fetch+inverse reshape function.
518+
reshape_fn, inverse_fn = _compute_muon_reshape(mu_hat_i, dim_nums)
519+
mu_hat_i = reshape_fn(mu_hat_i)
520+
b, i, o = mu_hat_i.shape
521+
if b * i * o >= vmap_optimization_threshold:
522+
bucket_id = 'no_opt'
523+
else:
524+
bucket_id = (i, o)
525+
pos = len(bucket_updates[bucket_id])
526+
bucket_updates[bucket_id].append(mu_hat_i)
527+
return lambda: inverse_fn(bucket_updates[bucket_id][pos])
528+
529+
inverse_fns = jax.tree.map(
530+
bucket_fn, mu_hat, resolved_weight_dim_nums, is_leaf=_is_weight_dim_nums
531+
)
532+
ortho = functools.partial(
533+
orthogonalize_via_newton_schulz,
534+
ns_coeffs=state.ns_coeffs,
535+
ns_steps=ns_steps,
536+
preconditioning=preconditioning,
537+
eps=eps,
538+
dimension_numbers=MuonDimensionNumbers(reduction_axis=1, output_axis=2),
539+
)
540+
for k, tensors in bucket_updates.items():
541+
if k == 'no_opt':
542+
bucket_updates[k] = [ortho(mu_hat_i) for mu_hat_i in tensors]
543+
continue
544+
concatenated_updates = ortho(jnp.concatenate(tensors, axis=0))
545+
offsets = [0]
546+
for t in tensors[:-1]:
547+
offsets.append(offsets[-1] + t.shape[0])
548+
bucket_updates[k] = jnp.split(concatenated_updates, offsets[1:], axis=0)
549+
updates = jax.tree.map(lambda inverse_fn: inverse_fn(), inverse_fns)
503550
if adaptive:
504551
# Scale the orthogonalized updates by the dual norm of the original
505552
# updates. See https://arxiv.org/abs/2409.20325 for the derivation.
@@ -549,6 +596,7 @@ def muon(
549596
adam_learning_rate: base.ScalarOrSchedule | None = None,
550597
muon_weight_dimension_numbers: WeightDimNumOrFn | None = None,
551598
consistent_rms: jax.typing.ArrayLike | None = None,
599+
vmap_optimization_threshold: int = 0,
552600
) -> base.GradientTransformation:
553601
r"""Muon: Momentum Orthogonalized by Newton-schulz.
554602
@@ -617,6 +665,9 @@ def muon(
617665
root mean square (RMS) shape-independent, like AdamW. `0.2` is recommended
618666
to match AdamW's empirical RMS. See <https://arxiv.org/abs/2502.16982>.
619667
If `None`, uses width scaling `sqrt(max(1, fan_out / fan_in))`.
668+
vmap_optimization_threshold: If set, parameters smaller than this threshold
669+
will be concatenated and orthogonalized together. A value of 2**20 showed
670+
good speedups on TPU.
620671
621672
Returns:
622673
The corresponding `GradientTransformation`.
@@ -704,6 +755,7 @@ def muon_weight_dim_nums_fn(params):
704755
adaptive=adaptive,
705756
preconditioning=preconditioning,
706757
weight_dimension_numbers=muon_weight_dim_nums_fn,
758+
vmap_optimization_threshold=vmap_optimization_threshold,
707759
),
708760
scale_by_shape(
709761
weight_dimension_numbers=muon_weight_dim_nums_fn,

optax/contrib/_muon_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,33 @@ def test_muon_orthogonalization_modes(self, preconditioning, shape):
378378
self.assertLess(max_s, 2.0, msg=f'Max singular value {max_s} too high')
379379
self.assertGreater(min_s, 0.1, msg=f'Min singular value {min_s} too low')
380380

381+
def test_vmap_optimization(self):
382+
"""Tests that VMap optimization is not affecting the results."""
383+
params = {
384+
'w1': jax.random.normal(jax.random.key(1), (100, 100)),
385+
'w2': jax.random.normal(jax.random.key(2), (100, 100)),
386+
'w3': jax.random.normal(jax.random.key(3), (50, 50)),
387+
'w4': jax.random.normal(jax.random.key(4), (50, 50)),
388+
}
389+
390+
results = []
391+
for threshold in [0, 5000, 1000000]:
392+
opt = _muon.muon(
393+
learning_rate=1.0,
394+
weight_decay=0.0,
395+
vmap_optimization_threshold=threshold,
396+
)
397+
state = opt.init(params)
398+
updates, _ = opt.update(params, state, params=params)
399+
results.append(updates)
400+
401+
test_utils.assert_trees_all_close(
402+
results[0], results[1], rtol=1e-2, atol=1e-2
403+
)
404+
test_utils.assert_trees_all_close(
405+
results[0], results[2], rtol=1e-2, atol=1e-2
406+
)
407+
381408
def test_aol_numerical_difference(self):
382409
"""Ensures that AOL=True produces different updates than Standard Muon."""
383410
params = {'w': jnp.eye(8) * 2.0}

0 commit comments

Comments
 (0)