Skip to content

Commit d802f0e

Browse files
committed
feat(zero2): add CPU offload support for Muon optimizer
Add Muon optimizer support in ZeRO Stage 1&2 CPU offload path by: 1. Partition strategy: Muon param groups now partition by parameter boundaries (never split a param across ranks), padding to uniform max size for all-gather compatibility. Logs padding overhead ratio. 2. CPU Newton-Schulz: Add muon_update_cpu() and zeropower_via_newtonschulz5_cpu() using PyTorch CPU bf16 matmul as baseline. Architecture allows future replacement with AMX C++ kernel. 3. CPU offload integration: _apply_muon_update_for_cpu_offload() copies complete gradients to CPU, runs muon_update on CPU (momentum buffer stays on CPU), writes result to FP32 grad buffer. No extra PCIe transfers. Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
1 parent abb88ce commit d802f0e

3 files changed

Lines changed: 274 additions & 2 deletions

File tree

deepspeed/runtime/zero/stage_1_and_2.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,20 @@ def _enforce_cpu_offload():
445445
# set model bit16 weight to slices of flattened buffer
446446
self._update_model_bit16_weights(i)
447447

448+
# For Muon param groups, pad flat buffer so partition boundaries
449+
# never split a parameter. Each partition gets a uniform size
450+
# (max_partition_size) suitable for all-gather.
451+
if self._is_muon_param_group(i):
452+
dp_size = dist.get_world_size(group=self.real_dp_process_group[i])
453+
max_ps = self._get_muon_max_partition_size(self.round_robin_bit16_groups[i], dp_size, orig_group_numel)
454+
padded_size = max_ps * dp_size
455+
if padded_size > self.bit16_groups_flat[i].numel():
456+
pad_tensor = torch.zeros(padded_size - self.bit16_groups_flat[i].numel(),
457+
dtype=self.bit16_groups_flat[i].dtype,
458+
device=self.bit16_groups_flat[i].device)
459+
self.bit16_groups_flat[i] = torch.cat([self.bit16_groups_flat[i], pad_tensor])
460+
self._update_model_bit16_weights(i)
461+
448462
# divide the flat weights into near equal partition equal to the data parallel degree
449463
# each process will compute on a different part of the partition
450464
data_parallel_partitions = self.get_data_parallel_partitions(self.bit16_groups_flat[i], i)
@@ -1501,6 +1515,64 @@ def complete_grad_norm_calculation_for_cpu_offload(self, params):
15011515
return torch.tensor(total_norm, device=self.device, dtype=torch.float)
15021516

