Skip to content
2 changes: 1 addition & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,7 +1754,7 @@ def _configure_basic_optimizer(self, model_parameters):
param_groups = []
if muon_params:
accepted_parameters = dict()
for key in ["lr", "momentum", "weight_decay", "muon_lr"]:
for key in ["lr", "momentum", "weight_decay", "muon_lr", "ns_method"]:
if key in optimizer_parameters:
if key == "muon_lr": # muon_lr will override lr
accepted_parameters['lr'] = optimizer_parameters[key]
Expand Down
129 changes: 116 additions & 13 deletions deepspeed/runtime/zero/muon/original_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torch
import deepspeed.comm as dist # replace torch's distributed package with deepspeed.comm to resolve deepspeed check
from deepspeed.runtime import compiler
from deepspeed.accelerator import get_accelerator


@compiler.compile()
Expand All @@ -45,7 +46,9 @@ def zeropower_via_newtonschulz5(G, steps: int):
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
# Use bf16 when hardware supports it; fp32 otherwise
compute_dtype = torch.bfloat16 if get_accelerator().is_bf16_supported() else torch.float32
X = G.to(compute_dtype)
if G.size(-2) > G.size(-1):
X = X.mT

Expand All @@ -63,13 +66,93 @@ def zeropower_via_newtonschulz5(G, steps: int):


@compiler.compile()
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
def zeropower_via_gram_newtonschulz(G, steps: int):
"""
Gram Newton-Schulz iteration for orthogonalization.

Mathematically equivalent to standard Newton-Schulz but iterates on the
small square Gram matrix R = X @ X.T (n x n) instead of the full rectangular
X (n x m). This reduces FLOPs significantly when m >> n (typical for
transformer weight matrices with aspect ratio ~5).

Uses fp16 instead of bf16 for better numerical precision at the same
compute cost. Includes a restart at iteration 2 to maintain stability
in half-precision.

Falls back to standard Newton-Schulz for square matrices (n == m)
where there is no FLOP advantage.

Reference: https://tridao.me/blog/2026/gram-newton-schulz/
"""
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
# Use fp16 for better precision than bf16 when hardware supports it; fp32 otherwise
compute_dtype = torch.float16 if get_accelerator().is_fp16_supported() else torch.float32
X = G.to(compute_dtype)
if G.size(-2) > G.size(-1):
X = X.mT

n, m = X.size(-2), X.size(-1)

X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)

# For square matrices, no FLOP advantage; use standard iteration
if m <= n:
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X

# Gram NS: iterate on R = X @ X.T (n x n) instead of X (n x m)
R = X @ X.mT
Q = None
restart_at = 2

for i in range(steps):
if i == restart_at and i != 0:
X = Q @ X
R = X @ X.mT
Q = None

Z = b * R + c * R @ R

if Q is None:
Q = Z.clone()
Q.diagonal().add_(a)
else:
Q = torch.addmm(Q, Z, Q, beta=a, alpha=1.0)

if i < steps - 1 and (i + 1) != restart_at:
RZ = torch.addmm(R, Z, R, beta=a, alpha=1.0)
R = torch.addmm(RZ, Z, RZ, beta=a, alpha=1.0)

if G.size(-2) > G.size(-1):
X = X.mT @ Q.mT
else:
X = Q @ X
return X


NS_METHODS = {"standard", "gram"}


@compiler.compile()
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True, ns_method="gram"):
orig_dtype = grad.dtype
momentum.lerp_(grad, 1 - beta)
update = grad.lerp_(momentum, beta) if nesterov else momentum
if update.ndim == 4: # for the case of conv filters
update = update.view(len(update), -1)
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
if ns_method == "gram":
update = zeropower_via_gram_newtonschulz(update, steps=ns_steps)
else:
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
if update.dtype != orig_dtype:
update = update.to(orig_dtype)
return update


Expand All @@ -93,10 +176,12 @@ class Muon(torch.optim.Optimizer):
lr: The learning rate, in units of spectral norm per update.
weight_decay: The AdamW-style weight decay.
momentum: The momentum. A value of 0.95 here is usually fine.
ns_method: Newton-Schulz method. "gram" (default) uses Gram NS for ~2x speedup
on rectangular matrices. "standard" uses the original iteration.
"""

def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95, ns_method="gram"):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, ns_method=ns_method)
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
params = sorted(params, key=lambda x: x.size(), reverse=True)
super().__init__(params, defaults)
Expand All @@ -122,7 +207,10 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(p.grad,
state["momentum_buffer"],
beta=group["momentum"],
ns_method=group.get("ns_method", "gram"))
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()],
Expand All @@ -136,8 +224,8 @@ class SingleDeviceMuon(torch.optim.Optimizer):
Muon variant for usage in non-distributed settings.
"""

