Skip to content

[OMNIML-3495] Add TEGroupedMLP export support for NemotronH models#967

Merged
yueshen2016 merged 1 commit intomainfrom
yueshen/Support-Nemotron-Export
Mar 9, 2026
Merged

[OMNIML-3495] Add TEGroupedMLP export support for NemotronH models#967
yueshen2016 merged 1 commit intomainfrom
yueshen/Support-Nemotron-Export

Conversation

@yueshen2016
Copy link
Copy Markdown
Contributor

@yueshen2016 yueshen2016 commented Mar 4, 2026

What does this PR do?

Type of change: New feature

Add export support for TEGroupedMLP (fused grouped GEMM experts) in the MCore-to-HuggingFace checkpoint exporter. Previously, the exporter only supported SequentialMLP (which has local_experts as a ModuleList). TEGroupedMLP stores per-expert weights as weight0, weight1, ..., weight{N-1} in a single TEGroupedLinear module instead. This caused an AttributeError: 'QuantTEGroupedMLP' object has no attribute 'local_experts' when exporting NemotronH models.

Changes:

  • Add GroupedMLPSlicing class in mcore_custom.py — the export counterpart of GroupedMLPMerging
  • Add _grouped_mlp_slicing method in GPTModelExporter that iterates TEGroupedLinear's per-expert weights and exports them as individual HF-format weights with proper quantization scale handling
  • Add "experts.linear_fc1" and "experts.linear_fc2" rules using GroupedMLPSlicing to nemotron_h_causal_lm_export
  • Route TEGroupedMLP (detected by absence of local_experts attribute) to the new "experts.linear_fc1" rule in _get_transformer_layer_state_dict

Usage

No API change. NemotronH models using TEGroupedMLP can now be exported:

import modelopt.torch.export as mtex

mtex.export_mcore_gpt_to_hf(
    model=megatron_model,
    export_dir="/path/to/hf_export",
    pretrained_model_name_or_path="/path/to/hf_model",
)

Testing

Inside Model-Bridge

torchrun --nproc_per_node 4 examples/quantization/export.py \
    --hf-model-id /models/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16/ \
    --megatron-load-path /models/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4-MLM \
    --export-dir /models/NVIDIA-Nemotron-3-Nano-30B-A3B-NVFP4-MLM_hf \
    --pp 4 \
    --dtype bfloat16 \
    --trust-remote-code

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, using torch.load(..., weights_only=True), avoiding pickle, etc.).

  • Is this change backward compatible?: ✅ The existing SequentialMLP (local_experts) path is guarded by hasattr(layer.mlp.experts, "local_experts") and remains unchanged. The new TEGroupedMLP path only activates when local_experts is absent and "experts.linear_fc1" is defined in the architecture's rules.
  • If you copied code from any other source, did you follow IP policy in CONTRIBUTING.md?: N/A
  • Did you write any new necessary tests?: ❌ Tested manually with Nemotron-3-Nano-30B-A3B. Unit test coverage should be added for _grouped_mlp_slicing.
  • Did you update Changelog?: ❌ New feature for a specific model architecture.

Additional Information

  • The import counterpart (GroupedMLPMerging / _grouped_mlp_merging) was added by @jennifchen in PR Latent MOE & Repeated MTP support for NemotronH; fix KV cache quant export #830. This PR completes the round-trip by adding the export side.
  • _grouped_mlp_slicing temporarily assigns module.weight = module.weight0 so that _get_quantized_state can extract qformat/scales from the module's quantizers, then removes it afterward. This follows the same pattern used by _QuantTEGroupedLinear._setup() in the quantization plugin.

Summary by CodeRabbit

  • New Features
    • Export now supports grouped-expert MLP slicing to split fused expert weights into per-expert tensors for downstream formats.
    • Per-expert export logic enhanced with clear fallbacks between packed and per-expert layouts, including a grouped-MLP export path.
    • Nemotron H causal LM import/export mappings updated to better align with grouped local-expert exports.
    • Added fused-normalization export support and safer handling when loading remote model code.

@yueshen2016 yueshen2016 requested a review from a team as a code owner March 4, 2026 09:01
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 4, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a GroupedMLPSlicing custom mapping and integrates grouped/per-expert TEGroupedMLP export handling into Nemotron mappings and the unified Megatron exporter, including grouped MLP slicing, per-expert or grouped weight emission, quantization-aware state-dict population, and fused-norm export paths.

