@@ -285,14 +285,16 @@ def save_pretrained(
285285 self ._hf_config .save_pretrained (save_directory )
286286 try :
287287 generation_config = transformers .GenerationConfig .from_pretrained (
288- self ._hf_pretrained_model_name
288+ self ._hf_pretrained_model_name ,
289+ trust_remote_code = self .trust_remote_code ,
289290 )
290291 generation_config .save_pretrained (save_directory )
291292 except OSError :
292293 pass
293294 try :
294295 tokenizer = transformers .AutoTokenizer .from_pretrained (
295- self ._hf_pretrained_model_name
296+ self ._hf_pretrained_model_name ,
297+ trust_remote_code = self .trust_remote_code ,
296298 )
297299 tokenizer .save_pretrained (save_directory )
298300 except OSError :
@@ -420,6 +422,13 @@ def _get_state_dict(self):
420422 def _get_transformer_layer_state_dict (self , layer , layer_id ):
421423 if not isinstance (layer .input_layernorm , IdentityOp ):
422424 self .rules ["input_layernorm" ](layer .input_layernorm , layer_id )
425+ elif (
426+ hasattr (layer .self_attention , "linear_qkv" )
427+ and hasattr (layer .self_attention .linear_qkv , "layer_norm_weight" )
428+ and layer .self_attention .linear_qkv .layer_norm_weight is not None
429+ and "fused_norm" in self .rules
430+ ):
431+ self .rules ["fused_norm" ](layer .self_attention .linear_qkv .layer_norm_weight , layer_id )
423432
424433 if not isinstance (layer .self_attention , IdentityOp ):
425434 if "MLASelfAttention" in str (type (layer .self_attention )):
@@ -458,6 +467,15 @@ def _get_transformer_layer_state_dict(self, layer, layer_id):
458467
459468 if not isinstance (layer .pre_mlp_layernorm , IdentityOp ):
460469 self .rules ["pre_mlp_layernorm" ](layer .pre_mlp_layernorm , layer_id )
470+ elif (
471+ not isinstance (layer .mlp , IdentityOp )
472+ and "MoE" not in str (type (layer .mlp ))
473+ and hasattr (layer .mlp , "linear_fc1" )
474+ and hasattr (layer .mlp .linear_fc1 , "layer_norm_weight" )
475+ and layer .mlp .linear_fc1 .layer_norm_weight is not None
476+ and "fused_norm" in self .rules
477+ ):
478+ self .rules ["fused_norm" ](layer .mlp .linear_fc1 .layer_norm_weight , layer_id )
461479
462480 if not isinstance (layer .mlp , IdentityOp ):
463481 if "MoE" in str (type (layer .mlp )):
@@ -473,22 +491,30 @@ def _get_transformer_layer_state_dict(self, layer, layer_id):
473491 self .rules ["shared_experts.linear_fc2" ](
474492 layer .mlp .shared_experts .linear_fc2 , layer_id
475493 )
476- if not self .rules .get ("use_packed_local_experts" , False ):
477- for expert_id , expert in enumerate (layer .mlp .experts .local_experts ):
494+ if hasattr (layer .mlp .experts , "local_experts" ):
495+ if not self .rules .get ("use_packed_local_experts" , False ):
496+ for expert_id , expert in enumerate (layer .mlp .experts .local_experts ):
497+ self .rules ["local_experts.linear_fc1" ](
498+ expert .linear_fc1 , layer_id , expert_id
499+ )
500+ self .rules ["local_experts.linear_fc2" ](
501+ expert .linear_fc2 , layer_id , expert_id
502+ )
503+ else :
504+ # For llama 4, in hf unified checkpoint, all local experts share one scale
478505 self .rules ["local_experts.linear_fc1" ](
479- expert . linear_fc1 , layer_id , expert_id
506+ layer . mlp . experts . local_experts , layer_id
480507 )
481508 self .rules ["local_experts.linear_fc2" ](
482- expert . linear_fc2 , layer_id , expert_id
509+ layer . mlp . experts . local_experts , layer_id
483510 )
484- else :
485- # For llama 4, in hf unified checkpoint, all local experts share one scale
486- self .rules ["local_experts.linear_fc1" ](
487- layer .mlp .experts .local_experts , layer_id
488- )
489- self .rules ["local_experts.linear_fc2" ](
490- layer .mlp .experts .local_experts , layer_id
491- )
511+ elif "experts.linear_fc1" in self .rules :
512+ # TEGroupedMLP: experts use fused grouped GEMM with a single
513+ # linear_fc1/linear_fc2 for all experts (no local_experts attribute).
514+ # Uses "experts.linear_fc1" rule (GroupedMLPMerging) instead of
515+ # "local_experts.linear_fc1" which expects per-expert iteration.
516+ self .rules ["experts.linear_fc1" ](layer .mlp .experts .linear_fc1 , layer_id )
517+ self .rules ["experts.linear_fc2" ](layer .mlp .experts .linear_fc2 , layer_id )
492518 else :
493519 self .rules ["linear_fc1" ](layer .mlp .linear_fc1 , layer_id )
494520 self .rules ["linear_fc2" ](layer .mlp .linear_fc2 , layer_id )
@@ -529,6 +555,14 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]:
529555 def _get_mamba_layer_state_dict (self , layer , layer_id ):
530556 if not isinstance (layer .norm , IdentityOp ):
531557 self .rules ["norm" ](layer .norm , layer_id )
558+ elif (
559+ isinstance (layer .norm , IdentityOp )
560+ and hasattr (layer .mixer .in_proj , "layer_norm_weight" )
561+ and layer .mixer .in_proj .layer_norm_weight is not None
562+ and "fused_norm" in self .rules
563+ ):
564+ # TE spec: norm is fused into in_proj (QuantTELayerNormColumnParallelLinear).
565+ self .rules ["fused_norm" ](layer .mixer .in_proj .layer_norm_weight , layer_id )
532566
533567 self .rules ["mixer_norm" ](layer .mixer .norm , layer_id )
534568 self .rules ["A_log" ](layer .mixer .A_log , layer_id )
@@ -655,6 +689,7 @@ def _custom_mapping_to_lambda(mapping):
655689 "qkv_slicing" : self ._qkv_slicing ,
656690 "self_attention_scaling" : self ._self_attention_scaling ,
657691 "gated_mlp_slicing" : self ._gated_mlp_slicing ,
692+ "grouped_mlp_slicing" : self ._grouped_mlp_slicing ,
658693 "pack_name_remapping" : self ._pack_name_remapping ,
659694 "pack_name_remapping_gpt_oss" : self ._pack_name_remapping_gpt_oss ,
660695 }
@@ -855,6 +890,67 @@ def _gated_mlp_slicing(
855890 self ._state_dict [gate_proj_key ] = val .detach ().clone ()
856891 self ._state_dict [up_proj_key ] = val .detach ().clone ()
857892
893+ def _grouped_mlp_slicing (self , module , prefix , parallel_config = None ):
894+ """Export TEGroupedMLP weights by splitting per-expert weights into individual HF weights.
895+
896+ TEGroupedMLP (via TEGroupedLinear) stores weights as weight0, weight1, ..., weight{N-1}
897+ in its state_dict, where each weight{i} corresponds to one expert. This method extracts
898+ quantization state from the module, then iterates over experts and saves each expert's
899+ weight (and scales if quantized) under the HF-style per-expert prefix.
900+
901+ This is the reverse of _grouped_mlp_merging in the importer.
902+ """
903+ num_experts = module .num_gemms
904+
905+ # TEGroupedLinear doesn't have module.weight (it has weight0, weight1, ...).
906+ # Temporarily assign weight = weight0 so _get_quantized_state can extract
907+ # qformat, scales, and input_scale from the module's quantizers.
908+ has_weight = hasattr (module , "weight" )
909+ if not has_weight :
910+ module .weight = module .weight0
911+ try :
912+ name_to_value , qformat , block_size = self ._get_quantized_state (
913+ module , self .dtype , prefix = prefix
914+ )
915+ weight_scale , weight_scale_2 = self ._get_weight_scales (name_to_value , qformat )
916+ name_to_value .pop ("weight" , None )
917+ finally :
918+ if not has_weight and hasattr (module , "weight" ):
919+ delattr (module , "weight" )
920+
921+ state_dict = module .state_dict ()
922+
923+ for expert_id in range (num_experts ):
924+ expert_prefix = prefix .format (expert_id ) + "."
925+ weight_key = f"weight{ expert_id } "
926+
927+ if weight_key not in state_dict :
928+ raise ValueError (f"Missing expected TEGroupedMLP expert weight: { weight_key } " )
929+
930+ weight = state_dict [weight_key ].to (self .dtype ).cpu ()
931+
932+ if weight_scale is None :
933+ self ._state_dict [expert_prefix + "weight" ] = weight
934+ else :
935+ self ._state_dict [expert_prefix + "weight" ] = to_quantized_weight (
936+ weight ,
937+ weight_scale ,
938+ qformat ,
939+ weight_scale_2 ,
940+ block_size ,
941+ )
942+ self ._state_dict [expert_prefix + "weight_scale" ] = weight_scale .detach ().clone ()
943+
944+ if weight_scale_2 is not None :
945+ self ._state_dict [expert_prefix + "weight_scale_2" ] = weight_scale_2 .detach ().clone ()
946+
947+ for key , val in name_to_value .items ():
948+ if key == "output_scale" :
949+ continue
950+ for expert_id in range (num_experts ):
951+ expert_prefix = prefix .format (expert_id ) + "."
952+ self ._state_dict [expert_prefix + key ] = val .detach ().clone ()
953+
858954 def _qkv_slicing (
859955 self ,
860956 module ,
0 commit comments