def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95, ns_method="gram"):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, ns_method=ns_method)
super().__init__(params, defaults)

@torch.no_grad()
Expand All @@ -156,7 +244,10 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(p.grad,
state["momentum_buffer"],
beta=group["momentum"],
ns_method=group.get("ns_method", "gram"))
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])

Expand Down Expand Up @@ -208,7 +299,10 @@ def __init__(self, param_groups):
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
group["ns_method"] = group.get("ns_method", "gram")
assert group[
"ns_method"] in NS_METHODS, f"ns_method must be one of {NS_METHODS}, got {group['ns_method']}"
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon", "ns_method"])
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
Expand Down Expand Up @@ -240,7 +334,10 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(p.grad,
state["momentum_buffer"],
beta=group["momentum"],
ns_method=group.get("ns_method", "gram"))
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()],
Expand Down Expand Up @@ -277,7 +374,10 @@ def __init__(self, param_groups):
group["lr"] = group.get("lr", 0.02)
group["momentum"] = group.get("momentum", 0.95)
group["weight_decay"] = group.get("weight_decay", 0)
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon"])
group["ns_method"] = group.get("ns_method", "gram")
assert group[
"ns_method"] in NS_METHODS, f"ns_method must be one of {NS_METHODS}, got {group['ns_method']}"
assert set(group.keys()) == set(["params", "lr", "momentum", "weight_decay", "use_muon", "ns_method"])
else:
# defaults
group["lr"] = group.get("lr", 3e-4)
Expand All @@ -304,7 +404,10 @@ def step(self, closure=None):
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
update = muon_update(p.grad,
state["momentum_buffer"],
beta=group["momentum"],
ns_method=group.get("ns_method", "gram"))
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
else:
Expand Down
4 changes: 3 additions & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
if self.use_muon:
self.sub_groups_using_muon = []
self.muon_beta = None
self.muon_ns_method = None
for idx, param_group in enumerate(fp16_param_groups):
if getattr(param_group['params'][0], 'use_muon', False):
self.sub_groups_using_muon.extend([True] * len(param_groups[idx]))
Expand All @@ -767,6 +768,7 @@ def _create_fp16_partitions_with_defragmentation(self, fp16_param_groups):
raise ValueError(f"All Muon parameter groups must have the same momentum (beta). "
f"Found {self.muon_beta} and {group_beta}.")
self.muon_beta = group_beta
self.muon_ns_method = param_group.get('ns_method', 'gram')
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve per-group ns_method in ZeRO-3 Muon updates

The ZeRO-3 setup stores ns_method in a single optimizer-wide field that is overwritten for each Muon param group, and _apply_distributed_muon_update later uses that single value for all Muon subgroups. If a user configures multiple use_muon=True groups with different ns_method values, earlier groups silently run with the last group's method, producing incorrect optimizer behavior and invalid experiment comparisons. This should either enforce one shared ns_method (like momentum) or track/apply ns_method per subgroup.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ns_method actually is decided by ns_method field in json and cannot diverge.