15031517
############################################################################################
1518+
def _apply_muon_update_for_cpu_offload(self, param):
1519+
"""Apply muon_update on CPU for a parameter in the CPU offload path.
1520+
1521+
With Muon-aware partitioning (parameters never split), grad_accum is
1522+
complete. This method copies it to CPU, runs Newton-Schulz on CPU,
1523+
writes the result to the FP32 grad buffer, and clears the GPU grad_accum.
1524+
Returns True if muon_update was applied (caller should skip the normal copy).
1525+
"""
1526+
if not getattr(param, 'use_muon', False):
1527+
return False
1528+
if 'muon' not in self.optimizer.__class__.__name__.lower():
1529+
return False
1530+
1531+
param_id = self.get_param_id(param)
1532+
[i, source_offset, dest_offset, num_elements] = self.grad_position[param_id]
1533+
1534+
grad_accum = self.get_param_gradient_attribute(param)
1535+
if grad_accum is None:
1536+
return False
1537+
1538+
# Copy full grad to CPU with original shape for Newton-Schulz
1539+
grad_cpu = grad_accum.detach().clone().to(device=self.device, dtype=torch.float32)
1540+
1541+
# Get or create momentum buffer for this param group
1542+
flatten_copy = self.optimizer.param_groups[i]['params'][0]
1543+
if 'momentum_buffer' not in self.optimizer.state[flatten_copy]:
1544+
total_size = sum(p.numel() for p in self.params_in_partition[i] if getattr(p, 'use_muon', False))
1545+
self.optimizer.state[flatten_copy]['momentum_buffer'] = torch.zeros(total_size,
1546+
dtype=torch.float32,
1547+
device=self.device)
1548+
1549+
momentum_flat = self.optimizer.state[flatten_copy]['momentum_buffer']
1550+
1551+
# Find this param's offset within the muon momentum buffer
1552+
muon_offset = 0
1553+
for p in self.params_in_partition[i]:
1554+
if p is param:
1555+
break
1556+
if getattr(p, 'use_muon', False):
1557+
muon_offset += p.numel()
1558+
1559+
momentum_cpu = momentum_flat[muon_offset:muon_offset + param.numel()].view(param.size())
1560+
1561+
# Run muon update on CPU
1562+
beta = self.optimizer.param_groups[i].get('momentum', 0.95)
1563+
update = muon_update(grad_cpu.view(param.size()), momentum_cpu, beta=beta)
1564+
1565+
# Write updated momentum back to flat buffer
1566+
momentum_flat[muon_offset:muon_offset + param.numel()] = momentum_cpu.view(-1)
1567+
1568+
# Write updated gradient to the CPU FP32 grad buffer
1569+
dest_tensor = self.single_partition_of_fp32_groups[i].grad.view(-1).narrow(0, dest_offset, num_elements)
1570+
dest_tensor.copy_(update.view(-1).to(self.master_weights_and_grads_dtype))
1571+
1572+
# Clear the GPU grad_accum since we already consumed it
1573+
self.clear_grad_attribute(param)
1574+
return True
1575+
15041576
def copy_grads_in_partition(self, param):
15051577
if self.cpu_offload:
15061578

@@ -1512,7 +1584,8 @@ def copy_grads_in_partition(self, param):
15121584

15131585
self.update_offload_overflow_tracker_for_param_grad(param)
15141586

1515-
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
1587+
if not self._apply_muon_update_for_cpu_offload(param):
1588+
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
15161589

15171590
return
15181591
#print(f"ID {self.get_param_id(param)} grad norm {param.grad.norm()}")
@@ -1826,6 +1899,61 @@ def get_data_parallel_partitions(self, tensor, group_id):
18261899
start = start + partition_size
18271900
return partitions
18281901

1902+
def _is_muon_param_group(self, group_index):
1903+
"""Check if a parameter group uses the Muon optimizer."""
1904+
params = self.round_robin_bit16_groups[group_index]
1905+
return (params and getattr(params[0], 'use_muon', False)
1906+
and 'muon' in self.optimizer.__class__.__name__.lower())
1907+
1908+
def _get_muon_max_partition_size(self, tensor_list, dp, total_num_elements):
1909+
"""Compute the max partition size when partitioning by parameter boundaries.
1910+
1911+
Parameters are assigned sequentially to exactly dp partitions. A new
1912+
partition starts when the current one has reached the target size and
1913+
there are still remaining partitions to fill. All partitions are padded
1914+
to the largest one so all-gather works with equal-sized chunks.
1915+
1916+
Returns:
1917+
max_partition_size: the uniform (padded) partition size.
1918+
"""
1919+
target_size = total_num_elements / dp
1920+
partition_sizes = []
1921+
current_size = 0
1922+
remaining_partitions = dp
1923+
1924+
for tensor in tensor_list:
1925+
numel = tensor.numel()
1926+
assert numel <= total_num_elements, (f"Muon parameter with {numel} elements exceeds total "
1927+
f"{total_num_elements} elements.")
1928+
if current_size >= target_size and remaining_partitions > 1:
1929+
partition_sizes.append(current_size)
1930+
remaining_partitions -= 1
1931+
current_size = 0
1932+
current_size += numel
1933+
1934+
if current_size > 0:
1935+
partition_sizes.append(current_size)
1936+
remaining_partitions -= 1
1937+
while remaining_partitions > 0:
1938+
partition_sizes.append(0)
1939+
remaining_partitions -= 1
1940+
1941+
assert len(partition_sizes) == dp
1942+
1943+
max_partition_size = max(partition_sizes)
1944+
# Align to nccl_start_alignment_factor to guarantee 4-byte partition boundaries
1945+
alignment = self.nccl_start_alignment_factor * dp
1946+
max_partition_size = (max_partition_size + alignment - 1) // alignment * alignment
1947+
total_padded = max_partition_size * dp
1948+
padding_ratio = ((total_padded - total_num_elements) / total_num_elements if total_num_elements > 0 else 0)
1949+
if dist.get_rank() == 0:
1950+
logger.info(f"Muon partition: max_partition_size={max_partition_size}, "
1951+
f"total_elements={total_num_elements}, "
1952+
f"total_padded={total_padded}, "
1953+
f"padding_ratio={padding_ratio:.4f}")
1954+
1955+
return max_partition_size
1956+
18291957
def get_partition_info(self, tensor_list, partition_size, partition_id):
18301958
params_in_partition = []
18311959
params_not_in_partition = []

