Skip to content

Commit 7df9c09

Browse files
committed
Address review feedback: norm restoration, validation, tests
- Add norm restoration step after row-wise normalization to match reference NorMuon algorithm (paper Algorithm 1) - Fix ns_coeffs validation readability (double-negative → positive) - Tighten NorMuon vs Muon benchmark comparison (5x → 2x) - Parameterize scale_by_normuon tests across matrix shapes
1 parent 536a7ae commit 7df9c09

3 files changed

Lines changed: 17 additions & 9 deletions

File tree

optax/contrib/_normuon.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def init_fn(params):
113113
f'ns_coeffs must have shape (3,) or (n, 3), got {ns_coeffs_.shape}'
114114
)
115115
if ns_coeffs_.ndim == 2:
116-
if not ns_coeffs_.shape[0] <= ns_steps:
116+
if ns_coeffs_.shape[0] > ns_steps:
117117
raise ValueError(f'Not enough coeffs to perform {ns_steps} steps')
118118
ns_coeffs_ = ns_coeffs_[-ns_steps:]
119119

@@ -159,7 +159,12 @@ def _update_nu(o, nu_prev):
159159

160160
# Normalize and apply adaptive scaling.
161161
def _normalize(o, nu_new):
162+
# Restore original norm after row-wise normalization (paper Algorithm 1).
163+
o_norm = jnp.linalg.norm(o, ord='fro')
162164
o_hat = o / (jnp.sqrt(nu_new[..., None]) + eps)
165+
o_hat_norm = jnp.linalg.norm(o_hat, ord='fro')
166+
o_hat = o_hat * (o_norm / (o_hat_norm + eps))
167+
# Apply adaptive scaling.
163168
m_n = math.prod(o.shape[-2:]) if o.ndim >= 2 else o.shape[-1]
164169
frob = jnp.linalg.norm(o_hat, ord='fro')
165170
scale = normuon_scale * jnp.sqrt(m_n) / (frob + eps)

optax/contrib/_normuon_benchmark_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ def test_normuon_vs_muon_convergence(self):
109109
'NorMuon did not converge.',
110110
)
111111

112-
# NorMuon should not be more than 5x worse than Muon.
112+
# NorMuon should not be more than 2x worse than Muon.
113113
self.assertLess(
114114
normuon_losses[-1],
115-
5.0 * muon_losses[-1],
116-
'NorMuon final loss is more than 5x worse than Muon.',
115+
2.0 * muon_losses[-1],
116+
'NorMuon final loss is more than 2x worse than Muon.',
117117
)
118118

119119
def test_normuon_no_side_effects(self):

optax/contrib/_normuon_test.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,18 +73,21 @@ def loss_fn(params):
7373
final_loss = loss_fn(params)
7474
self.assertLess(final_loss, initial_loss * 0.5)
7575

76-
def test_scale_by_normuon_direct(self):
76+
@parameterized.product(
77+
shape=[(6, 4), (8, 8), (3, 10)],
78+
)
79+
def test_scale_by_normuon_direct(self, shape):
7780
"""Test scale_by_normuon directly on a 2D input."""
7881
key = jax.random.key(3)
79-
params = jax.random.normal(key, (6, 4))
82+
params = jax.random.normal(key, shape)
8083
opt = _normuon.scale_by_normuon()
8184
state = opt.init(params)
82-
grad = jax.random.normal(key, (6, 4))
85+
grad = jax.random.normal(key, shape)
8386
updates, new_state = opt.update(grad, state)
84-
self.assertEqual(updates.shape, (6, 4))
87+
self.assertEqual(updates.shape, shape)
8588
self.assertTrue(jnp.all(jnp.isfinite(updates)))
8689
# Check that nu state has the right shape (rows only).
87-
self.assertEqual(new_state.nu.shape, (6,))
90+
self.assertEqual(new_state.nu.shape, (shape[0],))
8891

8992
@parameterized.named_parameters(
9093
('small', 1e-7),

0 commit comments

Comments
 (0)