Skip to content

Commit c058864

Browse files
committed
feat(zero2): add CPU offload support for Muon optimizer
Add Muon optimizer support in ZeRO Stage 1&2 CPU offload path by: 1. Partition strategy: Muon param groups now partition by parameter boundaries (never split a param across ranks), padding to uniform max size for all-gather compatibility. Logs padding overhead ratio. 2. CPU Newton-Schulz: Add muon_update_cpu() and zeropower_via_newtonschulz5_cpu() using PyTorch CPU bf16 matmul as baseline. Architecture allows future replacement with AMX C++ kernel. 3. CPU offload integration: _apply_muon_update_for_cpu_offload() copies complete gradients to CPU, runs muon_update on CPU (momentum buffer stays on CPU), writes result to FP32 grad buffer. No extra PCIe transfers. Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
1 parent abb88ce commit c058864

File tree

3 files changed

+211
-2
lines changed

3 files changed

+211
-2
lines changed

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1501,6 +1501,70 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
15011501
return torch.tensor(total_norm, device=self.device, dtype=torch.float)
15021502

15031503
############################################################################################
1504+
def _apply_muon_update_for_cpu_offload(self, param):
1505+
"""Apply muon_update on CPU for a parameter in the CPU offload path.
1506+
1507+
For Muon parameters (use_muon=True), copies the complete gradient to
1508+
CPU, runs Newton-Schulz orthogonalization on CPU, then writes only the
1509+
partition slice back to the CPU FP32 grad buffer. Cross-boundary
1510+
parameters are redundantly processed by each involved rank with the
1511+
full gradient, matching the non-offload path behavior in get_flat_partition.
1512+
1513+
Returns True if muon_update was applied (caller should skip the normal
1514+
copy for this param).
1515+
"""
1516+
if not getattr(param, 'use_muon', False):
1517+
return False
1518+
if 'muon' not in self.optimizer.__class__.__name__.lower():
1519+
return False
1520+
1521+
param_id = self.get_param_id(param)
1522+
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
1523+
1524+
grad_accum = self.get_param_gradient_attribute(param)
1525+
if grad_accum is None:
1526+
return False
1527+
1528+
grad_cpu = grad_accum.detach().clone().to(device=self.device, dtype=torch.float32)
1529+
1530+
flatten_copy = self.optimizer.param_groups[i]['params'][0]
1531+
if "momentum_buffer" not in self.optimizer.state[flatten_copy]:
1532+
total_size = sum(p.numel() for p in self.params_in_partition[i])
1533+
self.optimizer.state[flatten_copy]["momentum_buffer"] = torch.zeros(total_size,
1534+
dtype=torch.float32,
1535+
device=self.device)
1536+
1537+
momentum_flat = self.optimizer.state[flatten_copy]["momentum_buffer"]
1538+
1539+
muon_offset = 0
1540+
for p in self.params_in_partition[i]:
1541+
if p is param:
1542+
break
1543+
muon_offset += p.numel()
1544+
1545+
momentum_cpu = momentum_flat[muon_offset:muon_offset + param.numel()].view(param.size())
1546+
1547+
beta = self.optimizer.param_groups[i].get('momentum', 0.95)
1548+
update = muon_update(grad_cpu.view(param.size()), momentum_cpu, beta=beta)
1549+
1550+
momentum_flat[muon_offset:muon_offset + param.numel()] = momentum_cpu.view(-1)
1551+
1552+
# Write only the partition slice of the update to CPU FP32 grad buffer
1553+
tensor_offset = 0
1554+
actual_num_elements = param.numel()
1555+
if source_offset > 0:
1556+
tensor_offset = source_offset
1557+
actual_num_elements = param.numel() - tensor_offset
1558+
if actual_num_elements > num_elements:
1559+
actual_num_elements = num_elements
1560+
1561+
dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, actual_num_elements)
1562+
update_slice = update.view(-1).narrow(0, tensor_offset, actual_num_elements)
1563+
dest_tensor.copy_(update_slice.to(self.master_weights_and_grads_dtype))
1564+
1565+
self.clear_grad_attribute(param)
1566+
return True
1567+
15041568
def copy_grads_in_partition(self, param):
15051569
if self.cpu_offload:
15061570

@@ -1512,7 +1576,8 @@ def copy_grads_in_partition(self, param):
15121576

15131577
self.update_offload_overflow_tracker_for_param_grad(param)
15141578

1515-
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
1579+
if not self._apply_muon_update_for_cpu_offload(param):
1580+
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
15161581

15171582
return
15181583
#print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}")