docs/_pages/config-json.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ toc_label: "Contents"
3939
| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, **Lamb**, **OneBitLamb**, and **Muon** optimizers (See [here](https://deepspeed.readthedocs.io/en/latest/optimizers.html) for details) and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). | `"Adam"` |
4040
| params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). | `{"lr": 0.001, "eps": 1e-8}` |
4141

42-
Muon optimizer is supported with ZeRO Stage 1, 2, and 3. To use Muon, set the optimizer name to `Muon`. The parameters applied for Muon are automatically determined by the matrix shape and name. For ZeRO Stage 3 with NVMe offloading, set `save_muon_momentum_buffer_in_memory` to `true` under `zero_optimization` to keep the Muon momentum buffer in GPU/CPU memory instead of swapping to NVMe.
42+
Muon optimizer is supported with ZeRO Stage 1, 2, and 3, including CPU offload (`offload_optimizer`) for all stages. To use Muon, set the optimizer name to `Muon`. The parameters applied for Muon are automatically determined by the matrix shape and name. For ZeRO Stage 3 with NVMe offloading, set `save_muon_momentum_buffer_in_memory` to `true` under `zero_optimization` to keep the Muon momentum buffer in GPU/CPU memory instead of swapping to NVMe.
4343

4444
Example of <i>**optimizer**</i> with Adam
4545

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
# DeepSpeed Team
4+
5+
import deepspeed
6+
import torch
7+
import pytest
8+
9+
from unit.common import DistributedTest
10+
from unit.simple_model import SimpleModel
11+
from deepspeed.accelerator import get_accelerator
12+
13+
if torch.half not in get_accelerator().supported_dtypes():
14+
pytest.skip(f"fp16 not supported", allow_module_level=True)
15+
16+
17+
@pytest.mark.parametrize('zero_stage', [2])
18+
class TestMuonCPUOffload(DistributedTest):
19+
20+
def test_momentum_buffer_on_cpu(self, zero_stage):
21+
"""Verify Muon CPU offload creates momentum buffer on CPU.
22+
23+
This is the key invariant: after a training step with CPU offload,
24+
the Muon momentum buffer must reside on CPU (not GPU), confirming
25+
that muon_update ran on CPU and no GPU memory is wasted.
26+
"""
27+
hidden_dim = 32
28+
batch_size = 8
29+
config_dict = {
30+
"train_batch_size": batch_size,
31+
"optimizer": {
32+
"type": "muon",
33+
"params": {
34+
"lr": 0.01
35+
}
36+
},
37+
"fp16": {
38+
"enabled": True
39+
},
40+
"zero_optimization": {
41+
"stage": zero_stage,
42+
"reduce_scatter": False,
43+
"offload_optimizer": {
44+
"device": "cpu",
45+
"pin_memory": True,
46+
},
47+
},
48+
}
49+
50+
model = SimpleModel(hidden_dim=hidden_dim, nlayers=5)
51+
engine, optimizer, _, _ = deepspeed.initialize(
52+
config=config_dict,
53+
model=model,
54+
model_parameters=model.parameters(),
55+
dist_init_required=False,
56+
)
57+
58+
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half)
59+
y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device)
60+
loss = engine(x, y)
61+
engine.backward(loss)
62+
engine.step()
63+
64+
# Muon momentum buffer must exist and be on CPU.
65+
# If muon_update was silently skipped, momentum_buffer would not be created.
66+
flatten_copy = optimizer.optimizer.param_groups[0]['params'][0]
67+
state = optimizer.optimizer.state[flatten_copy]
68+
assert 'momentum_buffer' in state, ("momentum_buffer not found in optimizer state. "
69+
"muon_update was not called in the CPU offload path.")
70+
assert state['momentum_buffer'].device.type == 'cpu', (
71+
f"Momentum buffer is on {state['momentum_buffer'].device}, expected CPU")
72+
73+
74+
@pytest.mark.parametrize('zero_stage', [2])
75+
class TestMuonCPUOffloadCosim(DistributedTest):
76+
77+
def test_cosim_offload_vs_no_offload(self, zero_stage):
78+
"""Verify CPU offload produces results consistent with GPU path.
79+
80+
With the same random seed, offload and non-offload should produce
81+
close parameters. If muon_update is skipped or wrong in either path,
82+
the results diverge significantly.
83+
"""
84+
hidden_dim = 32
85+
batch_size = 8
86+
87+
def train(offload):
88+
torch.manual_seed(42)
89+
config_dict = {
90+
"train_batch_size": batch_size,
91+
"optimizer": {
92+
"type": "muon",
93+
"params": {
94+
"lr": 0.01
95+
}
96+
},
97+
"fp16": {
98+
"enabled": True
99+
},
100+
"zero_optimization": {
101+
"stage": zero_stage,
102+
"reduce_scatter": False,
103+
},
104+
}
105+
if offload:
106+
config_dict["zero_optimization"]["offload_optimizer"] = {
107+
"device": "cpu",
108+
"pin_memory": True,
109+
}
110+
111+
model = SimpleModel(hidden_dim=hidden_dim, nlayers=5)
112+
engine, _, _, _ = deepspeed.initialize(
113+
config=config_dict,
114+
model=model,
115+
model_parameters=model.parameters(),
116+
dist_init_required=False,
117+
)
118+
119+
for _ in range(3):
120+
x = torch.randn(batch_size, hidden_dim, device=engine.device, dtype=torch.half)
121+
y = torch.randint(0, hidden_dim, (batch_size, ), device=engine.device)
122+
loss = engine(x, y)
123+
engine.backward(loss)
124+
engine.step()
125+
126+
return {n: p.clone().detach().float().cpu() for n, p in model.named_parameters()}
127+
128+
params_offload = train(offload=True)
129+
params_no_offload = train(offload=False)
130+
131+
for name in params_offload:
132+
p_off = params_offload[name]
133+
p_no = params_no_offload[name]
134+
# Both paths should produce the same NaN pattern
135+
nan_mask = p_off.isnan() | p_no.isnan()
136+
assert nan_mask.equal(p_off.isnan()), (f"{name}: NaN pattern differs between offload and non-offload. "
137+
"muon_update produced different results.")
138+
# On non-NaN elements, cosine similarity should be very high
139+
valid = ~nan_mask
140+
if valid.sum() > 0:
141+
cos_sim = torch.nn.functional.cosine_similarity(p_off[valid].unsqueeze(0),
142+
p_no[valid].unsqueeze(0)).item()
143+
assert cos_sim > 0.99, (f"{name}: cosine similarity {cos_sim:.4f} between offload and "
144+
f"non-offload is too low, indicating muon_update results diverge.")

0 commit comments

Comments
 (0)