Skip to content

perf(pt): optimize HybridMuon optimizer#5412

Open
OutisLi wants to merge 4 commits intodeepmodeling:masterfrom
OutisLi:pr/muon
Open

perf(pt): optimize HybridMuon optimizer#5412
OutisLi wants to merge 4 commits intodeepmodeling:masterfrom
OutisLi:pr/muon

Conversation

@OutisLi
Copy link
Copy Markdown
Collaborator

@OutisLi OutisLi commented Apr 22, 2026

Summary by CodeRabbit

  • New Features

    • Added Gram Newton–Schulz orthogonalization for rectangular Muon matrices with enable_gram enabled by default (auto-disabled in distributed runs).
    • Introduced use_foreach switch to enable foreach-optimized updates when applicable.
  • Chores

    • Changed lr_adjust default to 0.0.
    • Changed magma_muon default to True.
    • Clarified weight_decay applies to both Muon-routed and AdamW-style matrix decay.
    • Training now disables foreach under higher sharding/ZeRO stages.
  • Tests

    • Added fp16/bf16-gated tests validating Gram behavior, column-pad merging, and end-to-end optimizer updates.

Copilot AI review requested due to automatic review settings April 22, 2026 05:50
@dosubot dosubot Bot added the enhancement label Apr 22, 2026
Comment thread deepmd/pt/optimizer/hybrid_muon.py Fixed
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 22, 2026

📝 Walkthrough

Walkthrough

Adds a compiled Gram Newton–Schulz orthogonalization path for rectangular Muon matrices, refactors Muon bucketing/dispatch and foreach-aware update kernels, exposes new optimizer options (enable_gram, use_foreach) with adjusted defaults, disables Gram in distributed runs, and adds fp16/bf16-gated tests for the new flows.

Changes