Changes

Cohort / File(s) Summary
Grouped MLP Mapping
modelopt/torch/export/plugins/mcore_custom.py
Adds GroupedMLPSlicing(CustomModuleMapping) with __init__(self, target_name_or_prefix: str = "", func_kwargs: dict[str, Any] = {}) that registers func_name="grouped_mlp_slicing".
Nemotron mappings
modelopt/torch/export/plugins/mcore_nemotron.py
Imports GroupedMLPSlicing and adds public mapping entries in import/export mappings: "fused_norm": NameRemapping("backbone.layers.{}.norm.weight"), "experts.linear_fc1": GroupedMLPSlicing("backbone.layers.{}.mixer.experts.{{}}.up_proj"), and "experts.linear_fc2": GroupedMLPSlicing("backbone.layers.{}.mixer.experts.{{}}.down_proj").
Unified Megatron exporter
modelopt/torch/export/unified_export_megatron.py
Adds GPTModelExporter._grouped_mlp_slicing(self, module, prefix, parallel_config=None), registers "grouped_mlp_slicing" in custom mappings, implements grouped vs per-expert export flows (including TEGroupedMLP slicing), emits per-expert weights/scales/quant state, integrates grouped slicing into state_dict population, and adds fused-norm handling and trust_remote_code propagation.

Sequence Diagram(s)

sequenceDiagram
    participant Exporter as GPTModelExporter
    participant Registry as MappingRegistry
    participant Handler as _grouped_mlp_slicing
    participant Module as TEGroupedMLPModule
    participant State as StateDict

    Exporter->>Registry: lookup "grouped_mlp_slicing"
    Registry->>Handler: invoke(module, prefix, parallel_config)
    Handler->>Module: inspect local_experts, grouped weights, quant state, fused_norm
    alt per-expert path
        Handler->>Module: iterate experts
        loop per expert
            Handler->>State: emit expert weight, scale, quant metadata
        end
    else grouped/fused path
        Handler->>Module: read fused/grouped linear_fc tensors
        Handler->>State: slice grouped tensors -> per-expert entries
    end
    alt fused_norm present
        Handler->>State: emit fused_norm entry
    end
    State-->>Exporter: state dict entries populated
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 40.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title directly describes the main change: adding TEGroupedMLP export support for NemotronH models, which is the core objective of this changeset.
Security Anti-Patterns ✅ Passed The pull request does not contain critical security anti-patterns; trust_remote_code is properly exposed as caller-configurable with secure False default.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch yueshen/Support-Nemotron-Export

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
modelopt/torch/export/unified_export_megatron.py (1)

871-931: Please add focused tests for the new grouped slicing path.

Recommended coverage: non-quantized export, quantized export, missing expert-weight key behavior, and cleanup of temporary module.weight on exceptions.

