You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
feat(zero2): add CPU offload support for Muon optimizer
Add Muon optimizer support in ZeRO Stage 1&2 CPU offload path by:
1. Partition strategy: Muon param groups now partition by parameter
boundaries (never split a param across ranks), padding to uniform
max size for all-gather compatibility. Logs padding overhead ratio.
2. CPU Newton-Schulz: Add muon_update_cpu() and
zeropower_via_newtonschulz5_cpu() using PyTorch CPU bf16 matmul
as baseline. Architecture allows future replacement with AMX C++ kernel.
3. CPU offload integration: _apply_muon_update_for_cpu_offload() copies
complete gradients to CPU, runs muon_update on CPU (momentum buffer
stays on CPU), writes result to FP32 grad buffer. No extra PCIe transfers.
Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Copy file name to clipboardExpand all lines: docs/_pages/config-json.md
+1-1Lines changed: 1 addition & 1 deletion
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -39,7 +39,7 @@ toc_label: "Contents"
39
39
| type | The optimizer name. DeepSpeed natively supports **Adam**, **AdamW**, **OneBitAdam**, **Lamb**, **OneBitLamb**, and **Muon** optimizers (See [here](https://deepspeed.readthedocs.io/en/latest/optimizers.html) for details) and will import other optimizers from [torch](https://pytorch.org/docs/stable/optim.html). |`"Adam"`|
40
40
| params | Dictionary of parameters to instantiate optimizer. The parameter names must match the optimizer constructor signature (e.g., for [Adam](https://pytorch.org/docs/stable/optim.html#torch.optim.Adam)). |`{"lr": 0.001, "eps": 1e-8}`|
41
41
42
-
Muon optimizer is supported with ZeRO Stage 1, 2, and 3. To use Muon, set the optimizer name to `Muon`. The parameters applied for Muon are automatically determined by the matrix shape and name. For ZeRO Stage 3 with NVMe offloading, set `save_muon_momentum_buffer_in_memory` to `true` under `zero_optimization` to keep the Muon momentum buffer in GPU/CPU memory instead of swapping to NVMe.
42
+
Muon optimizer is supported with ZeRO Stage 1, 2, and 3, including CPU offload (`offload_optimizer`) for all stages. To use Muon, set the optimizer name to `Muon`. The parameters applied for Muon are automatically determined by the matrix shape and name. For ZeRO Stage 3 with NVMe offloading, set `save_muon_momentum_buffer_in_memory` to `true` under `zero_optimization` to keep the Muon momentum buffer in GPU/CPU memory instead of swapping to NVMe.
0 commit comments