Cohort / File(s) Summary
Core Optimizer Implementation
deepmd/pt/optimizer/hybrid_muon.py
Adds compiled Gram Newton–Schulz orthogonalizer for rectangular matrices (float32 normalize → float16 Gram iterations + restarts); changes bucket keys to (rows, cols, device, dtype) and merges rectangular buckets via column-padding; precomputes Magma-lite damping scales; refactors update loops into foreach-aware helpers (_compute_muon_nesterov_updates, _adam_update_moments, _weight_decay_inplace); adds enable_gram: bool = True, `use_foreach: bool
Training Integration
deepmd/pt/train/training.py
Builds optimizer extra flags: enable_gram disabled for distributed runs, flash_muon/magma_muon forwarded from config, and use_foreach forced False for ZeRO/FSDP ≥2 else None.
Configuration Schema
deepmd/utils/argcheck.py
Adds enable_gram arg (default True); updates docs for weight_decay to include AdamW-style matrix decay and notes flash_muon is ignored when enable_gram=True; changes magma_muon default from FalseTrue.
Tests
source/tests/pt/test_hybrid_muon.py
Adds _fp16_matmul_supported probe and FP16_SUPPORTED gate; new tests asserting enable_gram default/behavior, rectangular-path state, and flash-path explicit disables; adds TestColumnPadMergeEquivalence (bf16-gated) validating column-pad merge equivalence and an end-to-end mixed-shape optimizer step under enable_gram=True.

Sequence Diagram(s)

sequenceDiagram
    participant Trainer as Trainer/Caller
    participant Optimizer as HybridMuon Optimizer
    participant Bucket as Bucket Dispatcher
    participant GramNS as Gram Newton‑Schulz (rectangular)
    participant SquareNS as Standard Newton‑Schulz (square)
    participant Params as Parameters/State

    Trainer->>Optimizer: step(request, configs)
    Optimizer->>Bucket: classify params by (rows, cols, device, dtype)
    Bucket->>Bucket: split into square vs rectangular buckets
    alt square buckets
        Bucket->>SquareNS: dispatch square buckets
        SquareNS->>SquareNS: batched Newton‑Schulz (bfloat16)
        SquareNS->>Optimizer: orthogonalized entries
    end
    alt rectangular buckets (enable_gram)
        Bucket->>GramNS: merge/pad columns and dispatch
        GramNS->>GramNS: normalize float32 → Gram NS float16 (restarts) → cast back
        GramNS->>Optimizer: unpad/untranspose & scale entries
    end
    Optimizer->>Params: apply foreach-aware moment updates, weight decay, Nesterov adjustments
    Optimizer->>Trainer: commit parameter updates and optimizer state
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~70 minutes

Possibly related PRs

Suggested reviewers

  • njzjz
  • wanghan-iapcm
  • iProzd
🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 78.38% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'perf(pt): optimize HybridMuon optimizer' accurately summarizes the main objective of the pull request, which introduces performance optimizations to the HybridMuon optimizer through Gram Newton-Schulz orthogonalization, foreach-based acceleration, and refactored internals.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
deepmd/pt/train/training.py (1)

899-913: ⚠️ Potential issue | 🟠 Major

Asymmetric distributed handling: use_foreach / flash_muon not guarded like enable_gram.

You correctly force enable_gram=False when self.is_distributed to dodge torch.compile + DTensor issues, but:

  • use_foreach is never passed in extra, so the optimizer resolves it to True unconditionally — under FSDP2 (zero_stage >= 2) that can trip DTensor dispatch errors in torch._foreach_* calls in the Muon/Adam hot paths.
  • flash_muon is still forwarded verbatim, even though the triton symmetric-matmul kernel operates on raw tensors and will likely fail on sharded DTensor grads with zero_stage >= 2.

Consider mirroring the enable_gram guard for these two, e.g.:

🛠️ Suggested diff
             elif self.opt_type == "HybridMuon":
                 cls = HybridMuonOptimizer
+                fsdp2 = self.zero_stage >= 2
                 extra = {
                     "adam_betas": adam_betas,
                     "momentum": float(self.opt_param["momentum"]),
                     "lr_adjust": float(self.opt_param["lr_adjust"]),
                     "lr_adjust_coeff": float(self.opt_param["lr_adjust_coeff"]),
                     "muon_mode": str(self.opt_param.get("muon_mode", "slice")),
                     "named_parameters": tuple(self.wrapper.named_parameters()),
                     "enable_gram": False
                     if self.is_distributed
                     else bool(self.opt_param.get("enable_gram")),
-                    "flash_muon": bool(self.opt_param.get("flash_muon")),
+                    "flash_muon": False if fsdp2 else bool(self.opt_param.get("flash_muon")),
                     "magma_muon": bool(self.opt_param.get("magma_muon")),
+                    "use_foreach": False if fsdp2 else None,
                 }

Also note that bool(self.opt_param.get("enable_gram")) (and the two bool(...get(...)) siblings) silently becomes False when the key is missing. argcheck.normalize is expected to fill defaults, but for a belt-and-suspenders safeguard pass a default (e.g. self.opt_param.get("enable_gram", True)) matching the argcheck default.

deepmd/utils/argcheck.py (1)

3146-3156: ⚠️ Potential issue | 🟡 Minor

Flipping magma_muon default FalseTrue is a silent behavior change.

Existing HybridMuon configs that don't explicitly set magma_muon will now get Magma-lite damping applied on the Muon route, which rescales updates and alters the training trajectory (and final weights/loss curves) compared to prior releases. This is a backward-incompatible change for anyone resuming or reproducing runs.

Consider one of the following:

  • Keep the default at False and let the new path be opt-in, or
  • Document this as a breaking change in the release notes / PR description and call it out in the schema doc (e.g., "Default changed from False to True in vX.Y").

Same concern, to a lesser degree, applies to enable_gram defaulting to True — any numerical divergence between Gram NS (fp16/fp32) and the previous bf16 standard NS path on rectangular matrices will also shift training results under the default config.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/utils/argcheck.py` around lines 3146 - 3156, The change flips the
Argument("magma_muon", ..., default=True) default which silently alters training
for configs that don't set it; revert the default to False so Magma-lite damping
remains opt-in (change Argument(... default=True) back to default=False for
"magma_muon"), and either keep or separately document the intentional default
flip for "enable_gram" (Argument("enable_gram", ..., default=True)) by adding a
clear breaking-change note in the PR/changelog and the schema docs; update any
unit/integration tests that assume the old defaults to reflect the reverted
default or the documented breaking change.
🧹 Nitpick comments (4)
deepmd/pt/optimizer/hybrid_muon.py (1)

564-640: Remove dead helper functions left behind by refactoring.

_stack_bucket_updates (lines 564–595) and _orthogonalize_standard_stacked (lines 598–640) are defined but never used anywhere in the codebase. The step() path builds stacked tensors inline via _reshape_update_to_matrix_batch + torch.cat and dispatches _batched_newton_schulz_orth / _flash_newton_schulz_orth directly (lines 1858–1886), and the Gram path uses its own flow in _process_merged_gram_buckets. Remove these helpers to keep the module's public surface clean.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/optimizer/hybrid_muon.py` around lines 564 - 640, Remove the two
unused helper functions _stack_bucket_updates and
_orthogonalize_standard_stacked from the module: delete their full definitions
(the functions named _stack_bucket_updates and _orthogonalize_standard_stacked),
remove any now-unreferenced imports or symbols they relied on, and update
__all__ or public exports if they were listed there; after removal run
tests/linters to ensure no usages remain and adjust any lingering references to
_reshape_update_to_matrix_batch, _batched_newton_schulz_orth, or
_flash_newton_schulz_orth which remain the canonical paths.
source/tests/pt/test_hybrid_muon.py (2)

