|
20 | 20 | """ |
21 | 21 |
|
22 | 22 |
|
| 23 | +import collections |
23 | 24 | import functools |
24 | 25 | import math |
25 | | -from typing import Any, Callable, NamedTuple, Optional, Union, Sequence, Literal |
| 26 | +from typing import Any, Callable, Literal, NamedTuple, Optional, Sequence, Union |
26 | 27 |
|
27 | 28 | import jax |
28 | 29 | import jax.numpy as jnp |
29 | | - |
30 | 30 | from optax._src import alias |
31 | 31 | from optax._src import base |
32 | 32 | from optax._src import combine |
@@ -394,6 +394,7 @@ def scale_by_muon( |
394 | 394 | 'frobenius', 'spectral', 'aol', 'schatten' |
395 | 395 | ] = 'frobenius', |
396 | 396 | weight_dimension_numbers: WeightDimNumOrFn | None = None, |
| 397 | + vmap_optimization_threshold: int = 0, |
397 | 398 | ) -> base.GradientTransformation: |
398 | 399 | r"""Rescale updates according to the Muon algorithm. |
399 | 400 |
|
@@ -423,6 +424,9 @@ def scale_by_muon( |
423 | 424 | params of `MuonDimensionNumbers`s, specifying how to reshape the |
424 | 425 | parameters before and after the orthogonalization OR a callable returning |
425 | 426 | such a tree. None implies that all parameters are 2D matrices. |
| 427 | + vmap_optimization_threshold: Parameters smaller than this threshold will be |
| 428 | + concatenated and orthogonalized together. A value of 2**20 showed good |
| 429 | + performance. |
426 | 430 |
|
427 | 431 | Returns: |
428 | 432 | A `GradientTransformation` object. |
@@ -496,10 +500,53 @@ def update_fn(updates, state, params=None): |
496 | 500 | else: |
497 | 501 | mu_hat = optax.tree.bias_correction(mu, beta, count_inc) |
498 | 502 | # Apply Newton-schulz orthogonalization. |
499 | | - updates = jax.tree.map( |
500 | | - lambda x, dim_num: orthogonalize_via_newton_schulz( |
501 | | - x, state.ns_coeffs, ns_steps, preconditioning, eps, dim_num), |
502 | | - mu_hat, resolved_weight_dim_nums, is_leaf=_is_weight_dim_nums) |
| 503 | + # In order to better utilize parallel computation, we batch together small |
| 504 | + # updates with the same reduction and output shape. |
| 505 | + # This is done by: |
| 506 | + # 1. Reshape all updates to 3 dimensions: (batch, reduction, output), |
| 507 | + # flattening multiple axes if necessary. |
| 508 | + # 2. Group all small updates with the same (reduction, output) shape. |
| 509 | + # 3. Concatenate all updates within each group on the batch dimension, apply |
| 510 | + # orthogonalization on the concatenated tensors, and then split them back |
| 511 | + # into their original batch sizes. |
| 512 | + # 4. Reshape all updates to their original shapes. |
| 513 | + bucket_updates = collections.defaultdict(list) |
| 514 | + |
| 515 | + def bucket_fn(mu_hat_i: jax.Array, dim_nums: MuonDimensionNumbers): |
| 516 | + # Reshape to (batch, reduction, output), put in the right bucket, and |
| 517 | + # return a fetch+inverse reshape function. |
| 518 | + reshape_fn, inverse_fn = _compute_muon_reshape(mu_hat_i, dim_nums) |
| 519 | + mu_hat_i = reshape_fn(mu_hat_i) |
| 520 | + b, i, o = mu_hat_i.shape |
| 521 | + if b * i * o >= vmap_optimization_threshold: |
| 522 | + bucket_id = 'no_opt' |
| 523 | + else: |
| 524 | + bucket_id = (i, o) |
| 525 | + pos = len(bucket_updates[bucket_id]) |
| 526 | + bucket_updates[bucket_id].append(mu_hat_i) |
| 527 | + return lambda: inverse_fn(bucket_updates[bucket_id][pos]) |
| 528 | + |
| 529 | + inverse_fns = jax.tree.map( |
| 530 | + bucket_fn, mu_hat, resolved_weight_dim_nums, is_leaf=_is_weight_dim_nums |
| 531 | + ) |
| 532 | + ortho = functools.partial( |
| 533 | + orthogonalize_via_newton_schulz, |
| 534 | + ns_coeffs=state.ns_coeffs, |
| 535 | + ns_steps=ns_steps, |
| 536 | + preconditioning=preconditioning, |
| 537 | + eps=eps, |
| 538 | + dimension_numbers=MuonDimensionNumbers(reduction_axis=1, output_axis=2), |
| 539 | + ) |
| 540 | + for k, tensors in bucket_updates.items(): |
| 541 | + if k == 'no_opt': |
| 542 | + bucket_updates[k] = [ortho(mu_hat_i) for mu_hat_i in tensors] |
| 543 | + continue |
| 544 | + concatenated_updates = ortho(jnp.concatenate(tensors, axis=0)) |
| 545 | + offsets = [0] |
| 546 | + for t in tensors[:-1]: |
| 547 | + offsets.append(offsets[-1] + t.shape[0]) |
| 548 | + bucket_updates[k] = jnp.split(concatenated_updates, offsets[1:], axis=0) |
| 549 | + updates = jax.tree.map(lambda inverse_fn: inverse_fn(), inverse_fns) |
503 | 550 | if adaptive: |
504 | 551 | # Scale the orthogonalized updates by the dual norm of the original |
505 | 552 | # updates. See https://arxiv.org/abs/2409.20325 for the derivation. |
@@ -549,6 +596,7 @@ def muon( |
549 | 596 | adam_learning_rate: base.ScalarOrSchedule | None = None, |
550 | 597 | muon_weight_dimension_numbers: WeightDimNumOrFn | None = None, |
551 | 598 | consistent_rms: jax.typing.ArrayLike | None = None, |
| 599 | + vmap_optimization_threshold: int = 0, |
552 | 600 | ) -> base.GradientTransformation: |
553 | 601 | r"""Muon: Momentum Orthogonalized by Newton-schulz. |
554 | 602 |
|
@@ -617,6 +665,9 @@ def muon( |
617 | 665 | root mean square (RMS) shape-independent, like AdamW. `0.2` is recommended |
618 | 666 | to match AdamW's empirical RMS. See <https://arxiv.org/abs/2502.16982>. |
619 | 667 | If `None`, uses width scaling `sqrt(max(1, fan_out / fan_in))`. |
| 668 | + vmap_optimization_threshold: If set, parameters smaller than this threshold |
| 669 | + will be concatenated and orthogonalized together. A value of 2**20 showed |
| 670 | + good speedups on TPU. |
620 | 671 |
|
621 | 672 | Returns: |
622 | 673 | The corresponding `GradientTransformation`. |
@@ -704,6 +755,7 @@ def muon_weight_dim_nums_fn(params): |
704 | 755 | adaptive=adaptive, |
705 | 756 | preconditioning=preconditioning, |
706 | 757 | weight_dimension_numbers=muon_weight_dim_nums_fn, |
| 758 | + vmap_optimization_threshold=vmap_optimization_threshold, |
707 | 759 | ), |
708 | 760 | scale_by_shape( |
709 | 761 | weight_dimension_numbers=muon_weight_dim_nums_fn, |
|
0 commit comments