Skip to content

Commit 8151232

Browse files
authored
Merge branch 'main' into jingyux/diffusion-skip-softmax
2 parents 1f8f0d3 + fcb09bf commit 8151232

2 files changed

Lines changed: 32 additions & 18 deletions

File tree

modelopt/torch/export/layer_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,18 @@ def sync_moe_gate_up_amax(model: nn.Module) -> int:
11841184
up_amax = getattr(up_wq, "amax", None)
11851185
if gate_amax is None or up_amax is None:
11861186
break
1187+
# Meta tensors have no storage (e.g. CPU-offloaded experts that
1188+
# were never activated during calibration). Skip — there is no
1189+
# real amax data to sync.
1190+
if gate_amax.is_meta or up_amax.is_meta:
1191+
warn(
1192+
f"Skipping gate/up amax sync for expert with meta tensors "
1193+
f"(gate_amax.is_meta={gate_amax.is_meta}, "
1194+
f"up_amax.is_meta={up_amax.is_meta}). "
1195+
f"This typically means the expert was CPU-offloaded and "
1196+
f"not activated during calibration."
1197+
)
1198+
break
11871199
if not torch.equal(gate_amax, up_amax):
11881200
shared_amax = torch.max(gate_amax, up_amax)
11891201
gate_wq.amax = shared_amax

modelopt/torch/export/unified_export_megatron.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -419,16 +419,25 @@ def _get_state_dict(self):
419419
if hasattr(model, "output_layer") and not model.share_embeddings_and_output_weights:
420420
self.rules["output_layer"](model.output_layer)
421421

422+
def _get_fused_norm_weight(self, module):
423+
"""Return ``module.layer_norm_weight`` when TE fuses the norm into a linear layer.
424+
425+
Returns ``None`` when the ``"fused_norm"`` rule is absent or the module has no
426+
``layer_norm_weight`` attribute (or its value is ``None``).
427+
"""
428+
if "fused_norm" not in self.rules:
429+
return None
430+
return getattr(module, "layer_norm_weight", None)
431+
422432
def _get_transformer_layer_state_dict(self, layer, layer_id):
423433
if not isinstance(layer.input_layernorm, IdentityOp):
424434
self.rules["input_layernorm"](layer.input_layernorm, layer_id)
425435
elif (
426-
hasattr(layer.self_attention, "linear_qkv")
427-
and hasattr(layer.self_attention.linear_qkv, "layer_norm_weight")
428-
and layer.self_attention.linear_qkv.layer_norm_weight is not None
429-
and "fused_norm" in self.rules
430-
):
431-
self.rules["fused_norm"](layer.self_attention.linear_qkv.layer_norm_weight, layer_id)
436+
norm_weight := self._get_fused_norm_weight(
437+
getattr(layer.self_attention, "linear_qkv", None)
438+
)
439+
) is not None:
440+
self.rules["fused_norm"](norm_weight, layer_id)
432441

433442
if not isinstance(layer.self_attention, IdentityOp):
434443
if "MLASelfAttention" in str(type(layer.self_attention)):
@@ -470,12 +479,10 @@ def _get_transformer_layer_state_dict(self, layer, layer_id):
470479
elif (
471480
not isinstance(layer.mlp, IdentityOp)
472481
and "MoE" not in str(type(layer.mlp))
473-
and hasattr(layer.mlp, "linear_fc1")
474-
and hasattr(layer.mlp.linear_fc1, "layer_norm_weight")
475-
and layer.mlp.linear_fc1.layer_norm_weight is not None
476-
and "fused_norm" in self.rules
482+
and (norm_weight := self._get_fused_norm_weight(getattr(layer.mlp, "linear_fc1", None)))
483+
is not None
477484
):
478-
self.rules["fused_norm"](layer.mlp.linear_fc1.layer_norm_weight, layer_id)
485+
self.rules["fused_norm"](norm_weight, layer_id)
479486

480487
if not isinstance(layer.mlp, IdentityOp):
481488
if "MoE" in str(type(layer.mlp)):
@@ -555,14 +562,9 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]:
555562
def _get_mamba_layer_state_dict(self, layer, layer_id):
556563
if not isinstance(layer.norm, IdentityOp):
557564
self.rules["norm"](layer.norm, layer_id)
558-
elif (
559-
isinstance(layer.norm, IdentityOp)
560-
and hasattr(layer.mixer.in_proj, "layer_norm_weight")
561-
and layer.mixer.in_proj.layer_norm_weight is not None
562-
and "fused_norm" in self.rules
563-
):
565+
elif (norm_weight := self._get_fused_norm_weight(layer.mixer.in_proj)) is not None:
564566
# TE spec: norm is fused into in_proj (QuantTELayerNormColumnParallelLinear).
565-
self.rules["fused_norm"](layer.mixer.in_proj.layer_norm_weight, layer_id)
567+
self.rules["fused_norm"](norm_weight, layer_id)
566568

567569
self.rules["mixer_norm"](layer.mixer.norm, layer_id)
568570
self.rules["A_log"](layer.mixer.A_log, layer_id)

0 commit comments

Comments
 (0)