@@ -200,6 +200,12 @@ def _name_remapping(
200200 state_dict [key ] = val
201201 else :
202202 source_key = mapping .get (key , key )
203+ # A mapping value of None means "skip this key" (keep existing value).
204+ # This is used for fused TE modules where layer_norm_weight is loaded
205+ # separately from a different HF path.
206+ if source_key is None :
207+ state_dict [key ] = val
208+ continue
203209 # For bias tensors in ROW_TP layers, don't use parallel config to avoid sharding
204210 # since bias should always be replicated, not sharded
205211 if (
@@ -537,6 +543,15 @@ def _import_mamba_layer(self, layer, layer_id, layer_pbar):
537543 self .rules ["in_proj" ](layer .mixer .in_proj , layer_id )
538544 self .rules ["out_proj" ](layer .mixer .out_proj , layer_id )
539545
546+ # TE spec: layer norm is fused into in_proj (TELayerNormColumnParallelLinear).
547+ # Load the fused layer_norm_weight from the HF norm path.
548+ if (
549+ isinstance (layer .norm , IdentityOp )
550+ and hasattr (layer .mixer .in_proj , "layer_norm_weight" )
551+ and "fused_norm" in self .rules
552+ ):
553+ self .rules ["fused_norm" ](layer .mixer .in_proj .layer_norm_weight , layer_id )
554+
540555 def _import_transformer_layer (self , layer , layer_id , layer_pbar , is_mtp : bool = False ):
541556 if not isinstance (layer .input_layernorm , IdentityOp ):
542557 self .rules ["input_layernorm" ](layer .input_layernorm , layer_id , is_mtp = is_mtp )
@@ -578,6 +593,18 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool =
578593 attention .core_attention .softmax_offset , layer_id , is_mtp = is_mtp
579594 )
580595
596+ # TE spec: input_layernorm is fused into linear_qkv (TELayerNormColumnParallelLinear).
597+ # Load the fused layer_norm_weight from the HF norm path.
598+ if (
599+ isinstance (layer .input_layernorm , IdentityOp )
600+ and hasattr (attention , "linear_qkv" )
601+ and hasattr (attention .linear_qkv , "layer_norm_weight" )
602+ and "fused_norm" in self .rules
603+ ):
604+ self .rules ["fused_norm" ](
605+ attention .linear_qkv .layer_norm_weight , layer_id , is_mtp = is_mtp
606+ )
607+
581608 if not isinstance (layer .pre_mlp_layernorm , IdentityOp ):
582609 self .rules ["pre_mlp_layernorm" ](layer .pre_mlp_layernorm , layer_id , is_mtp = is_mtp )
583610
@@ -671,6 +698,18 @@ def _import_transformer_layer(self, layer, layer_id, layer_pbar, is_mtp: bool =
671698 self .rules ["linear_fc1" ](layer .mlp .linear_fc1 , layer_id , is_mtp = is_mtp )
672699 self .rules ["linear_fc2" ](layer .mlp .linear_fc2 , layer_id , is_mtp = is_mtp )
673700
701+ # TE spec: pre_mlp_layernorm is fused into linear_fc1
702+ # (TELayerNormColumnParallelLinear).
703+ # Load the fused layer_norm_weight from the HF norm path.
704+ if (
705+ isinstance (layer .pre_mlp_layernorm , IdentityOp )
706+ and hasattr (layer .mlp .linear_fc1 , "layer_norm_weight" )
707+ and "fused_norm" in self .rules
708+ ):
709+ self .rules ["fused_norm" ](
710+ layer .mlp .linear_fc1 .layer_norm_weight , layer_id , is_mtp = is_mtp
711+ )
712+
674713 def _import_state_dict (self ):
675714 model = self .model
676715 layer_pbar = tqdm (model .decoder .layers , disable = self .disable_tqdm )
0 commit comments