670-734: E2E coverage is good; consider asserting finite outputs during each step too.

test_optimizer_step_column_pad_merge_e2e is a solid integration check for the column-pad merge path. One small hardening suggestion: NaN/Inf during an intermediate step would still leave param != init_state[name], so the "was updated" assertion can pass even when a step produced non-finite values that later overwrote finite ones. Consider asserting torch.isfinite(param).all() inside the per-step loop (or at least on all final parameter tensors) in addition to the forward output check.

 x = torch.randn(4, 8, device=self.device)
 for _ in range(3):
     optimizer.zero_grad()
     model(x).backward()
     optimizer.step()
+    for name, param in model.named_parameters():
+        self.assertTrue(
+            torch.isfinite(param).all(),
+            f"Parameter {name} became non-finite during step",
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt/test_hybrid_muon.py` around lines 670 - 734, Add finite-value
assertions to the test_optimizer_step_column_pad_merge_e2e to ensure no NaN/Inf
appears during optimization: inside the training loop that runs
optimizer.zero_grad(); model(x).backward(); optimizer.step(), after each step
assert that every parameter tensor in model.named_parameters() is finite (use
torch.isfinite on each param) and/or assert the model forward output is finite;
also add a final per-parameter torch.isfinite check before the existing "was
updated" assertions to reference the same parameters (model.named_parameters,
optimizer) and the test function name (test_optimizer_step_column_pad_merge_e2e)
so the checks are colocated with the update and output assertions.

42-53: Nit: _fp16_matmul_supported may succeed on CPU with poor precision.

On CPU, fp16 matmul often runs without raising (dispatched via upcast), so this probe will return True even when fp16 is numerically unreliable. That's acceptable for gating skips (you prefer a false positive over false negative), but tests relying on fp16 tolerance (atol=1e-3) may become flaky on certain CPUs.

Optionally mirror _bf16_matmul_supported's CUDA fast-path and be more conservative on CPU (e.g., only enable FP16_SUPPORTED when CUDA with torch.cuda.is_available()), or document the tolerance rationale.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt/test_hybrid_muon.py` around lines 42 - 53, The FP16 probe
_fp16_matmul_supported currently returns True on CPU because fp16 ops may be
upcasted there, causing false positives and flaky atol=1e-3 tests; change
_fp16_matmul_supported to mirror the BF16 fast-path by gating FP16 support to
CUDA devices only (check torch.cuda.is_available() and device.type == "cuda" or
otherwise follow the same CUDA-specific checks used in _bf16_matmul_supported)
and update the FP16_SUPPORTED assignment to use the revised probe so CPU-only
environments do not report FP16_SUPPORTED=True.
deepmd/utils/argcheck.py (1)

3136-3145: UX: with both defaults True, flash_muon is dead-by-default.

Under the new defaults (enable_gram=True, flash_muon=True), flash_muon is silently ignored per the updated doc on line 3144. Users reading the schema may not immediately notice. Two optional refinements:

  • Make the doc more explicit, e.g. prepend "(Effective only when enable_gram=false.)" so it shows up at the top of the entry.
  • Or switch flash_muon default to False to reflect it's now a fallback path — this avoids misleading "enabled" status in logs/dumps.

Not a blocker; just reduces user confusion when debugging perf.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/utils/argcheck.py` around lines 3136 - 3145, The flash_muon Argument
is effectively ignored when enable_gram=True, which makes its default True
misleading; update the Argument for "flash_muon" (the Argument call that uses
doc_only_pt_supported and the doc string) so either (a) prepend a clear note to
the doc string like "(Effective only when `enable_gram=false`.) " before the
existing description to surface that dependency, or (b) change the default from
True to False so flash_muon reflects the fallback behavior when enable_gram
defaults to True; ensure references to enable_gram remain intact and the updated
doc string or default is applied in the same Argument(...) definition.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 1404-1406: The loop in hybrid_muon.py iterates "for (min_dim,
_dev, _dt), sub_list in super_buckets.items()" but never uses min_dim, causing
Ruff B007; rename the unused loop variable to _min_dim (i.e., change min_dim ->
_min_dim) in that for-statement (the loop that computes padded_max_dim using
max(r, c) over sub_list) so the linter recognizes it as intentionally unused and
the warning is silenced.
- Around line 961-978: The current default in _resolve_foreach returns True even
for FSDP2/DTensor cases; fix by plumbing a use_foreach flag through the training
stack or by autodetection: either (A) add an explicit use_foreach argument in
the same place enable_gram is passed (training.py) and ensure training
constructs the optimizer with use_foreach=False when self.zero_stage >= 2, or
(B) change the optimizer flow so step() inspects a sample parameter/gradient for
torch.distributed.tensor.DTensor (or checks torch.distributed.is_initialized() +
gradient type) and calls _resolve_foreach(False) when a DTensor is detected;
update _resolve_foreach signature/usage accordingly so foreach is disabled for
FSDP2 by default unless explicitly True.

In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 590-596: The comment is wrong about bypassing torch.compile
because calling nn.Module._call_impl does not reliably avoid the compiled path;
replace calls to gram_orth._call_impl(X) and gram_orth._call_impl(X_padded) with
explicit calls to gram_orth.forward(X) and gram_orth.forward(X_padded) (or,
alternatively, remove/change the comment to state you are comparing forward
outputs rather than claiming compilation is bypassed) so the test intent matches
behavior and references the correct entrypoint (use gram_orth.forward instead of
_call_impl).

---

Outside diff comments:
In `@deepmd/utils/argcheck.py`:
- Around line 3146-3156: The change flips the Argument("magma_muon", ...,
default=True) default which silently alters training for configs that don't set
it; revert the default to False so Magma-lite damping remains opt-in (change
Argument(... default=True) back to default=False for "magma_muon"), and either
keep or separately document the intentional default flip for "enable_gram"
(Argument("enable_gram", ..., default=True)) by adding a clear breaking-change
note in the PR/changelog and the schema docs; update any unit/integration tests
that assume the old defaults to reflect the reverted default or the documented
breaking change.

---

Nitpick comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 564-640: Remove the two unused helper functions
_stack_bucket_updates and _orthogonalize_standard_stacked from the module:
delete their full definitions (the functions named _stack_bucket_updates and
_orthogonalize_standard_stacked), remove any now-unreferenced imports or symbols
they relied on, and update __all__ or public exports if they were listed there;
after removal run tests/linters to ensure no usages remain and adjust any
lingering references to _reshape_update_to_matrix_batch,
_batched_newton_schulz_orth, or _flash_newton_schulz_orth which remain the
canonical paths.

In `@deepmd/utils/argcheck.py`:
- Around line 3136-3145: The flash_muon Argument is effectively ignored when
enable_gram=True, which makes its default True misleading; update the Argument
for "flash_muon" (the Argument call that uses doc_only_pt_supported and the doc
string) so either (a) prepend a clear note to the doc string like "(Effective
only when `enable_gram=false`.) " before the existing description to surface
that dependency, or (b) change the default from True to False so flash_muon
reflects the fallback behavior when enable_gram defaults to True; ensure
references to enable_gram remain intact and the updated doc string or default is
applied in the same Argument(...) definition.

In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 670-734: Add finite-value assertions to the
test_optimizer_step_column_pad_merge_e2e to ensure no NaN/Inf appears during
optimization: inside the training loop that runs optimizer.zero_grad();
model(x).backward(); optimizer.step(), after each step assert that every
parameter tensor in model.named_parameters() is finite (use torch.isfinite on
each param) and/or assert the model forward output is finite; also add a final
per-parameter torch.isfinite check before the existing "was updated" assertions
to reference the same parameters (model.named_parameters, optimizer) and the
test function name (test_optimizer_step_column_pad_merge_e2e) so the checks are
colocated with the update and output assertions.
- Around line 42-53: The FP16 probe _fp16_matmul_supported currently returns
True on CPU because fp16 ops may be upcasted there, causing false positives and
flaky atol=1e-3 tests; change _fp16_matmul_supported to mirror the BF16
fast-path by gating FP16 support to CUDA devices only (check
torch.cuda.is_available() and device.type == "cuda" or otherwise follow the same
CUDA-specific checks used in _bf16_matmul_supported) and update the
FP16_SUPPORTED assignment to use the revised probe so CPU-only environments do
not report FP16_SUPPORTED=True.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: b72c24a3-2f16-4758-8706-6f77c9ddceeb

📥 Commits

Reviewing files that changed from the base of the PR and between 1a1dc59 and 9a02481.

📒 Files selected for processing (4)
  • deepmd/pt/optimizer/hybrid_muon.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_hybrid_muon.py

Comment thread deepmd/pt/optimizer/hybrid_muon.py Outdated
Comment thread deepmd/pt/optimizer/hybrid_muon.py Outdated
Comment thread source/tests/pt/test_hybrid_muon.py
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR optimizes the PyTorch HybridMuonOptimizer by introducing a compiled Gram Newton–Schulz orthogonalization path for rectangular matrices and by batching/merging optimizer work to reduce Python and kernel-launch overhead.

Changes:

  • Add a compiled Gram Newton–Schulz orthogonalizer for rectangular Muon matrices, plus column-pad merge to fuse multiple rectangular shapes into fewer orth calls.
  • Add foreach-based fast paths for Adam moment updates, Muon momentum updates, and in-place weight decay.
  • Extend CLI/config surface (enable_gram) and broaden tests to cover Gram defaults, fallbacks, and column-pad merge equivalence/end-to-end behavior.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.

File Description
deepmd/pt/optimizer/hybrid_muon.py Adds Gram NS implementation + bucket/merge strategy and foreach accelerations in step() to improve performance.
deepmd/pt/train/training.py Wires new optimizer knobs (enable_gram, flash_muon, magma_muon) into training-time optimizer construction.
deepmd/utils/argcheck.py Exposes new enable_gram option and updates defaults/docs for HybridMuon-related options.
source/tests/pt/test_hybrid_muon.py Adds coverage for Gram defaults/fallbacks and validates column-pad merge equivalence and an e2e mixed-shape optimizer step.
Comments suppressed due to low confidence (1)

deepmd/utils/argcheck.py:3155

  • magma_muon is now defaulted to True here, but HybridMuonOptimizer.__init__ still defaults magma_muon=False and its docstring still states the default is False. This makes CLI/config behavior diverge from direct-Python usage and documentation. Please align the default across argcheck, the optimizer signature, and the docstring (or explicitly document why they differ).
            "magma_muon",
            bool,
            optional=True,
            default=True,
            doc=doc_only_pt_supported
            + "Enable Magma-lite damping on the Muon route only. "
            "When enabled, HybridMuon computes momentum-gradient alignment "
            "per Muon block, applies EMA smoothing, and rescales Muon updates "
            "to improve stability. Adam/AdamW routes are unchanged.",

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/pt/train/training.py
Comment thread deepmd/pt/optimizer/hybrid_muon.py Outdated
Comment thread deepmd/pt/optimizer/hybrid_muon.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 22, 2026

Codecov Report

❌ Patch coverage is 79.37220% with 46 lines in your changes missing coverage. Please review.
✅ Project coverage is 80.45%. Comparing base (54f42d9) to head (bcccbdd).

Files with missing lines Patch % Lines
deepmd/pt/optimizer/hybrid_muon.py 79.37% 46 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5412      +/-   ##
==========================================
- Coverage   80.46%   80.45%   -0.01%     
==========================================
  Files         821      821              
  Lines       86075    86251     +176     
  Branches     4140     4140              
==========================================
+ Hits        69263    69397     +134     
- Misses      15536    15579      +43     
+ Partials     1276     1275       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@OutisLi OutisLi requested a review from njzjz April 23, 2026 04:11
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
deepmd/pt/optimizer/hybrid_muon.py (1)

800-815: Document the public use_foreach option.

Line 839 adds a user-visible optimizer knob, but the constructor docstring stops at magma_muon. Please document use_foreach, including that None defaults to foreach kernels and FSDP2/DTensor callers should pass False.

Also applies to: 836-839

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/optimizer/hybrid_muon.py` around lines 800 - 815, The constructor
docstring in deepmd.pt.optimizer.hybrid_muon (around the __init__/constructor
where enable_gram, flash_muon, magma_muon are documented) is missing the public
parameter use_foreach; update the docstring to document use_foreach: state it
accepts True/False/None, that None defaults to using PyTorch foreach kernels,
and that distributed callers such as FSDP2/DTensor should pass False to disable
foreach; place this new parameter description immediately after the magma_muon
entry so the docs list remains complete and ordered.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt/train/training.py`:
- Around line 913-917: The current code always sets the "use_foreach" kwarg to
None except when self.zero_stage >= 2, which ignores a user-configured
optimizer.use_foreach for non-FSDP2 runs; change the assignment so it preserves
the FSDP2 override (False when self.zero_stage >= 2) but otherwise forwards the
configured value (the optimizer's configured use_foreach flag) instead of
forcing None—i.e., set "use_foreach" = False if self.zero_stage >= 2 else the
configured optimizer.use_foreach (read from your optimizer config/attribute)
when building the optimizer kwargs.

In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 201-205: Class-level BF16 gating is shadowing FP16-only Gram tests
(e.g., TestHybridMuonOptimizer::test_enable_gram_rectangular_step_runs and tests
in TestColumnPadMergeEquivalence), preventing them from running on devices that
support FP16 but not BF16; fix by moving Gram-only tests into a class decorated
with `@unittest.skipIf`(not FP16_SUPPORTED, ...) or removing the class-level BF16
skip and instead apply method-level `@unittest.skipIf`(not BF16_SUPPORTED, ...)
only to tests that require BF16 (the BF16 Newton-Schulz path), ensuring FP16
tests reference FP16_SUPPORTED and BF16-specific tests reference BF16_SUPPORTED
while keeping test names like test_enable_gram_rectangular_step_runs and the
three TestColumnPadMergeEquivalence methods as-is.

---

Nitpick comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 800-815: The constructor docstring in
deepmd.pt.optimizer.hybrid_muon (around the __init__/constructor where
enable_gram, flash_muon, magma_muon are documented) is missing the public
parameter use_foreach; update the docstring to document use_foreach: state it
accepts True/False/None, that None defaults to using PyTorch foreach kernels,
and that distributed callers such as FSDP2/DTensor should pass False to disable
foreach; place this new parameter description immediately after the magma_muon
entry so the docs list remains complete and ordered.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 76ceeb0f-25b2-47a2-a974-efbd24cdc4b8

📥 Commits

Reviewing files that changed from the base of the PR and between 9a02481 and e7cb13f.

📒 Files selected for processing (3)
  • deepmd/pt/optimizer/hybrid_muon.py
  • deepmd/pt/train/training.py
  • source/tests/pt/test_hybrid_muon.py

Comment on lines +913 to +917
# FSDP2 shards parameters as DTensor; several torch._foreach_*
# ops lack DTensor sharding propagation on older PyTorch, so
# fall back to the per-tensor path under zero_stage >= 2.
# DDP / ZeRO-1 keep plain tensors and use the default.
"use_foreach": False if self.zero_stage >= 2 else None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Honor configured use_foreach outside FSDP2.

Line 917 always passes None unless zero_stage >= 2, so a user-provided optimizer.use_foreach: false is silently ignored for single-device/DDP/ZeRO-1 runs. Preserve the FSDP2 override, but otherwise forward the configured value.

Proposed fix
-                    "use_foreach": False if self.zero_stage >= 2 else None,
+                    "use_foreach": False
+                    if self.zero_stage >= 2
+                    else self.opt_param.get("use_foreach"),
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/train/training.py` around lines 913 - 917, The current code always
sets the "use_foreach" kwarg to None except when self.zero_stage >= 2, which
ignores a user-configured optimizer.use_foreach for non-FSDP2 runs; change the
assignment so it preserves the FSDP2 override (False when self.zero_stage >= 2)
but otherwise forwards the configured value (the optimizer's configured
use_foreach flag) instead of forcing None—i.e., set "use_foreach" = False if
self.zero_stage >= 2 else the configured optimizer.use_foreach (read from your
optimizer config/attribute) when building the optimizer kwargs.

Comment on lines +201 to +205
@unittest.skipIf(
not FP16_SUPPORTED,
"float16 matmul not supported on this device",
)
def test_enable_gram_rectangular_step_runs(self) -> None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

pytest source/tests/pt/test_hybrid_muon.py::TestHybridMuonOptimizer::test_enable_gram_rectangular_step_runs -v
pytest source/tests/pt/test_hybrid_muon.py::TestColumnPadMergeEquivalence::test_gram_ns_column_pad_exact_equivalence -v
pytest source/tests/pt/test_hybrid_muon.py::TestColumnPadMergeEquivalence::test_gram_ns_batch_pad_equivalence -v
pytest source/tests/pt/test_hybrid_muon.py::TestColumnPadMergeEquivalence::test_optimizer_step_column_pad_merge_e2e -v

Repository: deepmodeling/deepmd-kit

Length of output: 248


🏁 Script executed:

# First, check if the file exists and get its line count
wc -l source/tests/pt/test_hybrid_muon.py

Repository: deepmodeling/deepmd-kit

Length of output: 106


🏁 Script executed:

# Read the section around line 201-205 to see the class and decorator context
sed -n '160,250p' source/tests/pt/test_hybrid_muon.py | cat -n

Repository: deepmodeling/deepmd-kit

Length of output: 4184


🏁 Script executed:

# Read earlier to find the class definition
sed -n '1,170p' source/tests/pt/test_hybrid_muon.py | cat -n

Repository: deepmodeling/deepmd-kit

Length of output: 7648


🏁 Script executed:

# Now look at the TestColumnPadMergeEquivalence class around lines 500-600
sed -n '500,650p' source/tests/pt/test_hybrid_muon.py | cat -n

Repository: deepmodeling/deepmd-kit

Length of output: 7186


🏁 Script executed:

# Continue reading the rest of TestColumnPadMergeEquivalence
sed -n '650,758p' source/tests/pt/test_hybrid_muon.py | cat -n

Repository: deepmodeling/deepmd-kit

Length of output: 5130


Avoid shadowing FP16 Gram tests with BF16 class-level skips.

The FP16-gated Gram tests in TestHybridMuonOptimizer::test_enable_gram_rectangular_step_runs (lines 201-205) and the three tests in TestColumnPadMergeEquivalence (lines 580-583, 635-638, 668-671) cannot run on devices that support FP16 but not BF16, because both classes are skipped when BF16_SUPPORTED is false. This limits Gram coverage on common pre-Ampere CUDA GPUs.

Move the Gram-only tests into an FP16_SUPPORTED-gated class, or restructure to remove the class-level BF16 skip and add method-level BF16 skips only to tests that specifically exercise the standard bf16 Newton-Schulz path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt/test_hybrid_muon.py` around lines 201 - 205, Class-level BF16
gating is shadowing FP16-only Gram tests (e.g.,
TestHybridMuonOptimizer::test_enable_gram_rectangular_step_runs and tests in
TestColumnPadMergeEquivalence), preventing them from running on devices that
support FP16 but not BF16; fix by moving Gram-only tests into a class decorated
with `@unittest.skipIf`(not FP16_SUPPORTED, ...) or removing the class-level BF16
skip and instead apply method-level `@unittest.skipIf`(not BF16_SUPPORTED, ...)
only to tests that require BF16 (the BF16 Newton-Schulz path), ensuring FP16
tests reference FP16_SUPPORTED and BF16-specific tests reference BF16_SUPPORTED
while keeping test names like test_enable_gram_rectangular_step_runs and the
three TestColumnPadMergeEquivalence methods as-is.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
deepmd/utils/argcheck.py (1)

3146-3155: ⚠️ Potential issue | 🟡 Minor

Align the config default with the optimizer API default.

The schema now defaults magma_muon to True, but HybridMuonOptimizer(..., magma_muon=False) still defaults direct Python usage to False. If this default change is intentional for all HybridMuon users, update the optimizer constructor/docstring too; otherwise keep the schema default consistent.

🛠️ Possible fix if the new default is intended globally
-            default=True,
+            default=False,

or update HybridMuonOptimizer to default magma_muon=True.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/utils/argcheck.py` around lines 3146 - 3155, The schema sets
Argument("magma_muon", ..., default=True) but the HybridMuonOptimizer
constructor/docstring and its Python default remain False; make them consistent:
either revert the Argument default to False to match HybridMuonOptimizer, or
update HybridMuonOptimizer's __init__ signature and docstring to default
magma_muon=True (and adjust any tests/docs referencing the old default). Locate
the Argument for "magma_muon" in argcheck.py and the HybridMuonOptimizer
class/constructor and docstring, then change the default value in one place and
update the docstring text to reflect the chosen global default.
♻️ Duplicate comments (1)
source/tests/pt/test_hybrid_muon.py (1)

91-92: ⚠️ Potential issue | 🟡 Minor

Separate BF16 standard-NS coverage from FP16 Gram coverage.

The class-level BF16 skips still suppress FP16-only Gram tests on devices that support fp16 but not bf16. Conversely, several rectangular default-Gram tests in TestHybridMuonOptimizer can require fp16 but are only protected by the BF16 class gate. Move Gram-only tests under an FP16 gate, and keep BF16 skips only on tests that exercise the standard bf16 Newton-Schulz path.

Also applies to: 201-205, 562-583, 634-671

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt/test_hybrid_muon.py` around lines 91 - 92, The class-level
BF16 skip on TestHybridMuonOptimizer is too broad; remove the
`@unittest.skipIf`(not BF16_SUPPORTED, ...) decorator from the class and instead:
1) add `@unittest.skipIf`(not BF16_SUPPORTED, ...) only to the individual tests
that exercise the standard bf16 Newton-Schulz path (search for tests that invoke
the Newton-Schulz / bf16-specific code paths), and 2) add `@unittest.skipIf`(not
FP16_SUPPORTED, "fp16 Gram not supported on this device") to the Gram-only tests
(tests that validate Gram matrix behavior or default-Gram rectangular cases).
Apply the same change pattern to the other affected blocks mentioned (the tests
around the other ranges) so BF16 gating is limited to bf16-NS tests and FP16
gating covers Gram-only tests.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 136-145: The bare except in the Triton probe (around
triton.runtime.driver.active.get_current_target() and TRITON_AVAILABLE) must be
fixed to satisfy Ruff: either replace "except Exception:" with a narrowed tuple
of expected exceptions (e.g., except (ImportError, AttributeError, RuntimeError,
OSError):) to only catch likely probe failures, or keep the broad except but add
an explicit Ruff suppression comment on that except line (e.g., "# noqa: BLE001
— intentional broad probe to detect missing/misconfigured Triton driver on
CPU-only hosts") so the rationale is recorded next to the except; update only
the except handling around the probe and leave TRITON_AVAILABLE logic unchanged.

---

Outside diff comments:
In `@deepmd/utils/argcheck.py`:
- Around line 3146-3155: The schema sets Argument("magma_muon", ...,
default=True) but the HybridMuonOptimizer constructor/docstring and its Python
default remain False; make them consistent: either revert the Argument default
to False to match HybridMuonOptimizer, or update HybridMuonOptimizer's __init__
signature and docstring to default magma_muon=True (and adjust any tests/docs
referencing the old default). Locate the Argument for "magma_muon" in
argcheck.py and the HybridMuonOptimizer class/constructor and docstring, then
change the default value in one place and update the docstring text to reflect
the chosen global default.

---

Duplicate comments:
In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 91-92: The class-level BF16 skip on TestHybridMuonOptimizer is too
broad; remove the `@unittest.skipIf`(not BF16_SUPPORTED, ...) decorator from the
class and instead: 1) add `@unittest.skipIf`(not BF16_SUPPORTED, ...) only to the
individual tests that exercise the standard bf16 Newton-Schulz path (search for
tests that invoke the Newton-Schulz / bf16-specific code paths), and 2) add
`@unittest.skipIf`(not FP16_SUPPORTED, "fp16 Gram not supported on this device")
to the Gram-only tests (tests that validate Gram matrix behavior or default-Gram
rectangular cases). Apply the same change pattern to the other affected blocks
mentioned (the tests around the other ranges) so BF16 gating is limited to
bf16-NS tests and FP16 gating covers Gram-only tests.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 4f119985-8d9a-446d-9d4a-f849db7eeab4

📥 Commits

Reviewing files that changed from the base of the PR and between e7cb13f and bcccbdd.

📒 Files selected for processing (4)
  • deepmd/pt/optimizer/hybrid_muon.py
  • deepmd/pt/train/training.py
  • deepmd/utils/argcheck.py
  • source/tests/pt/test_hybrid_muon.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/train/training.py

Comment on lines +136 to +145
try:
# Touching ``driver.active`` forces the lazy proxy to initialize the
# backend driver. ``get_current_target`` is the lightest public call
# that exercises the same path as ``Autotuner.__init__``.
triton.runtime.driver.active.get_current_target()
TRITON_AVAILABLE = True
except Exception:
# No usable runtime driver (no CUDA/ROCm/XPU, or a mis-configured
# one): fall back to the pure-PyTorch Newton-Schulz path.
TRITON_AVAILABLE = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Make the intentional broad Triton probe exception pass Ruff.

Ruff reports BLE001 here. Since this import-time guard intentionally protects CPU/driver-less hosts, either narrow the exception set or add an explicit suppression with the rationale.

🛠️ Proposed fix
-    except Exception:
+    except Exception:  # noqa: BLE001
         # No usable runtime driver (no CUDA/ROCm/XPU, or a mis-configured
         # one): fall back to the pure-PyTorch Newton-Schulz path.
         TRITON_AVAILABLE = False

As per coding guidelines: "**/*.py: Install linter and run ruff check . before committing changes or the CI will fail".

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
try:
# Touching ``driver.active`` forces the lazy proxy to initialize the
# backend driver. ``get_current_target`` is the lightest public call
# that exercises the same path as ``Autotuner.__init__``.
triton.runtime.driver.active.get_current_target()
TRITON_AVAILABLE = True
except Exception:
# No usable runtime driver (no CUDA/ROCm/XPU, or a mis-configured
# one): fall back to the pure-PyTorch Newton-Schulz path.
TRITON_AVAILABLE = False
try:
# Touching ``driver.active`` forces the lazy proxy to initialize the
# backend driver. ``get_current_target`` is the lightest public call
# that exercises the same path as ``Autotuner.__init__``.
triton.runtime.driver.active.get_current_target()
TRITON_AVAILABLE = True
except Exception: # noqa: BLE001
# No usable runtime driver (no CUDA/ROCm/XPU, or a mis-configured
# one): fall back to the pure-PyTorch Newton-Schulz path.
TRITON_AVAILABLE = False
🧰 Tools
🪛 Ruff (0.15.10)

[warning] 142-142: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/optimizer/hybrid_muon.py` around lines 136 - 145, The bare except
in the Triton probe (around triton.runtime.driver.active.get_current_target()
and TRITON_AVAILABLE) must be fixed to satisfy Ruff: either replace "except
Exception:" with a narrowed tuple of expected exceptions (e.g., except
(ImportError, AttributeError, RuntimeError, OSError):) to only catch likely
probe failures, or keep the broad except but add an explicit Ruff suppression
comment on that except line (e.g., "# noqa: BLE001 — intentional broad probe to
detect missing/misconfigured Triton driver on CPU-only hosts") so the rationale
is recorded next to the except; update only the except handling around the probe
and leave TRITON_AVAILABLE logic unchanged.

Copy link
Copy Markdown
Contributor

@njzjz-bot njzjz-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks promising, but I'd like one clarification before I approve. In deepmd/pt/train/training.py, the handoff now uses bool(self.opt_param.get("flash_muon")) and bool(self.opt_param.get("magma_muon")) instead of explicit fallbacks. If these keys are ever absent in older configs / programmatic call sites, this silently flips missing values to False (notably flash_muon used to fall back to True, and magma_muon now also depends on config-layer defaults being materialized).

If config defaults are guaranteed to be filled before this point, then this is fine; otherwise I think the old explicit fallback style is safer.

— OpenClaw 2026.4.22 (model: gpt-5.4)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants