perf(pt): optimize HybridMuon optimizer#5412
perf(pt): optimize HybridMuon optimizer#5412OutisLi wants to merge 4 commits intodeepmodeling:masterfrom
Conversation
📝 WalkthroughWalkthroughAdds a compiled Gram Newton–Schulz orthogonalization path for rectangular Muon matrices, refactors Muon bucketing/dispatch and foreach-aware update kernels, exposes new optimizer options ( Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer as Trainer/Caller
participant Optimizer as HybridMuon Optimizer
participant Bucket as Bucket Dispatcher
participant GramNS as Gram Newton‑Schulz (rectangular)
participant SquareNS as Standard Newton‑Schulz (square)
participant Params as Parameters/State
Trainer->>Optimizer: step(request, configs)
Optimizer->>Bucket: classify params by (rows, cols, device, dtype)
Bucket->>Bucket: split into square vs rectangular buckets
alt square buckets
Bucket->>SquareNS: dispatch square buckets
SquareNS->>SquareNS: batched Newton‑Schulz (bfloat16)
SquareNS->>Optimizer: orthogonalized entries
end
alt rectangular buckets (enable_gram)
Bucket->>GramNS: merge/pad columns and dispatch
GramNS->>GramNS: normalize float32 → Gram NS float16 (restarts) → cast back
GramNS->>Optimizer: unpad/untranspose & scale entries
end
Optimizer->>Params: apply foreach-aware moment updates, weight decay, Nesterov adjustments
Optimizer->>Trainer: commit parameter updates and optimizer state
Estimated code review effort🎯 4 (Complex) | ⏱️ ~70 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
deepmd/pt/train/training.py (1)
899-913:⚠️ Potential issue | 🟠 MajorAsymmetric distributed handling:
use_foreach/flash_muonnot guarded likeenable_gram.You correctly force
enable_gram=Falsewhenself.is_distributedto dodgetorch.compile+ DTensor issues, but:
use_foreachis never passed inextra, so the optimizer resolves it toTrueunconditionally — under FSDP2 (zero_stage >= 2) that can trip DTensor dispatch errors intorch._foreach_*calls in the Muon/Adam hot paths.flash_muonis still forwarded verbatim, even though the triton symmetric-matmul kernel operates on raw tensors and will likely fail on sharded DTensor grads withzero_stage >= 2.Consider mirroring the
enable_gramguard for these two, e.g.:🛠️ Suggested diff
elif self.opt_type == "HybridMuon": cls = HybridMuonOptimizer + fsdp2 = self.zero_stage >= 2 extra = { "adam_betas": adam_betas, "momentum": float(self.opt_param["momentum"]), "lr_adjust": float(self.opt_param["lr_adjust"]), "lr_adjust_coeff": float(self.opt_param["lr_adjust_coeff"]), "muon_mode": str(self.opt_param.get("muon_mode", "slice")), "named_parameters": tuple(self.wrapper.named_parameters()), "enable_gram": False if self.is_distributed else bool(self.opt_param.get("enable_gram")), - "flash_muon": bool(self.opt_param.get("flash_muon")), + "flash_muon": False if fsdp2 else bool(self.opt_param.get("flash_muon")), "magma_muon": bool(self.opt_param.get("magma_muon")), + "use_foreach": False if fsdp2 else None, }Also note that
bool(self.opt_param.get("enable_gram"))(and the twobool(...get(...))siblings) silently becomesFalsewhen the key is missing.argcheck.normalizeis expected to fill defaults, but for a belt-and-suspenders safeguard pass a default (e.g.self.opt_param.get("enable_gram", True)) matching the argcheck default.deepmd/utils/argcheck.py (1)
3146-3156:⚠️ Potential issue | 🟡 MinorFlipping
magma_muondefaultFalse→Trueis a silent behavior change.Existing HybridMuon configs that don't explicitly set
magma_muonwill now get Magma-lite damping applied on the Muon route, which rescales updates and alters the training trajectory (and final weights/loss curves) compared to prior releases. This is a backward-incompatible change for anyone resuming or reproducing runs.Consider one of the following:
- Keep the default at
Falseand let the new path be opt-in, or- Document this as a breaking change in the release notes / PR description and call it out in the schema doc (e.g., "Default changed from False to True in vX.Y").
Same concern, to a lesser degree, applies to
enable_gramdefaulting toTrue— any numerical divergence between Gram NS (fp16/fp32) and the previous bf16 standard NS path on rectangular matrices will also shift training results under the default config.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/utils/argcheck.py` around lines 3146 - 3156, The change flips the Argument("magma_muon", ..., default=True) default which silently alters training for configs that don't set it; revert the default to False so Magma-lite damping remains opt-in (change Argument(... default=True) back to default=False for "magma_muon"), and either keep or separately document the intentional default flip for "enable_gram" (Argument("enable_gram", ..., default=True)) by adding a clear breaking-change note in the PR/changelog and the schema docs; update any unit/integration tests that assume the old defaults to reflect the reverted default or the documented breaking change.
🧹 Nitpick comments (4)
deepmd/pt/optimizer/hybrid_muon.py (1)
564-640: Remove dead helper functions left behind by refactoring.
_stack_bucket_updates(lines 564–595) and_orthogonalize_standard_stacked(lines 598–640) are defined but never used anywhere in the codebase. Thestep()path builds stacked tensors inline via_reshape_update_to_matrix_batch+torch.catand dispatches_batched_newton_schulz_orth/_flash_newton_schulz_orthdirectly (lines 1858–1886), and the Gram path uses its own flow in_process_merged_gram_buckets. Remove these helpers to keep the module's public surface clean.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/optimizer/hybrid_muon.py` around lines 564 - 640, Remove the two unused helper functions _stack_bucket_updates and _orthogonalize_standard_stacked from the module: delete their full definitions (the functions named _stack_bucket_updates and _orthogonalize_standard_stacked), remove any now-unreferenced imports or symbols they relied on, and update __all__ or public exports if they were listed there; after removal run tests/linters to ensure no usages remain and adjust any lingering references to _reshape_update_to_matrix_batch, _batched_newton_schulz_orth, or _flash_newton_schulz_orth which remain the canonical paths.source/tests/pt/test_hybrid_muon.py (2)
670-734: E2E coverage is good; consider asserting finite outputs during each step too.
test_optimizer_step_column_pad_merge_e2eis a solid integration check for the column-pad merge path. One small hardening suggestion: NaN/Inf during an intermediate step would still leaveparam != init_state[name], so the "was updated" assertion can pass even when a step produced non-finite values that later overwrote finite ones. Consider assertingtorch.isfinite(param).all()inside the per-step loop (or at least on all final parameter tensors) in addition to the forward output check.x = torch.randn(4, 8, device=self.device) for _ in range(3): optimizer.zero_grad() model(x).backward() optimizer.step() + for name, param in model.named_parameters(): + self.assertTrue( + torch.isfinite(param).all(), + f"Parameter {name} became non-finite during step", + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt/test_hybrid_muon.py` around lines 670 - 734, Add finite-value assertions to the test_optimizer_step_column_pad_merge_e2e to ensure no NaN/Inf appears during optimization: inside the training loop that runs optimizer.zero_grad(); model(x).backward(); optimizer.step(), after each step assert that every parameter tensor in model.named_parameters() is finite (use torch.isfinite on each param) and/or assert the model forward output is finite; also add a final per-parameter torch.isfinite check before the existing "was updated" assertions to reference the same parameters (model.named_parameters, optimizer) and the test function name (test_optimizer_step_column_pad_merge_e2e) so the checks are colocated with the update and output assertions.
42-53: Nit:_fp16_matmul_supportedmay succeed on CPU with poor precision.On CPU, fp16 matmul often runs without raising (dispatched via upcast), so this probe will return
Trueeven when fp16 is numerically unreliable. That's acceptable for gating skips (you prefer a false positive over false negative), but tests relying on fp16 tolerance (atol=1e-3) may become flaky on certain CPUs.Optionally mirror
_bf16_matmul_supported's CUDA fast-path and be more conservative on CPU (e.g., only enable FP16_SUPPORTED when CUDA withtorch.cuda.is_available()), or document the tolerance rationale.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt/test_hybrid_muon.py` around lines 42 - 53, The FP16 probe _fp16_matmul_supported currently returns True on CPU because fp16 ops may be upcasted there, causing false positives and flaky atol=1e-3 tests; change _fp16_matmul_supported to mirror the BF16 fast-path by gating FP16 support to CUDA devices only (check torch.cuda.is_available() and device.type == "cuda" or otherwise follow the same CUDA-specific checks used in _bf16_matmul_supported) and update the FP16_SUPPORTED assignment to use the revised probe so CPU-only environments do not report FP16_SUPPORTED=True.deepmd/utils/argcheck.py (1)
3136-3145: UX: with both defaultsTrue,flash_muonis dead-by-default.Under the new defaults (
enable_gram=True,flash_muon=True),flash_muonis silently ignored per the updated doc on line 3144. Users reading the schema may not immediately notice. Two optional refinements:
- Make the doc more explicit, e.g. prepend "(Effective only when
enable_gram=false.)" so it shows up at the top of the entry.- Or switch
flash_muondefault toFalseto reflect it's now a fallback path — this avoids misleading "enabled" status in logs/dumps.Not a blocker; just reduces user confusion when debugging perf.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/utils/argcheck.py` around lines 3136 - 3145, The flash_muon Argument is effectively ignored when enable_gram=True, which makes its default True misleading; update the Argument for "flash_muon" (the Argument call that uses doc_only_pt_supported and the doc string) so either (a) prepend a clear note to the doc string like "(Effective only when `enable_gram=false`.) " before the existing description to surface that dependency, or (b) change the default from True to False so flash_muon reflects the fallback behavior when enable_gram defaults to True; ensure references to enable_gram remain intact and the updated doc string or default is applied in the same Argument(...) definition.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 1404-1406: The loop in hybrid_muon.py iterates "for (min_dim,
_dev, _dt), sub_list in super_buckets.items()" but never uses min_dim, causing
Ruff B007; rename the unused loop variable to _min_dim (i.e., change min_dim ->
_min_dim) in that for-statement (the loop that computes padded_max_dim using
max(r, c) over sub_list) so the linter recognizes it as intentionally unused and
the warning is silenced.
- Around line 961-978: The current default in _resolve_foreach returns True even
for FSDP2/DTensor cases; fix by plumbing a use_foreach flag through the training
stack or by autodetection: either (A) add an explicit use_foreach argument in
the same place enable_gram is passed (training.py) and ensure training
constructs the optimizer with use_foreach=False when self.zero_stage >= 2, or
(B) change the optimizer flow so step() inspects a sample parameter/gradient for
torch.distributed.tensor.DTensor (or checks torch.distributed.is_initialized() +
gradient type) and calls _resolve_foreach(False) when a DTensor is detected;
update _resolve_foreach signature/usage accordingly so foreach is disabled for
FSDP2 by default unless explicitly True.
In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 590-596: The comment is wrong about bypassing torch.compile
because calling nn.Module._call_impl does not reliably avoid the compiled path;
replace calls to gram_orth._call_impl(X) and gram_orth._call_impl(X_padded) with
explicit calls to gram_orth.forward(X) and gram_orth.forward(X_padded) (or,
alternatively, remove/change the comment to state you are comparing forward
outputs rather than claiming compilation is bypassed) so the test intent matches
behavior and references the correct entrypoint (use gram_orth.forward instead of
_call_impl).
---
Outside diff comments:
In `@deepmd/utils/argcheck.py`:
- Around line 3146-3156: The change flips the Argument("magma_muon", ...,
default=True) default which silently alters training for configs that don't set
it; revert the default to False so Magma-lite damping remains opt-in (change
Argument(... default=True) back to default=False for "magma_muon"), and either
keep or separately document the intentional default flip for "enable_gram"
(Argument("enable_gram", ..., default=True)) by adding a clear breaking-change
note in the PR/changelog and the schema docs; update any unit/integration tests
that assume the old defaults to reflect the reverted default or the documented
breaking change.
---
Nitpick comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 564-640: Remove the two unused helper functions
_stack_bucket_updates and _orthogonalize_standard_stacked from the module:
delete their full definitions (the functions named _stack_bucket_updates and
_orthogonalize_standard_stacked), remove any now-unreferenced imports or symbols
they relied on, and update __all__ or public exports if they were listed there;
after removal run tests/linters to ensure no usages remain and adjust any
lingering references to _reshape_update_to_matrix_batch,
_batched_newton_schulz_orth, or _flash_newton_schulz_orth which remain the
canonical paths.
In `@deepmd/utils/argcheck.py`:
- Around line 3136-3145: The flash_muon Argument is effectively ignored when
enable_gram=True, which makes its default True misleading; update the Argument
for "flash_muon" (the Argument call that uses doc_only_pt_supported and the doc
string) so either (a) prepend a clear note to the doc string like "(Effective
only when `enable_gram=false`.) " before the existing description to surface
that dependency, or (b) change the default from True to False so flash_muon
reflects the fallback behavior when enable_gram defaults to True; ensure
references to enable_gram remain intact and the updated doc string or default is
applied in the same Argument(...) definition.
In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 670-734: Add finite-value assertions to the
test_optimizer_step_column_pad_merge_e2e to ensure no NaN/Inf appears during
optimization: inside the training loop that runs optimizer.zero_grad();
model(x).backward(); optimizer.step(), after each step assert that every
parameter tensor in model.named_parameters() is finite (use torch.isfinite on
each param) and/or assert the model forward output is finite; also add a final
per-parameter torch.isfinite check before the existing "was updated" assertions
to reference the same parameters (model.named_parameters, optimizer) and the
test function name (test_optimizer_step_column_pad_merge_e2e) so the checks are
colocated with the update and output assertions.
- Around line 42-53: The FP16 probe _fp16_matmul_supported currently returns
True on CPU because fp16 ops may be upcasted there, causing false positives and
flaky atol=1e-3 tests; change _fp16_matmul_supported to mirror the BF16
fast-path by gating FP16 support to CUDA devices only (check
torch.cuda.is_available() and device.type == "cuda" or otherwise follow the same
CUDA-specific checks used in _bf16_matmul_supported) and update the
FP16_SUPPORTED assignment to use the revised probe so CPU-only environments do
not report FP16_SUPPORTED=True.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: b72c24a3-2f16-4758-8706-6f77c9ddceeb
📒 Files selected for processing (4)
deepmd/pt/optimizer/hybrid_muon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_hybrid_muon.py
There was a problem hiding this comment.
Pull request overview
This PR optimizes the PyTorch HybridMuonOptimizer by introducing a compiled Gram Newton–Schulz orthogonalization path for rectangular matrices and by batching/merging optimizer work to reduce Python and kernel-launch overhead.
Changes:
- Add a compiled Gram Newton–Schulz orthogonalizer for rectangular Muon matrices, plus column-pad merge to fuse multiple rectangular shapes into fewer orth calls.
- Add foreach-based fast paths for Adam moment updates, Muon momentum updates, and in-place weight decay.
- Extend CLI/config surface (
enable_gram) and broaden tests to cover Gram defaults, fallbacks, and column-pad merge equivalence/end-to-end behavior.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
deepmd/pt/optimizer/hybrid_muon.py |
Adds Gram NS implementation + bucket/merge strategy and foreach accelerations in step() to improve performance. |
deepmd/pt/train/training.py |
Wires new optimizer knobs (enable_gram, flash_muon, magma_muon) into training-time optimizer construction. |
deepmd/utils/argcheck.py |
Exposes new enable_gram option and updates defaults/docs for HybridMuon-related options. |
source/tests/pt/test_hybrid_muon.py |
Adds coverage for Gram defaults/fallbacks and validates column-pad merge equivalence and an e2e mixed-shape optimizer step. |
Comments suppressed due to low confidence (1)
deepmd/utils/argcheck.py:3155
magma_muonis now defaulted toTruehere, butHybridMuonOptimizer.__init__still defaultsmagma_muon=Falseand its docstring still states the default is False. This makes CLI/config behavior diverge from direct-Python usage and documentation. Please align the default across argcheck, the optimizer signature, and the docstring (or explicitly document why they differ).
"magma_muon",
bool,
optional=True,
default=True,
doc=doc_only_pt_supported
+ "Enable Magma-lite damping on the Muon route only. "
"When enabled, HybridMuon computes momentum-gradient alignment "
"per Muon block, applies EMA smoothing, and rescales Muon updates "
"to improve stability. Adam/AdamW routes are unchanged.",
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5412 +/- ##
==========================================
- Coverage 80.46% 80.45% -0.01%
==========================================
Files 821 821
Lines 86075 86251 +176
Branches 4140 4140
==========================================
+ Hits 69263 69397 +134
- Misses 15536 15579 +43
+ Partials 1276 1275 -1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (1)
deepmd/pt/optimizer/hybrid_muon.py (1)
800-815: Document the publicuse_foreachoption.Line 839 adds a user-visible optimizer knob, but the constructor docstring stops at
magma_muon. Please documentuse_foreach, including thatNonedefaults to foreach kernels and FSDP2/DTensor callers should passFalse.Also applies to: 836-839
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/optimizer/hybrid_muon.py` around lines 800 - 815, The constructor docstring in deepmd.pt.optimizer.hybrid_muon (around the __init__/constructor where enable_gram, flash_muon, magma_muon are documented) is missing the public parameter use_foreach; update the docstring to document use_foreach: state it accepts True/False/None, that None defaults to using PyTorch foreach kernels, and that distributed callers such as FSDP2/DTensor should pass False to disable foreach; place this new parameter description immediately after the magma_muon entry so the docs list remains complete and ordered.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/train/training.py`:
- Around line 913-917: The current code always sets the "use_foreach" kwarg to
None except when self.zero_stage >= 2, which ignores a user-configured
optimizer.use_foreach for non-FSDP2 runs; change the assignment so it preserves
the FSDP2 override (False when self.zero_stage >= 2) but otherwise forwards the
configured value (the optimizer's configured use_foreach flag) instead of
forcing None—i.e., set "use_foreach" = False if self.zero_stage >= 2 else the
configured optimizer.use_foreach (read from your optimizer config/attribute)
when building the optimizer kwargs.
In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 201-205: Class-level BF16 gating is shadowing FP16-only Gram tests
(e.g., TestHybridMuonOptimizer::test_enable_gram_rectangular_step_runs and tests
in TestColumnPadMergeEquivalence), preventing them from running on devices that
support FP16 but not BF16; fix by moving Gram-only tests into a class decorated
with `@unittest.skipIf`(not FP16_SUPPORTED, ...) or removing the class-level BF16
skip and instead apply method-level `@unittest.skipIf`(not BF16_SUPPORTED, ...)
only to tests that require BF16 (the BF16 Newton-Schulz path), ensuring FP16
tests reference FP16_SUPPORTED and BF16-specific tests reference BF16_SUPPORTED
while keeping test names like test_enable_gram_rectangular_step_runs and the
three TestColumnPadMergeEquivalence methods as-is.
---
Nitpick comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 800-815: The constructor docstring in
deepmd.pt.optimizer.hybrid_muon (around the __init__/constructor where
enable_gram, flash_muon, magma_muon are documented) is missing the public
parameter use_foreach; update the docstring to document use_foreach: state it
accepts True/False/None, that None defaults to using PyTorch foreach kernels,
and that distributed callers such as FSDP2/DTensor should pass False to disable
foreach; place this new parameter description immediately after the magma_muon
entry so the docs list remains complete and ordered.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 76ceeb0f-25b2-47a2-a974-efbd24cdc4b8
📒 Files selected for processing (3)
deepmd/pt/optimizer/hybrid_muon.pydeepmd/pt/train/training.pysource/tests/pt/test_hybrid_muon.py
| # FSDP2 shards parameters as DTensor; several torch._foreach_* | ||
| # ops lack DTensor sharding propagation on older PyTorch, so | ||
| # fall back to the per-tensor path under zero_stage >= 2. | ||
| # DDP / ZeRO-1 keep plain tensors and use the default. | ||
| "use_foreach": False if self.zero_stage >= 2 else None, |
There was a problem hiding this comment.
Honor configured use_foreach outside FSDP2.
Line 917 always passes None unless zero_stage >= 2, so a user-provided optimizer.use_foreach: false is silently ignored for single-device/DDP/ZeRO-1 runs. Preserve the FSDP2 override, but otherwise forward the configured value.
Proposed fix
- "use_foreach": False if self.zero_stage >= 2 else None,
+ "use_foreach": False
+ if self.zero_stage >= 2
+ else self.opt_param.get("use_foreach"),🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/pt/train/training.py` around lines 913 - 917, The current code always
sets the "use_foreach" kwarg to None except when self.zero_stage >= 2, which
ignores a user-configured optimizer.use_foreach for non-FSDP2 runs; change the
assignment so it preserves the FSDP2 override (False when self.zero_stage >= 2)
but otherwise forwards the configured value (the optimizer's configured
use_foreach flag) instead of forcing None—i.e., set "use_foreach" = False if
self.zero_stage >= 2 else the configured optimizer.use_foreach (read from your
optimizer config/attribute) when building the optimizer kwargs.
| @unittest.skipIf( | ||
| not FP16_SUPPORTED, | ||
| "float16 matmul not supported on this device", | ||
| ) | ||
| def test_enable_gram_rectangular_step_runs(self) -> None: |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
pytest source/tests/pt/test_hybrid_muon.py::TestHybridMuonOptimizer::test_enable_gram_rectangular_step_runs -v
pytest source/tests/pt/test_hybrid_muon.py::TestColumnPadMergeEquivalence::test_gram_ns_column_pad_exact_equivalence -v
pytest source/tests/pt/test_hybrid_muon.py::TestColumnPadMergeEquivalence::test_gram_ns_batch_pad_equivalence -v
pytest source/tests/pt/test_hybrid_muon.py::TestColumnPadMergeEquivalence::test_optimizer_step_column_pad_merge_e2e -vRepository: deepmodeling/deepmd-kit
Length of output: 248
🏁 Script executed:
# First, check if the file exists and get its line count
wc -l source/tests/pt/test_hybrid_muon.pyRepository: deepmodeling/deepmd-kit
Length of output: 106
🏁 Script executed:
# Read the section around line 201-205 to see the class and decorator context
sed -n '160,250p' source/tests/pt/test_hybrid_muon.py | cat -nRepository: deepmodeling/deepmd-kit
Length of output: 4184
🏁 Script executed:
# Read earlier to find the class definition
sed -n '1,170p' source/tests/pt/test_hybrid_muon.py | cat -nRepository: deepmodeling/deepmd-kit
Length of output: 7648
🏁 Script executed:
# Now look at the TestColumnPadMergeEquivalence class around lines 500-600
sed -n '500,650p' source/tests/pt/test_hybrid_muon.py | cat -nRepository: deepmodeling/deepmd-kit
Length of output: 7186
🏁 Script executed:
# Continue reading the rest of TestColumnPadMergeEquivalence
sed -n '650,758p' source/tests/pt/test_hybrid_muon.py | cat -nRepository: deepmodeling/deepmd-kit
Length of output: 5130
Avoid shadowing FP16 Gram tests with BF16 class-level skips.
The FP16-gated Gram tests in TestHybridMuonOptimizer::test_enable_gram_rectangular_step_runs (lines 201-205) and the three tests in TestColumnPadMergeEquivalence (lines 580-583, 635-638, 668-671) cannot run on devices that support FP16 but not BF16, because both classes are skipped when BF16_SUPPORTED is false. This limits Gram coverage on common pre-Ampere CUDA GPUs.
Move the Gram-only tests into an FP16_SUPPORTED-gated class, or restructure to remove the class-level BF16 skip and add method-level BF16 skips only to tests that specifically exercise the standard bf16 Newton-Schulz path.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@source/tests/pt/test_hybrid_muon.py` around lines 201 - 205, Class-level BF16
gating is shadowing FP16-only Gram tests (e.g.,
TestHybridMuonOptimizer::test_enable_gram_rectangular_step_runs and tests in
TestColumnPadMergeEquivalence), preventing them from running on devices that
support FP16 but not BF16; fix by moving Gram-only tests into a class decorated
with `@unittest.skipIf`(not FP16_SUPPORTED, ...) or removing the class-level BF16
skip and instead apply method-level `@unittest.skipIf`(not BF16_SUPPORTED, ...)
only to tests that require BF16 (the BF16 Newton-Schulz path), ensuring FP16
tests reference FP16_SUPPORTED and BF16-specific tests reference BF16_SUPPORTED
while keeping test names like test_enable_gram_rectangular_step_runs and the
three TestColumnPadMergeEquivalence methods as-is.
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
deepmd/utils/argcheck.py (1)
3146-3155:⚠️ Potential issue | 🟡 MinorAlign the config default with the optimizer API default.
The schema now defaults
magma_muontoTrue, butHybridMuonOptimizer(..., magma_muon=False)still defaults direct Python usage toFalse. If this default change is intentional for all HybridMuon users, update the optimizer constructor/docstring too; otherwise keep the schema default consistent.🛠️ Possible fix if the new default is intended globally
- default=True, + default=False,or update
HybridMuonOptimizerto defaultmagma_muon=True.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/utils/argcheck.py` around lines 3146 - 3155, The schema sets Argument("magma_muon", ..., default=True) but the HybridMuonOptimizer constructor/docstring and its Python default remain False; make them consistent: either revert the Argument default to False to match HybridMuonOptimizer, or update HybridMuonOptimizer's __init__ signature and docstring to default magma_muon=True (and adjust any tests/docs referencing the old default). Locate the Argument for "magma_muon" in argcheck.py and the HybridMuonOptimizer class/constructor and docstring, then change the default value in one place and update the docstring text to reflect the chosen global default.
♻️ Duplicate comments (1)
source/tests/pt/test_hybrid_muon.py (1)
91-92:⚠️ Potential issue | 🟡 MinorSeparate BF16 standard-NS coverage from FP16 Gram coverage.
The class-level BF16 skips still suppress FP16-only Gram tests on devices that support fp16 but not bf16. Conversely, several rectangular default-Gram tests in
TestHybridMuonOptimizercan require fp16 but are only protected by the BF16 class gate. Move Gram-only tests under an FP16 gate, and keep BF16 skips only on tests that exercise the standard bf16 Newton-Schulz path.Also applies to: 201-205, 562-583, 634-671
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt/test_hybrid_muon.py` around lines 91 - 92, The class-level BF16 skip on TestHybridMuonOptimizer is too broad; remove the `@unittest.skipIf`(not BF16_SUPPORTED, ...) decorator from the class and instead: 1) add `@unittest.skipIf`(not BF16_SUPPORTED, ...) only to the individual tests that exercise the standard bf16 Newton-Schulz path (search for tests that invoke the Newton-Schulz / bf16-specific code paths), and 2) add `@unittest.skipIf`(not FP16_SUPPORTED, "fp16 Gram not supported on this device") to the Gram-only tests (tests that validate Gram matrix behavior or default-Gram rectangular cases). Apply the same change pattern to the other affected blocks mentioned (the tests around the other ranges) so BF16 gating is limited to bf16-NS tests and FP16 gating covers Gram-only tests.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt/optimizer/hybrid_muon.py`:
- Around line 136-145: The bare except in the Triton probe (around
triton.runtime.driver.active.get_current_target() and TRITON_AVAILABLE) must be
fixed to satisfy Ruff: either replace "except Exception:" with a narrowed tuple
of expected exceptions (e.g., except (ImportError, AttributeError, RuntimeError,
OSError):) to only catch likely probe failures, or keep the broad except but add
an explicit Ruff suppression comment on that except line (e.g., "# noqa: BLE001
— intentional broad probe to detect missing/misconfigured Triton driver on
CPU-only hosts") so the rationale is recorded next to the except; update only
the except handling around the probe and leave TRITON_AVAILABLE logic unchanged.
---
Outside diff comments:
In `@deepmd/utils/argcheck.py`:
- Around line 3146-3155: The schema sets Argument("magma_muon", ...,
default=True) but the HybridMuonOptimizer constructor/docstring and its Python
default remain False; make them consistent: either revert the Argument default
to False to match HybridMuonOptimizer, or update HybridMuonOptimizer's __init__
signature and docstring to default magma_muon=True (and adjust any tests/docs
referencing the old default). Locate the Argument for "magma_muon" in
argcheck.py and the HybridMuonOptimizer class/constructor and docstring, then
change the default value in one place and update the docstring text to reflect
the chosen global default.
---
Duplicate comments:
In `@source/tests/pt/test_hybrid_muon.py`:
- Around line 91-92: The class-level BF16 skip on TestHybridMuonOptimizer is too
broad; remove the `@unittest.skipIf`(not BF16_SUPPORTED, ...) decorator from the
class and instead: 1) add `@unittest.skipIf`(not BF16_SUPPORTED, ...) only to the
individual tests that exercise the standard bf16 Newton-Schulz path (search for
tests that invoke the Newton-Schulz / bf16-specific code paths), and 2) add
`@unittest.skipIf`(not FP16_SUPPORTED, "fp16 Gram not supported on this device")
to the Gram-only tests (tests that validate Gram matrix behavior or default-Gram
rectangular cases). Apply the same change pattern to the other affected blocks
mentioned (the tests around the other ranges) so BF16 gating is limited to
bf16-NS tests and FP16 gating covers Gram-only tests.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 4f119985-8d9a-446d-9d4a-f849db7eeab4
📒 Files selected for processing (4)
deepmd/pt/optimizer/hybrid_muon.pydeepmd/pt/train/training.pydeepmd/utils/argcheck.pysource/tests/pt/test_hybrid_muon.py
🚧 Files skipped from review as they are similar to previous changes (1)
- deepmd/pt/train/training.py
| try: | ||
| # Touching ``driver.active`` forces the lazy proxy to initialize the | ||
| # backend driver. ``get_current_target`` is the lightest public call | ||
| # that exercises the same path as ``Autotuner.__init__``. | ||
| triton.runtime.driver.active.get_current_target() | ||
| TRITON_AVAILABLE = True | ||
| except Exception: | ||
| # No usable runtime driver (no CUDA/ROCm/XPU, or a mis-configured | ||
| # one): fall back to the pure-PyTorch Newton-Schulz path. | ||
| TRITON_AVAILABLE = False |
There was a problem hiding this comment.
Make the intentional broad Triton probe exception pass Ruff.
Ruff reports BLE001 here. Since this import-time guard intentionally protects CPU/driver-less hosts, either narrow the exception set or add an explicit suppression with the rationale.
🛠️ Proposed fix
- except Exception:
+ except Exception: # noqa: BLE001
# No usable runtime driver (no CUDA/ROCm/XPU, or a mis-configured
# one): fall back to the pure-PyTorch Newton-Schulz path.
TRITON_AVAILABLE = FalseAs per coding guidelines: "**/*.py: Install linter and run ruff check . before committing changes or the CI will fail".
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| try: | |
| # Touching ``driver.active`` forces the lazy proxy to initialize the | |
| # backend driver. ``get_current_target`` is the lightest public call | |
| # that exercises the same path as ``Autotuner.__init__``. | |
| triton.runtime.driver.active.get_current_target() | |
| TRITON_AVAILABLE = True | |
| except Exception: | |
| # No usable runtime driver (no CUDA/ROCm/XPU, or a mis-configured | |
| # one): fall back to the pure-PyTorch Newton-Schulz path. | |
| TRITON_AVAILABLE = False | |
| try: | |
| # Touching ``driver.active`` forces the lazy proxy to initialize the | |
| # backend driver. ``get_current_target`` is the lightest public call | |
| # that exercises the same path as ``Autotuner.__init__``. | |
| triton.runtime.driver.active.get_current_target() | |
| TRITON_AVAILABLE = True | |
| except Exception: # noqa: BLE001 | |
| # No usable runtime driver (no CUDA/ROCm/XPU, or a mis-configured | |
| # one): fall back to the pure-PyTorch Newton-Schulz path. | |
| TRITON_AVAILABLE = False |
🧰 Tools
🪛 Ruff (0.15.10)
[warning] 142-142: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/pt/optimizer/hybrid_muon.py` around lines 136 - 145, The bare except
in the Triton probe (around triton.runtime.driver.active.get_current_target()
and TRITON_AVAILABLE) must be fixed to satisfy Ruff: either replace "except
Exception:" with a narrowed tuple of expected exceptions (e.g., except
(ImportError, AttributeError, RuntimeError, OSError):) to only catch likely
probe failures, or keep the broad except but add an explicit Ruff suppression
comment on that except line (e.g., "# noqa: BLE001 — intentional broad probe to
detect missing/misconfigured Triton driver on CPU-only hosts") so the rationale
is recorded next to the except; update only the except handling around the probe
and leave TRITON_AVAILABLE logic unchanged.
njzjz-bot
left a comment
There was a problem hiding this comment.
Overall this looks promising, but I'd like one clarification before I approve. In deepmd/pt/train/training.py, the handoff now uses bool(self.opt_param.get("flash_muon")) and bool(self.opt_param.get("magma_muon")) instead of explicit fallbacks. If these keys are ever absent in older configs / programmatic call sites, this silently flips missing values to False (notably flash_muon used to fall back to True, and magma_muon now also depends on config-layer defaults being materialized).
If config defaults are guaranteed to be filled before this point, then this is fine; otherwise I think the old explicit fallback style is safer.
— OpenClaw 2026.4.22 (model: gpt-5.4)
Summary by CodeRabbit
New Features
Chores
Tests