@@ -56,6 +56,7 @@ def preprocess_weights(self, weights: dict) -> dict:
5656 config = self .config .pretrained_config
5757 tp_size = self .config .mapping .tp_size
5858 tp_rank = self .config .mapping .tp_rank
59+ mtp_layer_offset = config .num_hidden_layers
5960
6061 if self .config .mapping .enable_attention_dp :
6162 tp_size = 1
@@ -66,10 +67,28 @@ def preprocess_weights(self, weights: dict) -> dict:
6667 linear_key_dim = config .linear_key_head_dim * config .linear_num_key_heads # 16 * 128
6768 linear_value_dim = config .linear_value_head_dim * config .linear_num_value_heads # 32 * 128
6869
70+ mtp_mapping = {
71+ "mtp.fc" : "fc" ,
72+ "mtp.norm" : "shared_head.norm" ,
73+ "mtp.pre_fc_norm_embedding" : "pre_fc_norm_embedding" ,
74+ "mtp.pre_fc_norm_hidden" : "pre_fc_norm_hidden" ,
75+ }
76+
6977 new_weights = {}
7078 for name , _ in weights .items ():
7179 key = name
7280
81+ if key .startswith ("mtp.layers." ):
82+ _ , _ , mtp_layer_idx , module_name = key .split ("." , 3 )
83+ key = (f"model.layers.{ mtp_layer_offset + int (mtp_layer_idx )} ."
84+ f"{ module_name } " )
85+ elif key .startswith ("mtp." ):
86+ for mtp_prefix , trtllm_name in mtp_mapping .items ():
87+ if key .startswith (mtp_prefix ):
88+ suffix = key [len (mtp_prefix ):]
89+ key = f"model.layers.{ mtp_layer_offset } .{ trtllm_name } { suffix } "
90+ break
91+
7392 if "A_log" in key :
7493 w = split (weights [name ], tp_size , tp_rank )
7594 w = w .to (torch .float32 )
0 commit comments