@@ -99,6 +99,24 @@ class Step3VLTextModel(Qwen3Model):
9999class Step35Model (TextModel ):
100100 model_arch = gguf .MODEL_ARCH .STEP35
101101
102+ # --mtp / --no-mtp toggles (see convert_hf_to_gguf.py main()).
103+ # Unlike Qwen3.5 which stores MTP under a `mtp.*` namespace, Step3.5 just
104+ # appends MTP layers at `model.layers.{num_hidden_layers + i}`; these flags
105+ # filter by layer index instead of by name prefix.
106+ no_mtp : bool = False
107+ mtp_only : bool = False
108+
109+ def __init__ (self , * args , ** kwargs ):
110+ super ().__init__ (* args , ** kwargs )
111+ # NextN/MTP layers are appended past num_hidden_layers; extend the
112+ # tensor map to cover them so the MTP block's tensors get correctly
113+ # indexed names. When --no-mtp drops the MTP blocks, fall back to the
114+ # base num_hidden_layers so we don't reserve unused slots.
115+ n_nextn = int (self .hparams .get ("num_nextn_predict_layers" , 0 ))
116+ if n_nextn > 0 and not self .no_mtp :
117+ self .block_count = int (self .hparams ["num_hidden_layers" ]) + n_nextn
118+ self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
119+
102120 def set_gguf_parameters (self ):
103121 rope_theta = self .hparams .get ("rope_theta" )
104122 if isinstance (rope_theta , list ):
@@ -119,8 +137,25 @@ def set_gguf_parameters(self):
119137 n_head_swa = attn_other .get ("num_attention_heads" , n_head_base )
120138 n_kv_swa = attn_other .get ("num_attention_groups" , n_kv_base )
121139
122- layer_types = layer_types [: self .block_count ]
123- partial_rotary_factors = partial_rotary_factors [: self .block_count ]
140+ n_nextn = int (self .hparams .get ("num_nextn_predict_layers" , 0 ))
141+
142+ # The Step3p5 HF checkpoint stores layer_types/partial_rotary_factors
143+ # entries for the MTP blocks past num_hidden_layers; preserve them so
144+ # the MTP layer's attention shape, SWA flag, and partial RoPE dim are
145+ # set correctly. Pad with full-attention defaults if the checkpoint
146+ # truncated them.
147+ def _pad (arr , n , default ):
148+ arr = list (arr )
149+ if len (arr ) < n :
150+ arr = arr + [default ] * (n - len (arr ))
151+ return arr [:n ]
152+
153+ layer_types = _pad (layer_types , self .block_count , "full_attention" )
154+ partial_rotary_factors = _pad (
155+ partial_rotary_factors ,
156+ self .block_count ,
157+ 0.5 , # full_attention default for Step3p5
158+ )
124159 assert [1.0 if lt == "sliding_attention" else 0.5 for lt in layer_types ] == partial_rotary_factors
125160 head_arr = [n_head_swa if lt == "sliding_attention" else n_head_base for lt in layer_types ]
126161 kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types ]
@@ -157,14 +192,25 @@ def set_gguf_parameters(self):
157192
158193 self .gguf_writer .add_layer_norm_rms_eps (self .hparams .get ("rms_norm_eps" , 1e-5 ))
159194
160- # Optional per-layer SwiGLU clamps.
195+ # Optional per-layer SwiGLU clamps. MTP layers default to no clamping (0.0).
161196 if (limits := self .hparams .get ("swiglu_limits" )) is not None :
162- limits_f = [0.0 if v is None else float (v ) for v in limits [: self .block_count ]]
197+ limits_f = _pad (
198+ [0.0 if v is None else float (v ) for v in limits ],
199+ self .block_count ,
200+ 0.0 ,
201+ )
163202 self .gguf_writer .add_swiglu_clamp_exp (limits_f )
164203 if (limits_shared := self .hparams .get ("swiglu_limits_shared" )) is not None :
165- limits_shared_f = [0.0 if v is None else float (v ) for v in limits_shared [: self .block_count ]]
204+ limits_shared_f = _pad (
205+ [0.0 if v is None else float (v ) for v in limits_shared ],
206+ self .block_count ,
207+ 0.0 ,
208+ )
166209 self .gguf_writer .add_swiglu_clamp_shexp (limits_shared_f )
167210
211+ if n_nextn > 0 and not self .no_mtp :
212+ self .gguf_writer .add_nextn_predict_layers (n_nextn )
213+
168214 @classmethod
169215 def filter_tensors (cls , item : tuple [str , Callable [[], Tensor ]]) -> tuple [str , Callable [[], Tensor ]] | None :
170216 name , gen = item
@@ -175,13 +221,41 @@ def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Ca
175221
176222 return super ().filter_tensors ((name , gen ))
177223
224+ def _is_mtp_layer (self , bid : int | None ) -> bool :
225+ if bid is None :
226+ return False
227+ n_main = int (self .hparams .get ("num_hidden_layers" , self .block_count ))
228+ return bid >= n_main
229+
178230 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ):
179- # remove mtp layers
180- if (m := re .match (r"model\.layers\.(\d+)\." , name )) is not None :
181- il = int (m .group (1 ))
182- n_main = int (self .hparams .get ("num_hidden_layers" , self .block_count ))
183- if il >= n_main :
231+ is_mtp = self ._is_mtp_layer (bid )
232+
233+ # --no-mtp: drop the appended MTP block(s) entirely.
234+ if is_mtp and self .no_mtp :
235+ return
236+ # --mtp: keep ONLY MTP-block tensors plus the shared embeddings/norm/lm_head
237+ # (so the resulting GGUF carries just the draft head).
238+ if self .mtp_only and not is_mtp and bid is not None :
239+ return
240+ if self .mtp_only and bid is None :
241+ # Top-level tensors: keep only shared embeddings/norm/lm_head.
242+ keep = name in (
243+ "model.embed_tokens.weight" , "model.norm.weight" , "lm_head.weight" ,
244+ )
245+ if not keep :
184246 return
247+
248+ # The checkpoint nests the per-MTP-layer shared head under
249+ # `model.layers.{N+i}.transformer.shared_head.{norm,output}.weight`;
250+ # strip the `transformer.` infix and rename `output` → `head` so the
251+ # existing NEXTN_SHARED_HEAD_{NORM,HEAD} tensor mapping picks them up.
252+ # Mirrors vllm's `_rewrite_spec_layer_name` (step3p5_mtp.py).
253+ if is_mtp :
254+ if ".transformer." in name :
255+ name = name .replace (".transformer." , "." )
256+ if "shared_head.output" in name :
257+ name = name .replace ("shared_head.output" , "shared_head.head" )
258+
185259 if name .endswith ("norm.weight" ):
186260 data_torch += 1.0
187261
@@ -190,6 +264,21 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
190264
191265 yield from super ().modify_tensors (data_torch , name , bid )
192266
267+ def prepare_metadata (self , vocab_only : bool ):
268+ from_dir = self .fname_out .is_dir ()
269+ super ().prepare_metadata (vocab_only = vocab_only )
270+
271+ # Mirror Qwen3.5's behavior: when emitting a draft-only file into a
272+ # directory, prefix with "mtp-" so it doesn't collide with the trunk.
273+ if not self .mtp_only or not from_dir :
274+ return
275+
276+ output_type : str = self .ftype .name .partition ("_" )[2 ]
277+ fname_default : str = gguf .naming_convention (
278+ self .metadata .name , self .metadata .basename , self .metadata .finetune ,
279+ self .metadata .version , size_label = None , output_type = output_type , model_type = None )
280+ self .fname_out = self .fname_out .parent / f"mtp-{ fname_default } .gguf"
281+
193282 def generate_extra_tensors (self ) -> Iterable [tuple [str , Tensor ]]:
194283 # Step35 can optionally use Llama-3 style RoPE scaling (HF: rope_scaling.rope_type == "llama3").
195284 # llama.cpp represents this via a single extra tensor: "rope_freqs.weight" (aka MODEL_TENSOR.ROPE_FREQS).
0 commit comments