If you want, I can draft a pytest matrix for these cases in a follow-up.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_megatron.py` around lines 871 - 931, Add
focused pytest unit tests for _grouped_mlp_slicing covering: (1) non-quantized
export where module has weight0..weightN and exported per-expert "weight"
entries are correct; (2) quantized export where _get_quantized_state returns
qformat/weight_scale(s) and exported per-expert "weight", "weight_scale", and
"weight_scale_2" use to_quantized_weight and cloned scales; (3) behavior when an
expert weight key (e.g., "weight2") is missing from module.state_dict — ensure
slicing skips that expert and others still export; and (4) cleanup when
_get_quantized_state or to_quantized_weight raises: ensure temporary assignment
of module.weight (done in _grouped_mlp_slicing) is removed after the call even
on exception. In tests, instantiate a minimal TEGroupedMLP-like object with
num_gemms, weight0..weightN in state_dict, control _get_quantized_state via
monkeypatch or fixture to simulate quantized/non-quantized returns, call
_grouped_mlp_slicing(prefix=...) and assert resulting self._state_dict
keys/values and that module has no lingering "weight" attribute after success or
exception.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 494-497: Update the inline comment describing the TEGroupedMLP
export path to reflect the correct mapping type: change the reference from
"GroupedMLPMerging" to "GroupedMLPSlicing" and clarify that the export uses the
"experts.linear_fc1" rule with GroupedMLPSlicing (not
"local_experts.linear_fc1"); modify the comment around TEGroupedMLP /
experts.linear_fc1 to name GroupedMLPSlicing so it matches the actual
implementation.
- Around line 883-907: The temporary assignment of module.weight =
module.weight0 before calling _get_quantized_state is not exception-safe and may
leave the alias in place if an error occurs; wrap the assignment and subsequent
calls to _get_quantized_state and _get_weight_scales in a try/finally so you
always deleteattr(module, "weight") when it was not originally present. Also
change the expert loop (using expert_prefix and weight_key =
f"weight{expert_id}") to fail fast instead of continuing silently when a
weight_key is missing from module.state_dict()—raise a clear exception
(including the expert_id and expert_prefix) so incomplete checkpoints are not
exported unnoticed.

---

Nitpick comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 871-931: Add focused pytest unit tests for _grouped_mlp_slicing
covering: (1) non-quantized export where module has weight0..weightN and
exported per-expert "weight" entries are correct; (2) quantized export where
_get_quantized_state returns qformat/weight_scale(s) and exported per-expert
"weight", "weight_scale", and "weight_scale_2" use to_quantized_weight and
cloned scales; (3) behavior when an expert weight key (e.g., "weight2") is
missing from module.state_dict — ensure slicing skips that expert and others
still export; and (4) cleanup when _get_quantized_state or to_quantized_weight
raises: ensure temporary assignment of module.weight (done in
_grouped_mlp_slicing) is removed after the call even on exception. In tests,
instantiate a minimal TEGroupedMLP-like object with num_gemms, weight0..weightN
in state_dict, control _get_quantized_state via monkeypatch or fixture to
simulate quantized/non-quantized returns, call _grouped_mlp_slicing(prefix=...)
and assert resulting self._state_dict keys/values and that module has no
lingering "weight" attribute after success or exception.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f82ada24-3356-48b5-b6c1-453fc3152768

📥 Commits

Reviewing files that changed from the base of the PR and between a34d613 and d076a185af6062f5aa5e9ea5273a3d71c9c61899.

📒 Files selected for processing (3)
  • modelopt/torch/export/plugins/mcore_custom.py
  • modelopt/torch/export/plugins/mcore_nemotron.py
  • modelopt/torch/export/unified_export_megatron.py

Comment thread modelopt/torch/export/unified_export_megatron.py
Comment thread modelopt/torch/export/unified_export_megatron.py
@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 4, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 72.14%. Comparing base (a076e6c) to head (8f9b734).
⚠️ Report is 19 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #967      +/-   ##
==========================================
+ Coverage   72.12%   72.14%   +0.02%     
==========================================
  Files         209      209              
  Lines       23628    23667      +39     
==========================================
+ Hits        17042    17075      +33     
- Misses       6586     6592       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@yueshen2016 yueshen2016 force-pushed the yueshen/Support-Nemotron-Export branch 2 times, most recently from c3e9f46 to d879084 Compare March 7, 2026 00:24
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
modelopt/torch/export/unified_export_megatron.py (1)

892-903: ⚠️ Potential issue | 🟠 Major

Make the temporary module.weight alias exception-safe and fail fast on missing expert weights.

If _get_quantized_state(...) or _get_weight_scales(...) throws, Line 894 can leave module.weight behind. And Lines 911-912 silently drop experts, which can produce an incomplete checkpoint with no error.

Suggested fix
         has_weight = hasattr(module, "weight")
         if not has_weight:
             module.weight = module.weight0
-
-        name_to_value, qformat, block_size = self._get_quantized_state(
-            module, self.dtype, prefix=prefix
-        )
-        weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
-        name_to_value.pop("weight", None)
-
-        if not has_weight:
-            delattr(module, "weight")
+        try:
+            name_to_value, qformat, block_size = self._get_quantized_state(
+                module, self.dtype, prefix=prefix
+            )
+            weight_scale, weight_scale_2 = self._get_weight_scales(name_to_value, qformat)
+            name_to_value.pop("weight", None)
+        finally:
+            if not has_weight and hasattr(module, "weight"):
+                delattr(module, "weight")
@@
-            if weight_key not in state_dict:
-                continue
+            if weight_key not in state_dict:
+                raise ValueError(
+                    f"Missing expected TEGroupedMLP expert weight {weight_key!r} for {expert_prefix}"
+                )

Also applies to: 911-912

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_megatron.py` around lines 892 - 903,
Wrap the temporary aliasing of module.weight (where code sets module.weight =
module.weight0 when has_weight is False) in a try/finally so that module.weight
is always removed in the finally block even if _get_quantized_state or
_get_weight_scales throws; locate the aliasing and restoration around the calls
to _get_quantized_state and _get_weight_scales and move name_to_value, qformat,
block_size assignment into the try. Also replace the silent pop of "weight" from
name_to_value (name_to_value.pop("weight", None)) with an explicit existence
check and raise a clear exception if expected expert weights are missing so the
export fails fast rather than producing an incomplete checkpoint.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 919-929: The export is quantizing each 2D expert weight using the
full grouped-module scale tensors, which breaks when TEGroupedLinear provides
per-expert (batched) scales; modify the export to slice the per-expert portions
of weight_scale (and weight_scale_2) before calling to_quantized_weight so the
scale shape matches the 2D weight being quantized. Locate the block that assigns
self._state_dict[expert_prefix + "weight"] and self._state_dict[expert_prefix +
"weight_scale(_2)"], slice weight_scale (and weight_scale_2 when not None) to
the single-expert index corresponding to the current expert_prefix, then pass
those sliced tensors to to_quantized_weight and store the sliced clones in the
state dict to ensure correct broadcasting and export for per-expert scales.

---

Duplicate comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 892-903: Wrap the temporary aliasing of module.weight (where code
sets module.weight = module.weight0 when has_weight is False) in a try/finally
so that module.weight is always removed in the finally block even if
_get_quantized_state or _get_weight_scales throws; locate the aliasing and
restoration around the calls to _get_quantized_state and _get_weight_scales and
move name_to_value, qformat, block_size assignment into the try. Also replace
the silent pop of "weight" from name_to_value (name_to_value.pop("weight",
None)) with an explicit existence check and raise a clear exception if expected
expert weights are missing so the export fails fast rather than producing an
incomplete checkpoint.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 4fc251f3-1c3c-4d69-bc1a-e267f4f500cd

📥 Commits

Reviewing files that changed from the base of the PR and between c3e9f461d5f466432deb481e5c2bd6d395b85e9c and d87908454b2b40f56ac947ff4a372d0464e08ec0.

📒 Files selected for processing (3)
  • modelopt/torch/export/plugins/mcore_custom.py
  • modelopt/torch/export/plugins/mcore_nemotron.py
  • modelopt/torch/export/unified_export_megatron.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/export/plugins/mcore_nemotron.py

Comment thread modelopt/torch/export/unified_export_megatron.py
@yueshen2016 yueshen2016 force-pushed the yueshen/Support-Nemotron-Export branch 2 times, most recently from f38ee5a to 0e92059 Compare March 7, 2026 08:08
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
modelopt/torch/export/unified_export_megatron.py (1)

917-932: ⚠️ Potential issue | 🟠 Major

Slice grouped quantization scales down to one expert before packing.

modelopt/torch/export/quant_utils.py:890-991 only handles batched per-expert scales specially when weight.dim() == 3. Here each weight{i} is 2D, so Lines 922-932 reuse the full grouped weight_scale / weight_scale_2 tensor and will mis-broadcast or fail when TEGroupedLinear exposes per-expert scales.

Suggested fix
         for expert_id in range(num_experts):
             expert_prefix = prefix.format(expert_id) + "."
             weight_key = f"weight{expert_id}"
@@
             weight = state_dict[weight_key].to(self.dtype).cpu()
+            expert_weight_scale = weight_scale
+            if (
+                weight_scale is not None
+                and weight_scale.dim() > 0
+                and weight_scale.shape[0] == num_experts
+            ):
+                expert_weight_scale = weight_scale[expert_id]
+
+            expert_weight_scale_2 = weight_scale_2
+            if (
+                weight_scale_2 is not None
+                and weight_scale_2.dim() > 0
+                and weight_scale_2.shape[0] == num_experts
+            ):
+                expert_weight_scale_2 = weight_scale_2[expert_id]
 
             if weight_scale is None:
                 self._state_dict[expert_prefix + "weight"] = weight
             else:
                 self._state_dict[expert_prefix + "weight"] = to_quantized_weight(
                     weight,
-                    weight_scale,
+                    expert_weight_scale,
                     qformat,
-                    weight_scale_2,
+                    expert_weight_scale_2,
                     block_size,
                 )