docs/_pages/config-json.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ toc_label: "Contents"
3939
| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, **Lamb**, **OneBitLamb**, and **Muon** optimizers (See [here](https://deepspeed.readthedocs.io/en/latest/optimizers.html) for details) and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` |
4040
| params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` |
4141

42-
Muon optimizer is supported with ZeRO Stage 1, 2, and 3. To use Muon, set the optimizer name to `Muon`. The parameters applied for Muon are automatically determined by the matrix shape and name. For ZeRO Stage 3 with NVMe offloading, set `save_muon_momentum_buffer_in_memory` to `true` under `zero_optimization` to keep the Muon momentum buffer in GPU/CPU memory instead of swapping to NVMe.
42+
Muon optimizer is supported with ZeRO Stage 1, 2, and 3, including CPU offload (`offload_optimizer`) for all stages. To use Muon, set the optimizer name to `Muon`. The parameters applied for Muon are automatically determined by the matrix shape and name. For ZeRO Stage 3 with NVMe offloading, set `save_muon_momentum_buffer_in_memory` to `true` under `zero_optimization` to keep the Muon momentum buffer in GPU/CPU memory instead of swapping to NVMe.
4343

4444
Example of <i>**optimizer**</i> with Adam
4545

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# DeepSpeed Team
4+
5+
import deepspeed
6+
import torch
7+
import pytest
8+
9+
from unit.common import DistributedTest
10+
from unit.simple_model import SimpleModel
11+
from deepspeed.accelerator import get_accelerator
12+
13+
if torch.half not in get_accelerator().supported_dtypes():
14+
pytest.skip(f"fp16 not supported", allow_module_level=True)
15+
16+
17+
@pytest.mark.parametrize('zero_stage', [2])
18+
class TestMuonCPUOffload(DistributedTest):
19+
20+
def test_momentum_buffer_on_cpu(self, zero_stage):
21+
"""Verify Muon CPU offload creates momentum buffer on CPU.
22+
23+
This is the key invariant: after a training step with CPU offload,
24+
the Muon momentum buffer must reside on CPU (not GPU), confirming
25+
that muon_update ran on CPU and no GPU memory is wasted.
26+
"""
27+
hidden_dim = 32
28+
batch_size = 8
29+
config_dict = {
30+
"train_batch_size": batch_size,
31+
"optimizer": {
32+
"type": "muon",
33+
"params": {
34+
"lr": 0.01
35+
}
36+
},
37+
"fp16": {
38+
"enabled": True
39+
},
40+
"zero_optimization": {
41+
"stage": zero_stage,
42+
"reduce_scatter": False,
43+
"offload_optimizer": {
44+
"device": "cpu",
45+
"pin_memory": True,
46+
},
47+
},
48+
}
49+
50+
model = SimpleModel(hidden_dim=hidden_dim, nlayers=5)
51+
engine, optimizer, _, _ = deepspeed.initialize(
52+
config=config_dict,
53+
model=model,
54+
model_parameters=model.parameters(),
55+
dist_init_required=False,
56+
)
57+
58+
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half)
59+
y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device)
60+
loss = engine(x, y)
61+
engine.backward(loss)
62+
engine.step()
63+
64+
# Muon momentum buffer must exist and be on CPU.
65+
# If muon_update was silently skipped, momentum_buffer would not be created.
66+
flatten_copy = optimizer.optimizer.param_groups[0]['params'][0]
67+
state = optimizer.optimizer.state[flatten_copy]
68+
assert 'momentum_buffer' in state, ("momentum_buffer not found in optimizer state. "
69+
"muon_update was not called in the CPU offload path.")
70+
assert state['momentum_buffer'].device.type == 'cpu', (
71+
f"Momentum buffer is on {state['momentum_buffer'].device}, expected CPU")
72+
73+
74+
@pytest.mark.parametrize('zero_stage', [2])
75+
class TestMuonCPUOffloadCosim(DistributedTest):
76+
77+
def test_cosim_offload_vs_no_offload(self, zero_stage):
78+
"""Verify CPU offload produces results consistent with GPU path.
79+
80+
With the same random seed, offload and non-offload should produce
81+
close parameters. If muon_update is skipped or wrong in either path,
82+
the results diverge significantly.
83+
"""
84+
hidden_dim = 32
85+
batch_size = 8
86+
87+
def train(offload):
88+
torch.manual_seed(42)
89+
config_dict = {
90+
"train_batch_size": batch_size,
91+
"optimizer": {
92+
"type": "muon",
93+
"params": {
94+
"lr": 0.01
95+
}
96+
},
97+
"fp16": {
98+
"enabled": True
99+
},
100+
"zero_optimization": {
101+
"stage": zero_stage,
102+
"reduce_scatter": False,
103+
},
104+
}
105+
if offload:
106+
config_dict["zero_optimization"]["offload_optimizer"] = {
107+
"device": "cpu",
108+
"pin_memory": True,
109+
}
110+
111+
model = SimpleModel(hidden_dim=hidden_dim, nlayers=5)
112+
engine, _, _, _ = deepspeed.initialize(
113+
config=config_dict,
114+
model=model,
115+
model_parameters=model.parameters(),
116+
dist_init_required=False,
117+
)
118+
119+
for _ in range(3):
120+
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half)
121+
y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device)
122+
loss = engine(x, y)
123+
engine.backward(loss)
124+
engine.step()
125+
126+
return {n: p.clone().detach().float().cpu() for n, p in model.named_parameters()}
127+
128+
params_offload = train(offload=True)
129+
params_no_offload = train(offload=False)
130+
131+
for name in params_offload:
132+
p_off = params_offload[name]
133+
p_no = params_no_offload[name]
134+
# Both paths should produce the same NaN pattern
135+
nan_mask = p_off.isnan() | p_no.isnan()
136+
assert nan_mask.equal(p_off.isnan()), (f"{name}: NaN pattern differs between offload and non-offload. "
137+
"muon_update produced different results.")
138+
# On non-NaN elements, cosine similarity should be very high
139+
valid = ~nan_mask
140+
if valid.sum() > 0:
141+
cos_sim = torch.nn.functional.cosine_similarity(p_off[valid].unsqueeze(0),
142+
p_no[valid].unsqueeze(0)).item()
143+
assert cos_sim > 0.99, (f"{name}: cosine similarity {cos_sim:.4f} between offload and "
144+
f"non-offload is too low, indicating muon_update results diverge.")

0 commit comments

Comments
 (0)