@@ -377,6 +377,10 @@ def save_pretrained(
377377 # Add multimodal components to state_dict
378378 state_dict .update (multimodal_state_dict )
379379
380+ mtp_state_dict = self ._get_mtp_state_dict ()
381+ state_dict .update (mtp_state_dict )
382+ print (f"Successfully loaded { len (mtp_state_dict )} MTP tensors" )
383+
380384 # Barrier to ensure the export_dir has been created.
381385 torch .distributed .barrier ()
382386
@@ -478,9 +482,7 @@ def _get_state_dict(self):
478482 else :
479483 raise ValueError ("Only TransformerLayer or MambaLayer are supported." )
480484
481- # Get MTP layer if exists. Only on rank 0 to avoid duplicate weights.
482- if torch .distributed .get_rank () == 0 :
483- self ._get_mtp_state_dict ()
485+ # TODO export MTP layer in the future
484486
485487 def _get_transformer_layer_state_dict (self , layer , layer_id ):
486488 if not isinstance (layer .input_layernorm , IdentityOp ):
@@ -558,13 +560,14 @@ def _get_transformer_layer_state_dict(self, layer, layer_id):
558560 self .rules ["linear_fc1" ](layer .mlp .linear_fc1 , layer_id )
559561 self .rules ["linear_fc2" ](layer .mlp .linear_fc2 , layer_id )
560562
561- def _get_mtp_state_dict (self ):
563+ def _get_mtp_state_dict (self ) -> dict [ str , torch . Tensor ] :
562564 """Export the MTP module.
563565
564566 Currently, we copy the BF16 MTP weights from the pretrained model if the pretrained model has MTP layers.
565567 """
566568 # TODO Implement MTP export for quantized MTP
567569 # Hacky version for now: copy MTP weights from pretrained model
570+ mtp_state_dict = {}
568571 if self ._hf_pretrained_model_name :
569572 if os .path .isdir (self ._hf_pretrained_model_name ):
570573 safetensors_index_file = (
@@ -583,11 +586,12 @@ def _get_mtp_state_dict(self):
583586 model_dir = Path (safetensors_index_file ).parent
584587 for key in safetensors_index ["weight_map" ]:
585588 if key .startswith ("mtp." ) and key not in self ._state_dict :
586- self . _state_dict [key ] = get_safetensor (model_dir , key )
589+ mtp_state_dict [key ] = get_safetensor (model_dir , key )
587590 mtp_exists = True
588591
589592 if mtp_exists :
590593 self .exclude_modules .append ("mtp*" )
594+ return mtp_state_dict
591595
592596 def _get_mamba_layer_state_dict (self , layer , layer_id ):
593597 if not isinstance (layer .norm , IdentityOp ):
@@ -855,7 +859,6 @@ def _name_remapping(
855859 else :
856860 source_key = mapping .get (key , key )
857861 self ._state_dict [prefix + source_key ] = val
858- print (f"{ prefix + source_key } : { self ._state_dict [prefix + source_key ].dtype } " )
859862
860863 def _gated_mlp_slicing (
861864 self , module , prefix , gate_proj_name = "gate_proj" , up_proj_name = "up_proj"
0 commit comments