Skip to content

Commit ff4332f

Browse files
committed
Fix export of fused layernorm weights for TE spec.
Signed-off-by: James Shen <yueshen@nvidia.com>
1 parent 5e43b2a commit ff4332f

File tree

1 file changed

+47
-0
lines changed

1 file changed

+47
-0
lines changed

modelopt/torch/export/unified_export_megatron.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,26 @@
8585
]
8686

8787

88+
class _FusedLayerNormProxy(torch.nn.Module):
89+
"""Proxy module exposing fused layernorm weights from TELayerNormColumnParallelLinear.
90+
91+
When using TE spec, the input layernorm and pre-MLP layernorm are fused into the
92+
subsequent linear layer (TELayerNormColumnParallelLinear). The layernorm weight is
93+
stored as ``layer_norm_weight`` on the fused linear module rather than as a separate
94+
``weight`` on a standalone layernorm module.
95+
96+
This proxy wraps that fused weight so the existing export rules (which expect a
97+
module with a ``.weight`` attribute) can export it with the correct HF key name.
98+
"""
99+
100+
def __init__(self, fused_linear: torch.nn.Module):
101+
super().__init__()
102+
self.weight = fused_linear.layer_norm_weight
103+
bias = getattr(fused_linear, "layer_norm_bias", None)
104+
if bias is not None:
105+
self.bias = bias
106+
107+
88108
class GPTModelExporter:
89109
"""Megatron Core GPTModel Exporter.
90110
@@ -489,6 +509,17 @@ def _get_state_dict(self):
489509
def _get_transformer_layer_state_dict(self, layer, layer_id):
490510
if not isinstance(layer.input_layernorm, IdentityOp):
491511
self.rules["input_layernorm"](layer.input_layernorm, layer_id)
512+
elif (
513+
"input_layernorm" in self.rules
514+
and hasattr(layer, "self_attention")
515+
and not isinstance(layer.self_attention, IdentityOp)
516+
and hasattr(layer.self_attention, "linear_qkv")
517+
and hasattr(layer.self_attention.linear_qkv, "layer_norm_weight")
518+
):
519+
# TE spec: input layernorm is fused into TELayerNormColumnParallelLinear
520+
self.rules["input_layernorm"](
521+
_FusedLayerNormProxy(layer.self_attention.linear_qkv), layer_id
522+
)
492523

493524
if not isinstance(layer.self_attention, IdentityOp):
494525
if "MLASelfAttention" in str(type(layer.self_attention)):
@@ -527,6 +558,14 @@ def _get_transformer_layer_state_dict(self, layer, layer_id):
527558

528559
if not isinstance(layer.pre_mlp_layernorm, IdentityOp):
529560
self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id)
561+
elif (
562+
"pre_mlp_layernorm" in self.rules
563+
and not isinstance(layer.mlp, IdentityOp)
564+
and hasattr(layer.mlp, "linear_fc1")
565+
and hasattr(layer.mlp.linear_fc1, "layer_norm_weight")
566+
):
567+
# TE spec: pre-MLP layernorm is fused into TELayerNormColumnParallelLinear
568+
self.rules["pre_mlp_layernorm"](_FusedLayerNormProxy(layer.mlp.linear_fc1), layer_id)
530569

531570
if not isinstance(layer.mlp, IdentityOp):
532571
if "MoE" in str(type(layer.mlp)):
@@ -598,6 +637,14 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]:
598637
def _get_mamba_layer_state_dict(self, layer, layer_id):
599638
if not isinstance(layer.norm, IdentityOp):
600639
self.rules["norm"](layer.norm, layer_id)
640+
elif (
641+
"norm" in self.rules
642+
and hasattr(layer, "mixer")
643+
and hasattr(layer.mixer, "in_proj")
644+
and hasattr(layer.mixer.in_proj, "layer_norm_weight")
645+
):
646+
# TE spec: norm is fused into TELayerNormColumnParallelLinear (in_proj)
647+
self.rules["norm"](_FusedLayerNormProxy(layer.mixer.in_proj), layer_id)
601648

602649
self.rules["mixer_norm"](layer.mixer.norm, layer_id)
603650
self.rules["A_log"](layer.mixer.A_log, layer_id)

0 commit comments

Comments
 (0)