else:
self.sub_groups_using_muon.extend([False] * len(param_groups[idx]))
# bookkeeping related to param groups
Expand Down Expand Up @@ -1515,7 +1517,7 @@ def _apply_distributed_muon_update(self, communication_data_type: torch.dtype, b
param = params[base_i + rank]
g = param.grad
m = gathered_momentums_pad[base_i + rank]
update = muon_update(g, m, beta=self.muon_beta)
update = muon_update(g, m, beta=self.muon_beta, ns_method=getattr(self, 'muon_ns_method', 'gram'))
g.data.copy_(update, non_blocking=False)
grad_handle = dist.all_gather(grads_pad[base_i:base_i + world_sz],
grads_pad[base_i + rank],
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2003,7 +2003,11 @@ def get_flat_partition(self,
assert tensor.ndim > 1, f"if use muon, then tensor dim > 1, got {tensor.size()}"
buffer = torch.narrow(self.optimizer.state[flatten_copy]["momentum_buffer"], 0, buffer_idx,
tensor.numel()).view(tensor.size())
grad_accum = muon_update(grad_accum, buffer, self.optimizer.param_groups[param_group_idx]['momentum'])
ns_method = self.optimizer.param_groups[param_group_idx].get('ns_method', 'gram')
grad_accum = muon_update(grad_accum,
buffer,
self.optimizer.param_groups[param_group_idx]['momentum'],
ns_method=ns_method)
tensor = grad_accum
num_elements = tensor.numel()
buffer_idx += num_elements
Expand Down
14 changes: 13 additions & 1 deletion docs/_pages/config-json.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ toc_label: "Contents"

Muon optimizer is supported with ZeRO Stage 1, 2, and 3. To use Muon, set the optimizer name to `Muon`. The parameters applied for Muon are automatically determined by the matrix shape and name. For ZeRO Stage 3 with NVMe offloading, set `save_muon_momentum_buffer_in_memory` to `true` under `zero_optimization` to keep the Muon momentum buffer in GPU/CPU memory instead of swapping to NVMe.

Muon supports the following params:

| "params" key | Description | Default |
| -------------- | -------------------------------------------------------------------------------------------------------------------- | --------- |
| lr | Learning rate for all parameters. Overridden by `muon_lr` / `adam_lr` if set. | 0.001 |
| momentum | Momentum coefficient for the Muon update. | 0.95 |
| weight\_decay | Weight decay (AdamW-style). | 0.0 |
| muon\_lr | Learning rate override for Muon parameters. Defaults to `lr` if not set. | - |
| adam\_lr | Learning rate override for non-Muon (Adam) parameters. Defaults to `lr` if not set. | - |
| ns\_method | Newton-Schulz orthogonalization method: `"gram"` for Gram NS (~2x faster on rectangular matrices), `"standard"` for the original iteration. Use `"standard"` to fall back if you encounter convergence issues. | `"gram"` |

Example of <i>**optimizer**</i> with Adam

```json
Expand Down Expand Up @@ -73,7 +84,8 @@ If not set, muon_lr will default to lr.
"lr": 0.001,
"momentum": 0.9,
"weight_decay": 0.0,
"muon_lr": 0.001
"muon_lr": 0.001,
"ns_method": "gram"
}
},
"zero_optimization": {
Expand Down
91 changes: 91 additions & 0 deletions tests/unit/ops/muon/test_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,94 @@ def test(self, optimizer_type, zero_stage, lr, hidden_dim, nlayer, offload_optim
after_training = [p.clone().cpu() for p in model.parameters()]
for initial, final in zip(initial_params, after_training):
assert not torch.equal(initial.cpu(), final.cpu()), "Parameters should have been updated during training"


class TestGramNewtonSchulz(DistributedTest):
"""Test Gram Newton-Schulz integration with Muon optimizer."""

world_size = 2
reuse_dist_env = True

@pytest.mark.parametrize('ns_method', ['gram', 'standard'])
@pytest.mark.parametrize('zero_stage', [1, 2])
def test_ns_method_training(self, ns_method, zero_stage):
"""Verify both ns_method values work end-to-end with DeepSpeed."""
hidden_dim = 64
batch_size = 8
config_dict = {
"train_batch_size": batch_size,
"optimizer": {
"type": "muon",
"params": {
"lr": 0.01,
"ns_method": ns_method,
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True,
},
"zero_optimization": {
"stage": zero_stage,
"reduce_scatter": False,
},
}

model = SimpleModel(hidden_dim=hidden_dim, nlayers=3)
initial_params = [p.clone().cpu() for p in model.parameters()]
engine, optimizer, _, _ = deepspeed.initialize(
config=config_dict,
model=model,
model_parameters=model.parameters(),
dist_init_required=False,
)

for _ in range(3):
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half)
y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device)
loss = engine(x, y)
engine.backward(loss)
engine.step()

after_training = [p.clone().cpu() for p in model.parameters()]
for initial, final in zip(initial_params, after_training):
assert not torch.equal(initial, final), "Parameters should have been updated"

@pytest.mark.parametrize('ns_method', ['gram', 'standard'])
def test_ns_method_stage3(self, ns_method):
"""Verify ns_method works with ZeRO Stage 3."""
hidden_dim = 64
batch_size = 8
config_dict = {
"train_batch_size": batch_size,
"optimizer": {
"type": "muon",
"params": {
"lr": 0.01,
"ns_method": ns_method,
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": True,
},
"zero_optimization": {
"stage": 3,
"reduce_scatter": False,
},
}

model = SimpleModel(hidden_dim=hidden_dim, nlayers=3)
engine, optimizer, _, _ = deepspeed.initialize(
config=config_dict,
model=model,
model_parameters=model.parameters(),
dist_init_required=False,
)

for _ in range(3):
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half)
y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device)
loss = engine(x, y)
engine.backward(loss)
engine.step()
Loading