Skip to content

feat(zero2): add CPU offload support for Muon optimizer#7939

Open
delock wants to merge 2 commits intodeepspeedai:masterfrom
delock:gma/muon_cpuoffload
Open

feat(zero2): add CPU offload support for Muon optimizer#7939
delock wants to merge 2 commits intodeepspeedai:masterfrom
delock:gma/muon_cpuoffload

Conversation

@delock
Copy link
Copy Markdown
Collaborator

@delock delock commented Mar 31, 2026

Add Muon optimizer support in ZeRO Stage 1&2 CPU offload path:
Momentum is stored on CPU memory and Newton-Schultz algorithm happens on GPU.

This PR complete the piece in ZeRO2 and make CPU offload has same numerical behavior with non-CPU offload of ZeRO2.

Note this PR also contains code from PR #7953, this PR is intend to be merged after #7953.

@delock delock marked this pull request as draft March 31, 2026 07:02
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 54364fbe9a

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

pad_tensor = torch.zeros(padded_size - self.bit16_groups_flat[i].numel(),
dtype=self.bit16_groups_flat[i].dtype,
device=self.bit16_groups_flat[i].device)
self.bit16_groups_flat[i] = torch.cat([self.bit16_groups_flat[i], pad_tensor])
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Insert per-partition padding before Muon equal split

Appending a single padding block at the tail does not guarantee parameter-boundary partitioning: when an earlier partition is smaller than max_partition_size (e.g., sizes [4,5,1] for dp=3), get_data_parallel_partitions() still cuts at fixed max_partition_size offsets and splits a parameter across ranks. That breaks the new CPU-offload Muon path, which assumes unsplit parameters and writes a full update.view(-1) into a partition slice computed from grad_position, leading to shape mismatch or incorrect updates when source_offset != 0.

Useful? React with 👍 / 👎.

if self._is_muon_param_group(i):
dp_size = dist.get_world_size(group=self.real_dp_process_group[i])
max_ps = self._get_muon_max_partition_size(self.round_robin_bit16_groups[i], dp_size, orig_group_numel)
padded_size = max_ps * dp_size
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep Muon partition size aligned for NCCL boundaries

max_partition_size is used directly to set padded_size, but it is not rounded to the existing NCCL start-alignment factor. If max_partition_size is odd with fp16/bf16 tensors, partition starts after rank 0 become 2-byte shifted and fail the existing 4-byte alignment assertion in the same initialization flow. This makes valid Muon configurations crash depending on parameter shapes.

Useful? React with 👍 / 👎.

@delock delock force-pushed the gma/muon_cpuoffload branch 2 times, most recently from d802f0e to c058864 Compare March 31, 2026 10:07
@delock delock force-pushed the gma/muon_cpuoffload branch from 59bae08 to 0a4d98e Compare April 9, 2026 06:23
delock added 2 commits April 8, 2026 23:43
Add Gram Newton-Schulz (NS) orthogonalization as an alternative to the
standard NS iteration. Gram NS iterates on the small square Gram matrix
R = X @ X.T (n x n) instead of the full rectangular X (n x m), reducing
FLOPs significantly for typical transformer weight matrices (aspect
ratio ~5). Uses fp16 with a mid-iteration restart for numerical stability.

Key changes:
- Add _zeropower_via_gram_newtonschulz with addmm fusion optimizations
- Add ns_method parameter ("gram" or "standard") defaulting to "gram"
- Use accelerator API for compute dtype selection instead of hardcoded bf16
- Remove torch.compile decorator from NS functions (no measurable benefit)
- Thread ns_method through ZeRO Stage 1/2/3, engine whitelist, and docs

Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
Enable Muon optimizer with ZeRO Stage 2 CPU offload. The Newton-Schulz
orthogonalization always runs on GPU for performance (momentum is
temporarily moved to GPU), while the momentum buffer stays on CPU to
save GPU memory.

The _apply_muon_update_for_cpu_offload method intercepts the gradient
copy path in copy_grads_in_partition to apply muon_update before
writing to the CPU FP32 grad buffer. Cross-boundary parameters are
handled by processing the full gradient on each involved rank.

Includes cosimulation test verifying offload vs non-offload produce
consistent results.

Signed-off-by: Ma, Guokai <guokai.ma@gmail.com>
@delock delock force-pushed the gma/muon_cpuoffload branch from 0a4d98e to 456b565 Compare April 9, 2026 06:44
@delock delock changed the title [DRAFT] feat(zero2): add CPU offload support for Muon optimizer feat(zero2): add CPU offload support for Muon optimizer Apr 9, 2026
@delock delock marked this pull request as ready for review April 9, 2026 07:38
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: 456b5650b1

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Q = Z.clone()
Q.diagonal().add_(a)
else:
Q = torch.addmm(Q, Z, Q, beta=a, alpha=1.0)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Support batched tensors in Gram Newton-Schulz update

The new default ns_method="gram" regresses Muon's stated batched input support (grad.ndim >= 2): this path uses torch.addmm, which only accepts 2D inputs, so a Muon parameter with shape like (B, N, M) will now fail at runtime in _zeropower_via_gram_newtonschulz. Previously, the standard Newton-Schulz implementation used batched matmuls and handled these shapes, so this commit introduces a crash for valid prior inputs unless users manually switch to ns_method="standard".

Useful? React with 👍 / 👎.

Comment on lines 1587 to +1590
self.update_offload_overflow_tracker_for_param_grad(param)

self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
if not self._apply_muon_update_for_cpu_offload(param):
self.async_inplace_copy_grad_to_fp32_buffer_from_gpu(param)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Track Muon overflow after CPU-offload update

In the CPU-offload path, overflow tracking is performed before _apply_muon_update_for_cpu_offload(param), but has_overflow() for offload relies only on self.local_overflow. Because no post-muon_update inf/nan check is done, numerical failures introduced by the Newton-Schulz step can bypass loss-scaling overflow handling and still be applied to optimizer state. This diverges from the non-offload path, where overflow is checked on the Muon-transformed gradients.

Useful? React with 👍 / 👎.

raise ValueError(f"All Muon parameter groups must have the same momentum (beta). "
f"Found {self.muon_beta} and {group_beta}.")
self.muon_beta = group_beta
self.muon_ns_method = param_group.get('ns_method', 'gram')
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Preserve ns_method per Muon param group in ZeRO-3

ZeRO-3 stores ns_method in a single self.muon_ns_method while iterating all Muon param groups, so later groups overwrite earlier values. _muon_update_grads_in_place then applies that one method to every Muon subgroup, which silently ignores per-group configuration when users provide multiple Muon groups (a pattern already handled for momentum via explicit consistency checks).

Useful? React with 👍 / 👎.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant