-
Notifications
You must be signed in to change notification settings - Fork 24
Expand file tree
/
Copy pathtest_orthogonalized_optimizer.py
More file actions
491 lines (406 loc) · 18 KB
/
test_orthogonalized_optimizer.py
File metadata and controls
491 lines (406 loc) · 18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from absl import flags, logging
from absl.testing import absltest, parameterized
from emerging_optimizers.orthogonalized_optimizers import mop, muon, muon_hyperball, polargrad, scion
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer
flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on")
flags.DEFINE_integer("seed", None, "Random seed for reproducible tests")
FLAGS = flags.FLAGS
def setUpModule() -> None:
if FLAGS.seed is not None:
logging.info("Setting random seed to %d", FLAGS.seed)
torch.manual_seed(FLAGS.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(FLAGS.seed)
class OrthogonalizedOptimizerTest(parameterized.TestCase):
def setUp(self):
self.device = FLAGS.device
@parameterized.product(
weight_decay_method=["decoupled", "independent", "l2"],
shape=[(5, 7), (33, 65), (127, 257)],
nesterov=[True, False],
fp32_matmul_prec=["highest", "medium", "low"],
)
def test_smoke(self, weight_decay_method, shape, nesterov, fp32_matmul_prec) -> None:
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
test_param.grad = torch.randint_like(test_param, -5, 5)
orthogonalized_opt = OrthogonalizedOptimizer(
[test_param],
lr=2,
momentum=0,
weight_decay=0.5,
nesterov=nesterov,
weight_decay_method=weight_decay_method,
fp32_matmul_prec=fp32_matmul_prec,
)
orthogonalized_opt.step()
@parameterized.parameters(
{"shape": (5, 7)},
{"shape": (33, 65)},
{"shape": (127, 257)},
)
def test_orthogonalized_optimizer_core_matches_sgd(self, shape) -> None:
"""Test that OrthogonalizedOptimizer matches SGD when orthogonalization is disabled."""
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
ref_param = nn.Parameter(torch.empty_like(test_param))
ref_param.data.copy_(test_param.data)
test_param.grad = torch.randint_like(test_param, -5, 5)
ref_param.grad = test_param.grad.clone()
orthogonalized_opt = OrthogonalizedOptimizer(
[test_param],
lr=2,
momentum=0,
nesterov=False,
weight_decay=0.5,
weight_decay_method="decoupled",
fp32_matmul_prec="highest",
)
sgd_opt = torch.optim.SGD(
[ref_param],
lr=2,
momentum=0,
nesterov=False,
weight_decay=0.5,
)
orthogonalized_opt.step()
sgd_opt.step()
torch.testing.assert_close(
test_param.data,
ref_param.data,
atol=0,
rtol=0,
)
@parameterized.parameters(
{"shape": (5, 7)},
{"shape": (33, 65)},
{"shape": (127, 257)},
)
def test_orthogonalized_optimizer_core_matches_sgd_with_momentum(self, shape) -> None:
"""Test that OrthogonalizedOptimizer matches SGD with momentum over multiple steps."""
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
ref_param = nn.Parameter(torch.empty_like(test_param))
ref_param.data.copy_(test_param.data)
# Muon EMA momentum while torch SGD uses standard momentum. lr and momentum values
# are specially chosen for them to match.
orthogonalized_opt = OrthogonalizedOptimizer(
[test_param],
lr=2.0,
momentum=0.5,
nesterov=False,
weight_decay=0.0,
weight_decay_method="l2",
fp32_matmul_prec="highest",
)
sgd_opt = torch.optim.SGD(
[ref_param],
lr=1.0,
momentum=0.5,
nesterov=False,
weight_decay=0.0,
)
for _ in range(5):
test_param.grad = torch.randint_like(test_param, -5, 5)
ref_param.grad = test_param.grad.clone()
orthogonalized_opt.step()
sgd_opt.step()
torch.testing.assert_close(
test_param.data,
ref_param.data,
atol=0,
rtol=0,
)
@parameterized.parameters(
{"skip_non_grad_params": False},
{"skip_non_grad_params": True},
)
def test_init_group_skip_non_grad_params(self, skip_non_grad_params) -> None:
"""Test _init_group with skip_non_grad_params flag."""
param_with_grad = nn.Parameter(torch.randn(5, 7, device=self.device))
param_without_grad = nn.Parameter(torch.randn(5, 7, device=self.device))
param_with_grad.grad = torch.randn_like(param_with_grad)
opt = OrthogonalizedOptimizer(
[param_with_grad, param_without_grad],
lr=1.0,
momentum=0.0,
nesterov=False,
weight_decay=0.0,
weight_decay_method="l2",
fp32_matmul_prec="highest",
)
opt._init_group(opt.param_groups[0], skip_non_grad_params=skip_non_grad_params)
self.assertIn("momentum_buffer", opt.state[param_with_grad])
self.assertEqual(opt.state[param_with_grad]["momentum_buffer"].shape, param_with_grad.data.shape)
self.assertEqual("momentum_buffer" in opt.state[param_without_grad], not skip_non_grad_params)
def test_split_fn_interleaved(self) -> None:
"""Test a three way interleaved split function.
With 0 weights and lr -1, returned param should match orthogonalized grads.
"""
test_param = torch.zeros((6, 7), dtype=torch.float32, device=self.device)
test_param.grad = torch.empty_like(test_param.data)
for i in range(test_param.shape[0]):
test_param.grad[i] = i + 1
def dummy_interleaved_split_orth_fn(x: torch.Tensor) -> torch.Tensor:
out_list = [[], [], []]
for i in range(x.shape[0]):
out_list[i % 3].append(x[i : i + 1])
orth_grad_list = [torch.cat(t, dim=0) for t in out_list]
return torch.cat([torch.empty_like(x).fill_(x.max()) for x in orth_grad_list], dim=0)
orthogonalized_opt = OrthogonalizedOptimizer(
[test_param],
lr=-1,
momentum=0,
nesterov=False,
weight_decay=0.0,
weight_decay_method="l2",
fp32_matmul_prec="highest",
scaled_orthogonalize_fn=dummy_interleaved_split_orth_fn,
)
orthogonalized_opt.step()
assert not torch.allclose(test_param, test_param.grad)
ref_out = dummy_interleaved_split_orth_fn(test_param.grad)
torch.testing.assert_close(
test_param,
ref_out,
atol=0,
rtol=0,
)
def test_non_2d_param_raises_value_error(self) -> None:
"""Test that OrthogonalizedOptimizer raises ValueError for non-2D parameters."""
test_param = nn.Parameter(torch.randn(8, dtype=torch.float32, device=self.device))
test_param.grad = torch.randn_like(test_param)
# OrthogonalizedOptimizer has no defaults for lr/momentum/weight_decay/nesterov/weight_decay_method/fp32_matmul_prec
opt = OrthogonalizedOptimizer(
[test_param],
lr=1.0,
momentum=0.0,
weight_decay=0.0,
nesterov=False,
weight_decay_method="l2",
fp32_matmul_prec="highest",
)
with self.assertRaisesRegex(ValueError, "Only 2D"):
opt.step()
class MuonTest(parameterized.TestCase):
def setUp(self):
self.device = FLAGS.device
@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
weight_decay_method=["decoupled", "independent", "l2"],
nesterov=[True, False],
)
def test_smoke(self, shape, weight_decay_method, nesterov) -> None:
"""Smoke test Muon optimizer.
Most functionality of muon is tested in muon_utils. This test only entures everything run through
the optimizer class.
"""
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
test_param.grad = torch.randint_like(test_param, -5, 5)
muon_opt = muon.Muon([test_param], weight_decay_method=weight_decay_method, nesterov=nesterov)
muon_opt.step()
def test_use_syrk_match_without_syrk(self) -> None:
shape = (32, 32)
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
ref_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
ref_param.data.copy_(test_param.data)
test_param.grad = torch.randint_like(test_param, -5, 5)
ref_param.grad = test_param.grad.clone()
muon_opt = muon.Muon([test_param], num_ns_steps=1, coefficient_type="simple", use_syrk=True)
ref_muon_opt = muon.Muon([ref_param], num_ns_steps=1, coefficient_type="simple", use_syrk=False)
muon_opt.step()
ref_muon_opt.step()
torch.testing.assert_close(
test_param.data,
ref_param.data,
)
def test_use_independent_wd(self) -> None:
"""Test that use_independent_wd properly decouples weight decay from learning rate."""
shape = (32, 32)
weight_decay = 0.25
# Test with independent weight decay: with lr=0, weight decay should still be applied
# With lr=0, no gradient update occurs, so param should be exactly (1-wd)*param
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
test_param.grad = torch.randint_like(test_param, -5, 5)
# With independent weight decay and lr=0, param should be exactly (1-wd)*param
expected_param = (1 - weight_decay) * test_param.data
muon_opt_indep = muon.Muon(
[test_param],
lr=0.0, # Zero learning rate
weight_decay=weight_decay,
weight_decay_method="independent",
momentum=0.0,
)
muon_opt_indep.step()
torch.testing.assert_close(
test_param,
expected_param,
atol=0,
rtol=0,
)
def test_zero_num_ns_steps_raises_value_error(self) -> None:
"""Test that Muon raises ValueError for num_ns_steps < 1."""
test_param = nn.Parameter(torch.randn(5, 7, dtype=torch.float32, device=self.device))
with self.assertRaisesRegex(ValueError, "num_ns_steps must be at least 1"):
muon.Muon([test_param], num_ns_steps=0)
def test_invalid_scale_mode_raises_value_error(self) -> None:
"""Test that get_muon_scale_factor raises ValueError for invalid mode."""
with self.assertRaisesRegex(ValueError, "Invalid mode.*invalid_mode"):
muon.get_muon_scale_factor(10, 10, "invalid_mode")
class ScionTest(parameterized.TestCase):
def setUp(self):
self.device = FLAGS.device
@parameterized.parameters(
{"shape": (5, 7)},
{"shape": (33, 65)},
{"shape": (127, 257)},
)
def test_smoke(self, shape) -> None:
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
test_param.grad = torch.randint_like(test_param, -5, 5)
scion_opt = scion.Scion([test_param])
scion_opt.step()
def test_zero_num_ns_steps_raises_value_error(self) -> None:
"""Test that Scion raises ValueError for num_ns_steps < 1."""
test_param = nn.Parameter(torch.randn(5, 7, dtype=torch.float32, device=self.device))
with self.assertRaisesRegex(ValueError, "num_ns_steps must be at least 1"):
scion.Scion([test_param], num_ns_steps=0)
class MopTest(parameterized.TestCase):
def setUp(self):
self.device = FLAGS.device
@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
weight_decay_method=["decoupled", "independent"],
nesterov=[True, False],
scale_mode=["spectral", "nuclear_norm"],
extra_scale_factor=[1.0, 0.2],
)
def test_smoke(self, shape, weight_decay_method, nesterov, scale_mode, extra_scale_factor) -> None:
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
test_param.grad = torch.randint_like(test_param, -5, 5)
mop_opt = mop.MOP(
[test_param],
weight_decay_method=weight_decay_method,
nesterov=nesterov,
scale_mode=scale_mode,
extra_scale_factor=extra_scale_factor,
)
mop_opt.step()
class MuonHyperballTest(parameterized.TestCase):
def setUp(self):
self.device = FLAGS.device
@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
)
def test_norm_preservation(self, shape) -> None:
"""Test that MuonHyperball preserves parameter norm after optimizer steps."""
test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=self.device))
initial_norm = test_param.norm().item()
opt = muon_hyperball.MuonHyperball(
[test_param],
lr=0.01,
momentum=0.0,
weight_decay=0.0,
)
# Run multiple steps with random gradients
for _ in range(5):
test_param.grad = torch.randn_like(test_param)
opt.step()
# Norm should be preserved after each step
torch.testing.assert_close(
test_param.norm(),
torch.tensor(initial_norm, device=self.device),
atol=1e-5,
rtol=1e-5,
)
@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
hyperball_radius=[0.5, 1.0, 2.0],
)
def test_hyperball_radius_rescales_params(self, shape, hyperball_radius) -> None:
"""Test that hyperball_radius kwarg rescales parameters to specified radius."""
test_param = nn.Parameter(torch.randn(shape, dtype=torch.float32, device=self.device))
opt = muon_hyperball.MuonHyperball(
[test_param],
lr=0.01,
hyperball_radius=hyperball_radius,
)
# After initialization, parameter should have the specified radius
torch.testing.assert_close(
test_param.norm(),
torch.tensor(hyperball_radius, device=self.device),
atol=1e-5,
rtol=1e-5,
)
# Run multiple steps with random gradients
for _ in range(5):
test_param.grad = torch.randn_like(test_param)
opt.step()
# Norm should remain at hyperball_radius after each step
torch.testing.assert_close(
test_param.norm(),
torch.tensor(hyperball_radius, device=self.device),
atol=1e-5,
rtol=1e-5,
)
def test_zero_norm_raises_error(self) -> None:
"""Test that MuonHyperball raises ValueError for zero-norm parameters."""
test_param = nn.Parameter(torch.zeros((5, 7), dtype=torch.float32, device=self.device))
with self.assertRaises(ValueError):
muon_hyperball.MuonHyperball([test_param], lr=0.01)
class PolarGradTest(parameterized.TestCase):
def setUp(self):
self.device = FLAGS.device
@parameterized.product(
shape=[(5, 7), (33, 65), (127, 257)],
extra_scale_factor=[1.0, 0.2],
)
def test_smoke(self, shape, extra_scale_factor) -> None:
test_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
test_param.grad = torch.randint_like(test_param, -5, 5)
polargrad_opt = polargrad.PolarGrad(
[test_param],
extra_scale_factor=extra_scale_factor,
)
polargrad_opt.step()
@parameterized.product(
shape=[(4, 8), (16, 16), (32, 64), (13, 17)],
extra_scale_factor=[0.25, 0.125],
)
def test_orthogonalize_fn_matches_ref(self, shape, extra_scale_factor) -> None:
dummy_param = nn.Parameter(torch.randint(-5, 5, shape, dtype=torch.float32, device=self.device))
dummy_grad = torch.full(shape, 0.5, dtype=torch.float32, device=self.device)
# Set num_ns_steps to 0 to skip Newton-Schulz iterations and only normalize the input gradient.
polargrad_opt = polargrad.PolarGrad([dummy_param], num_ns_steps=0, extra_scale_factor=extra_scale_factor)
norm_grad = torch.nn.functional.normalize(dummy_grad, p=2, dim=(-2, -1), eps=1e-7)
# Assert normalization took effect
self.assertFalse((norm_grad == 1).all())
ref_scale = (norm_grad * dummy_grad).sum()
ref_out = norm_grad * ref_scale * extra_scale_factor
test_out = polargrad_opt.scaled_orthogonalize_fn(dummy_grad)
torch.testing.assert_close(
ref_out,
test_out,
atol=0,
rtol=0,
)
def test_negative_num_ns_steps_raises_value_error(self) -> None:
"""Test that PolarGrad raises ValueError for negative num_ns_steps."""
test_param = nn.Parameter(torch.randn(5, 7, dtype=torch.float32, device=self.device))
with self.assertRaisesRegex(ValueError, "num_ns_steps must be positive"):
polargrad.PolarGrad([test_param], num_ns_steps=-1)
if __name__ == "__main__":
absltest.main()