Skip to content

Commit a4080cf

Browse files
committed
Add NorMuon optimizer implementation and related tests
This commit introduces the NorMuon optimizer, which enhances the Muon algorithm by incorporating neuron-wise normalization of update magnitudes. The implementation includes the core functionality in `_normuon.py`, along with a dedicated test suite in `_normuon_test.py` to validate its behavior. Additionally, references to the new optimizer have been added in the documentation and the main `__init__.py` file. Key changes: - New optimizer: NorMuon with associated scaling and state management. - Tests for NorMuon functionality and behavior. - Documentation updates to include NorMuon in the API reference. PiperOrigin-RevId: [insert-rev-id-here]
1 parent 2225b90 commit a4080cf

5 files changed

Lines changed: 498 additions & 0 deletions

File tree

docs/api/contrib.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ are not supported by the main library.
3434
MomoAdamState
3535
muon
3636
MuonState
37+
normuon
38+
NorMuonState
3739
prodigy
3840
ProdigyState
3941
sam
@@ -127,6 +129,12 @@ Muon
127129
.. autofunction:: scale_by_muon
128130
.. autoclass:: MuonState
129131

132+
NorMuon
133+
~~~~~~~
134+
.. autofunction:: normuon
135+
.. autofunction:: scale_by_normuon
136+
.. autoclass:: NorMuonState
137+
130138
Prodigy
131139
~~~~~~~
132140
.. autofunction:: prodigy

optax/contrib/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@
5454
from optax.contrib._muon import MuonDimensionNumbers
5555
from optax.contrib._muon import MuonState
5656
from optax.contrib._muon import scale_by_muon
57+
from optax.contrib._normuon import normuon
58+
from optax.contrib._normuon import NorMuonState
59+
from optax.contrib._normuon import scale_by_normuon
5760
from optax.contrib._privacy import differentially_private_aggregate
5861
from optax.contrib._privacy import DifferentiallyPrivateAggregateState
5962
from optax.contrib._privacy import dpsgd

optax/contrib/_common_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
{'opt_name': 'momo', 'opt_kwargs': {'learning_rate': 1e-1}},
5454
{'opt_name': 'momo_adam', 'opt_kwargs': {'learning_rate': 1e-1}},
5555
{'opt_name': 'muon', 'opt_kwargs': {'learning_rate': 1e-2}},
56+
{'opt_name': 'normuon', 'opt_kwargs': {'learning_rate': 1e-2}},
5657
{'opt_name': 'prodigy', 'opt_kwargs': {'learning_rate': 1e-1}},
5758
{
5859
'opt_name': 'schedule_free_sgd',

0 commit comments

Comments
 (0)