-                self._state_dict[expert_prefix + "weight_scale"] = weight_scale.detach().clone()
+                self._state_dict[expert_prefix + "weight_scale"] = (
+                    expert_weight_scale.detach().clone()
+                )
 
-            if weight_scale_2 is not None:
-                self._state_dict[expert_prefix + "weight_scale_2"] = weight_scale_2.detach().clone()
+            if expert_weight_scale_2 is not None:
+                self._state_dict[expert_prefix + "weight_scale_2"] = (
+                    expert_weight_scale_2.detach().clone()
+                )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_megatron.py` around lines 917 - 932, The
grouped per-expert scale tensors (weight_scale and weight_scale_2) must be
sliced to the current expert before storing or passing to to_quantized_weight;
update the code around where you build self._state_dict[expert_prefix +
"weight"] to detect if weight_scale/weight_scale_2 are batched per-expert (e.g.,
have an extra leading dimension matching number of experts) and index them for
the current expert, then pass those single-expert scale tensors into
to_quantized_weight and store only single-expert copies in self._state_dict (use
.detach().clone() after slicing); adjust the branches that set
self._state_dict[expert_prefix + "weight_scale"] and
self._state_dict[expert_prefix + "weight_scale_2"] accordingly so they contain
per-expert slices rather than the full grouped tensors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 543-550: The transformer export path must also invoke the
fused_norm rule when a layer's input_layernorm or pre_mlp_layernorm is an
IdentityOp so the backbone.layers.{i}.norm.weight isn't omitted; in
_get_transformer_layer_state_dict detect when layer.input_layernorm (and
separately layer.pre_mlp_layernorm) is an IdentityOp and if "fused_norm" exists
in self.rules, call self.rules["fused_norm"](...) with the appropriate weight
tensor (e.g., layer.input_layernorm.weight and layer.pre_mlp_layernorm.weight)
and layer_id, guarding for attribute existence and non-None weights before
calling.
- Around line 267-268: Remove the interactive breakpoint call ("import ipdb;
ipdb.set_trace()") that sits under the "Main export process" in
unified_export_megatron.py so exports won't block CI or require ipdb; simply
delete this line (or guard it behind a debug flag/env check) and, if you need
traceability, replace it with a logging.debug call instead of importing ipdb.

---

Duplicate comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 917-932: The grouped per-expert scale tensors (weight_scale and
weight_scale_2) must be sliced to the current expert before storing or passing
to to_quantized_weight; update the code around where you build
self._state_dict[expert_prefix + "weight"] to detect if
weight_scale/weight_scale_2 are batched per-expert (e.g., have an extra leading
dimension matching number of experts) and index them for the current expert,
then pass those single-expert scale tensors into to_quantized_weight and store
only single-expert copies in self._state_dict (use .detach().clone() after
slicing); adjust the branches that set self._state_dict[expert_prefix +
"weight_scale"] and self._state_dict[expert_prefix + "weight_scale_2"]
accordingly so they contain per-expert slices rather than the full grouped
tensors.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: cd6e71ef-3c7a-4e41-96a4-7a29216e75f1

📥 Commits

Reviewing files that changed from the base of the PR and between d87908454b2b40f56ac947ff4a372d0464e08ec0 and f38ee5ad4ad444b7897066d107840aa04d934051.

📒 Files selected for processing (3)
  • modelopt/torch/export/plugins/mcore_custom.py
  • modelopt/torch/export/plugins/mcore_nemotron.py
  • modelopt/torch/export/unified_export_megatron.py

Comment on lines +267 to +268
# Main export process
import ipdb; ipdb.set_trace()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Remove the interactive breakpoint.

Line 268 drops into ipdb during every export, which blocks CI/non-interactive runs and also fails outright when ipdb is not installed.

Suggested fix
-        import ipdb; ipdb.set_trace()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_megatron.py` around lines 267 - 268,
Remove the interactive breakpoint call ("import ipdb; ipdb.set_trace()") that
sits under the "Main export process" in unified_export_megatron.py so exports
won't block CI or require ipdb; simply delete this line (or guard it behind a
debug flag/env check) and, if you need traceability, replace it with a
logging.debug call instead of importing ipdb.

