@@ -419,16 +419,25 @@ def _get_state_dict(self):
419419 if hasattr (model , "output_layer" ) and not model .share_embeddings_and_output_weights :
420420 self .rules ["output_layer" ](model .output_layer )
421421
422+ def _get_fused_norm_weight (self , module ):
423+ """Return ``module.layer_norm_weight`` when TE fuses the norm into a linear layer.
424+
425+ Returns ``None`` when the ``"fused_norm"`` rule is absent or the module has no
426+ ``layer_norm_weight`` attribute (or its value is ``None``).
427+ """
428+ if "fused_norm" not in self .rules :
429+ return None
430+ return getattr (module , "layer_norm_weight" , None )
431+
422432 def _get_transformer_layer_state_dict (self , layer , layer_id ):
423433 if not isinstance (layer .input_layernorm , IdentityOp ):
424434 self .rules ["input_layernorm" ](layer .input_layernorm , layer_id )
425435 elif (
426- hasattr (layer .self_attention , "linear_qkv" )
427- and hasattr (layer .self_attention .linear_qkv , "layer_norm_weight" )
428- and layer .self_attention .linear_qkv .layer_norm_weight is not None
429- and "fused_norm" in self .rules
430- ):
431- self .rules ["fused_norm" ](layer .self_attention .linear_qkv .layer_norm_weight , layer_id )
436+ norm_weight := self ._get_fused_norm_weight (
437+ getattr (layer .self_attention , "linear_qkv" , None )
438+ )
439+ ) is not None :
440+ self .rules ["fused_norm" ](norm_weight , layer_id )
432441
433442 if not isinstance (layer .self_attention , IdentityOp ):
434443 if "MLASelfAttention" in str (type (layer .self_attention )):
@@ -470,12 +479,10 @@ def _get_transformer_layer_state_dict(self, layer, layer_id):
470479 elif (
471480 not isinstance (layer .mlp , IdentityOp )
472481 and "MoE" not in str (type (layer .mlp ))
473- and hasattr (layer .mlp , "linear_fc1" )
474- and hasattr (layer .mlp .linear_fc1 , "layer_norm_weight" )
475- and layer .mlp .linear_fc1 .layer_norm_weight is not None
476- and "fused_norm" in self .rules
482+ and (norm_weight := self ._get_fused_norm_weight (getattr (layer .mlp , "linear_fc1" , None )))
483+ is not None
477484 ):
478- self .rules ["fused_norm" ](layer . mlp . linear_fc1 . layer_norm_weight , layer_id )
485+ self .rules ["fused_norm" ](norm_weight , layer_id )
479486
480487 if not isinstance (layer .mlp , IdentityOp ):
481488 if "MoE" in str (type (layer .mlp )):
@@ -555,14 +562,9 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]:
555562 def _get_mamba_layer_state_dict (self , layer , layer_id ):
556563 if not isinstance (layer .norm , IdentityOp ):
557564 self .rules ["norm" ](layer .norm , layer_id )
558- elif (
559- isinstance (layer .norm , IdentityOp )
560- and hasattr (layer .mixer .in_proj , "layer_norm_weight" )
561- and layer .mixer .in_proj .layer_norm_weight is not None
562- and "fused_norm" in self .rules
563- ):
565+ elif (norm_weight := self ._get_fused_norm_weight (layer .mixer .in_proj )) is not None :
564566 # TE spec: norm is fused into in_proj (QuantTELayerNormColumnParallelLinear).
565- self .rules ["fused_norm" ](layer . mixer . in_proj . layer_norm_weight , layer_id )
567+ self .rules ["fused_norm" ](norm_weight , layer_id )
566568
567569 self .rules ["mixer_norm" ](layer .mixer .norm , layer_id )
568570 self .rules ["A_log" ](layer .mixer .A_log , layer_id )
0 commit comments