Skip to content

Commit 7cd7d18

Browse files
yueshen2016claude
andauthored
[minor] Refactor TE fused-norm handling in GPTModelExporter (#1061)
### What does this PR do? Type of change: Refactor (no behavior change) Extract a `_get_fused_norm_weight` helper method in `GPTModelExporter` to consolidate the repeated TE fused-norm detection logic that previously appeared as three separate multi-condition `elif` blocks in `_get_transformer_layer_state_dict` (attention and MLP paths) and `_get_mamba_layer_state_dict`. Changes: - Add `_get_fused_norm_weight(module)` that checks `"fused_norm" in self.rules` and returns `getattr(module, "layer_norm_weight", None)` - Replace double `hasattr` chains with `getattr(..., None)` chaining — `getattr` already returns `None` for missing attributes - Remove redundant `isinstance(layer.norm, IdentityOp)` in the Mamba `elif` (guaranteed by being an `elif` branch) - Use walrus operator (`:=`) to capture `norm_weight` without repeating the attribute traversal on the call line ### Before / After Before (one of three nearly-identical blocks): ```python elif ( hasattr(layer.self_attention, "linear_qkv") and hasattr(layer.self_attention.linear_qkv, "layer_norm_weight") and layer.self_attention.linear_qkv.layer_norm_weight is not None and "fused_norm" in self.rules ): self.rules["fused_norm"](layer.self_attention.linear_qkv.layer_norm_weight, layer_id) ``` After: ```python elif ( norm_weight := self._get_fused_norm_weight( getattr(layer.self_attention, "linear_qkv", None) ) ) is not None: self.rules["fused_norm"](norm_weight, layer_id) ``` ### Testing No behavior change — existing tests cover all paths. - Is this change backward compatible?: ✅ Pure refactor, no API or logic change - Did you write any new necessary tests?: N/A - Did you update Changelog?: N/A (minor refactor) 🤖 Generated with [Claude Code](https://claude.com/claude-code) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Centralized and simplified fused-normalization handling in model export, reducing duplicated checks and streamlining control flow while preserving existing behavior and compatibility. Improved maintainability and consistency across export paths. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: James Shen <yueshen@nvidia.com> Signed-off-by: Yue Shen <yueshen@nvidia.com> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent de55e8a commit 7cd7d18

1 file changed

Lines changed: 20 additions & 18 deletions

File tree

modelopt/torch/export/unified_export_megatron.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -419,16 +419,25 @@ def _get_state_dict(self):
419419
if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights:
420420
self.rules["output_layer"](model.output_layer)
421421

422+
def _get_fused_norm_weight(self, module):
423+
"""Return ``module.layer_norm_weight`` when TE fuses the norm into a linear layer.
424+
425+
Returns ``None`` when the ``"fused_norm"`` rule is absent or the module has no
426+
``layer_norm_weight`` attribute (or its value is ``None``).
427+
"""
428+
if "fused_norm" not in self.rules:
429+
return None
430+
return getattr(module, "layer_norm_weight", None)
431+
422432
def _get_transformer_layer_state_dict(self, layer, layer_id):
423433
if not isinstance(layer.input_layernorm, IdentityOp):
424434
self.rules["input_layernorm"](layer.input_layernorm, layer_id)
425435
elif (
426-
hasattr(layer.self_attention, "linear_qkv")
427-
and hasattr(layer.self_attention.linear_qkv, "layer_norm_weight")
428-
and layer.self_attention.linear_qkv.layer_norm_weight is not None
429-
and "fused_norm" in self.rules
430-
):
431-
self.rules["fused_norm"](layer.self_attention.linear_qkv.layer_norm_weight, layer_id)
436+
norm_weight := self._get_fused_norm_weight(
437+
getattr(layer.self_attention, "linear_qkv", None)
438+
)
439+
) is not None:
440+
self.rules["fused_norm"](norm_weight, layer_id)
432441

433442
if not isinstance(layer.self_attention, IdentityOp):
434443
if "MLASelfAttention" in str(type(layer.self_attention)):
@@ -470,12 +479,10 @@ def _get_transformer_layer_state_dict(self, layer, layer_id):
470479
elif (
471480
not isinstance(layer.mlp, IdentityOp)
472481
and "MoE" not in str(type(layer.mlp))
473-
and hasattr(layer.mlp, "linear_fc1")
474-
and hasattr(layer.mlp.linear_fc1, "layer_norm_weight")
475-
and layer.mlp.linear_fc1.layer_norm_weight is not None
476-
and "fused_norm" in self.rules
482+
and (norm_weight := self._get_fused_norm_weight(getattr(layer.mlp, "linear_fc1", None)))
483+
is not None
477484
):
478-
self.rules["fused_norm"](layer.mlp.linear_fc1.layer_norm_weight, layer_id)
485+
self.rules["fused_norm"](norm_weight, layer_id)
479486

480487
if not isinstance(layer.mlp, IdentityOp):
481488
if "MoE" in str(type(layer.mlp)):
@@ -555,14 +562,9 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]:
555562
def _get_mamba_layer_state_dict(self, layer, layer_id):
556563
if not isinstance(layer.norm, IdentityOp):
557564
self.rules["norm"](layer.norm, layer_id)
558-
elif (
559-
isinstance(layer.norm, IdentityOp)
560-
and hasattr(layer.mixer.in_proj, "layer_norm_weight")
561-
and layer.mixer.in_proj.layer_norm_weight is not None
562-
and "fused_norm" in self.rules
563-
):
565+
elif (norm_weight := self._get_fused_norm_weight(layer.mixer.in_proj)) is not None:
564566
# TE spec: norm is fused into in_proj (QuantTELayerNormColumnParallelLinear).
565-
self.rules["fused_norm"](layer.mixer.in_proj.layer_norm_weight, layer_id)
567+
self.rules["fused_norm"](norm_weight, layer_id)
566568

567569
self.rules["mixer_norm"](layer.mixer.norm, layer_id)
568570
self.rules["A_log"](layer.mixer.A_log, layer_id)

0 commit comments

Comments
 (0)