Comment thread modelopt/torch/export/unified_export_megatron.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (3)
modelopt/torch/export/unified_export_megatron.py (3)

542-549: ⚠️ Potential issue | 🟠 Major

Mirror this fused_norm fallback in GPTModelExporter._get_transformer_layer_state_dict().

This only covers MambaLayer. TE-spec transformer layers still skip backbone.layers.{i}.norm.weight when the standalone norms are IdentityOp, so the new Nemotron H fused_norm export path is still incomplete.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_megatron.py` around lines 542 - 549, The
transformer exporter is missing the fused_norm fallback for layers whose
standalone norm is IdentityOp; update
GPTModelExporter._get_transformer_layer_state_dict to mirror the logic in
unified_export_megatron.py: detect when layer.norm is an IdentityOp and
layer.mixer.in_proj has a non-None layer_norm_weight and "fused_norm" exists in
self.rules, then call
self.rules["fused_norm"](layer.mixer.in_proj.layer_norm_weight, layer_id) so
backbone.layers.{i}.norm.weight is emitted for TE-spec fused norms (refer to the
existing fused_norm handling block and apply the same checks and rule invocation
inside _get_transformer_layer_state_dict).

496-499: ⚠️ Potential issue | 🟡 Minor

Rename the mapping in this comment to GroupedMLPSlicing.

This export branch goes through self.rules["experts.linear_fc1"], which resolves to GroupedMLPSlicing in modelopt/torch/export/plugins/mcore_nemotron.py, not GroupedMLPMerging.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_megatron.py` around lines 496 - 499,
Update the comment for TEGroupedMLP to rename the mapping from
"GroupedMLPMerging" to "GroupedMLPSlicing": clarify that this branch goes
through self.rules["experts.linear_fc1"] which resolves to GroupedMLPSlicing
(not GroupedMLPMerging) as implemented in
modelopt/torch/export/plugins/mcore_nemotron.py; keep the rest of the note about
using "experts.linear_fc1" instead of "local_experts.linear_fc1" intact.

918-931: ⚠️ Potential issue | 🟠 Major

Slice the grouped scales down to the current expert before calling to_quantized_weight().

modelopt/torch/export/quant_utils.py::to_quantized_weight only special-cases batched MoE scales when weight.dim() == 3. Here each expert weight is 2D, so reusing the full grouped weight_scale / weight_scale_2 tensor will mis-broadcast or fail for quantized TEGroupedMLP exports when scales are stored per expert.

Proposed fix
         for expert_id in range(num_experts):
             expert_prefix = prefix.format(expert_id) + "."
             weight_key = f"weight{expert_id}"
@@
             weight = state_dict[weight_key].to(self.dtype).cpu()
+            expert_weight_scale = weight_scale
+            if (
+                weight_scale is not None
+                and weight_scale.dim() > 0
+                and weight_scale.shape[0] == num_experts
+            ):
+                expert_weight_scale = weight_scale[expert_id]
+
+            expert_weight_scale_2 = weight_scale_2
+            if (
+                weight_scale_2 is not None
+                and weight_scale_2.dim() > 0
+                and weight_scale_2.shape[0] == num_experts
+            ):
+                expert_weight_scale_2 = weight_scale_2[expert_id]
 
-            if weight_scale is None:
+            if expert_weight_scale is None:
                 self._state_dict[expert_prefix + "weight"] = weight
             else:
                 self._state_dict[expert_prefix + "weight"] = to_quantized_weight(
                     weight,
-                    weight_scale,
+                    expert_weight_scale,
                     qformat,
-                    weight_scale_2,
+                    expert_weight_scale_2,
                     block_size,
                 )
-                self._state_dict[expert_prefix + "weight_scale"] = weight_scale.detach().clone()
+                self._state_dict[expert_prefix + "weight_scale"] = (
+                    expert_weight_scale.detach().clone()
+                )
 
