Skip to content

Commit 59bae08

Browse files
committed
refactor(muon): remove torch.compile from muon_update
Benchmarks show torch.compile provides negligible speedup (0-5%) and is sometimes slower for muon_update, since computation is dominated by matmul which already uses cuBLAS in eager mode. Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
1 parent a886c3d commit 59bae08

File tree

1 file changed

+1
-5
lines changed

1 file changed

+1
-5
lines changed

deepspeed/runtime/zero/muon/original_muon.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
import torch
3131
import deepspeed.comm as dist # replace torch's distributed package with deepspeed.comm to resolve deepspeed check
32-
from deepspeed.runtime import compiler
3332
from deepspeed.accelerator import get_accelerator
3433

3534

@@ -135,7 +134,7 @@ def _zeropower_via_gram_newtonschulz(G, steps: int):
135134
NS_METHODS = {"standard", "gram"}
136135

137136

138-
def _muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True, ns_method="gram"):
137+
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True, ns_method="gram"):
139138
orig_dtype = grad.dtype
140139
momentum.lerp_(grad, 1 - beta)
141140
update = grad.lerp_(momentum, beta) if nesterov else momentum
@@ -151,9 +150,6 @@ def _muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True, ns_method
151150
return update
152151

153152

154-
muon_update = compiler.compile()(_muon_update)
155-
156-
157153
class Muon(torch.optim.Optimizer):
158154
"""
159155
Muon - MomentUm Orthogonalized by Newton-schulz

0 commit comments

Comments
 (0)