-            if weight_scale_2 is not None:
-                self._state_dict[expert_prefix + "weight_scale_2"] = weight_scale_2.detach().clone()
+            if expert_weight_scale_2 is not None:
+                self._state_dict[expert_prefix + "weight_scale_2"] = (
+                    expert_weight_scale_2.detach().clone()
+                )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_megatron.py` around lines 918 - 931, The
export is passing grouped per-expert scale tensors into to_quantized_weight
while weight is a single-expert 2D tensor, causing mis-broadcasts; before
calling to_quantized_weight in the block that handles weight_scale not None,
slice weight_scale (and weight_scale_2 if present) to the current expert index
so they match weight's shape, then pass those per-expert slices into
to_quantized_weight and store the detached/cloned per-expert slices into
self._state_dict[expert_prefix + "weight_scale"] and
self._state_dict[expert_prefix + "weight_scale_2"] (adjust the existing
assignments around the to_quantized_weight call and the later weight_scale_2
handling to use the sliced tensors).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@modelopt/torch/export/unified_export_megatron.py`:
- Around line 542-549: The transformer exporter is missing the fused_norm
fallback for layers whose standalone norm is IdentityOp; update
GPTModelExporter._get_transformer_layer_state_dict to mirror the logic in
unified_export_megatron.py: detect when layer.norm is an IdentityOp and
layer.mixer.in_proj has a non-None layer_norm_weight and "fused_norm" exists in
self.rules, then call
self.rules["fused_norm"](layer.mixer.in_proj.layer_norm_weight, layer_id) so
backbone.layers.{i}.norm.weight is emitted for TE-spec fused norms (refer to the
existing fused_norm handling block and apply the same checks and rule invocation
inside _get_transformer_layer_state_dict).
- Around line 496-499: Update the comment for TEGroupedMLP to rename the mapping
from "GroupedMLPMerging" to "GroupedMLPSlicing": clarify that this branch goes
through self.rules["experts.linear_fc1"] which resolves to GroupedMLPSlicing
(not GroupedMLPMerging) as implemented in
modelopt/torch/export/plugins/mcore_nemotron.py; keep the rest of the note about
using "experts.linear_fc1" instead of "local_experts.linear_fc1" intact.
- Around line 918-931: The export is passing grouped per-expert scale tensors
into to_quantized_weight while weight is a single-expert 2D tensor, causing
mis-broadcasts; before calling to_quantized_weight in the block that handles
weight_scale not None, slice weight_scale (and weight_scale_2 if present) to the
current expert index so they match weight's shape, then pass those per-expert
slices into to_quantized_weight and store the detached/cloned per-expert slices
into self._state_dict[expert_prefix + "weight_scale"] and
self._state_dict[expert_prefix + "weight_scale_2"] (adjust the existing
assignments around the to_quantized_weight call and the later weight_scale_2
handling to use the sliced tensors).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 60292d3b-30e4-4857-8277-fdfbcf4048ec

📥 Commits

Reviewing files that changed from the base of the PR and between f38ee5ad4ad444b7897066d107840aa04d934051 and 0e92059e57cead27d81c534652f43ebc354a8fed.

📒 Files selected for processing (3)
  • modelopt/torch/export/plugins/mcore_custom.py
  • modelopt/torch/export/plugins/mcore_nemotron.py
  • modelopt/torch/export/unified_export_megatron.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • modelopt/torch/export/plugins/mcore_custom.py

@yueshen2016 yueshen2016 force-pushed the yueshen/Support-Nemotron-Export branch 3 times, most recently from cbed1cb to f3134cf Compare March 7, 2026 09:45
Signed-off-by: James Shen <yueshen@nvidia.com>
@yueshen2016 yueshen2016 force-pushed the yueshen/Support-Nemotron-Export branch from f3134cf to 8f9b734 Compare March 7, 2026 10:08
@yueshen2016 yueshen2016 enabled auto-merge (squash) March 7, 2026 21:18
Copy link
Copy Markdown
Collaborator

@ChenhanYu ChenhanYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve; I think the TE fused layernorm handling is having too many if conditions to satisfy. Could you try to see if Calude code can optimize it in the following PR?

@yueshen2016 yueshen2016 merged commit 1d6ec89 into main Mar 9, 2026
40 checks passed
@yueshen2016 yueshen2016 deleted the yueshen/Support-Nemotron-Export branch March 9, 2026 19:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants