173173 # layers.{i}.hc_ffn_{fn,base,scale} -> model.layers.{i}.ffn_hc.{fn,base,scale}
174174 (re .compile (r"^layers\.(\d+)\.hc_attn_(base|fn|scale)$" ), r"model.layers.\1.attn_hc.\2" ),
175175 (re .compile (r"^layers\.(\d+)\.hc_ffn_(base|fn|scale)$" ), r"model.layers.\1.ffn_hc.\2" ),
176+ # MTP-local HC head. Native MTP keys are normalized to temporary
177+ # ``layers.{k}.*`` keys before the rename table is applied.
178+ (re .compile (r"^layers\.(\d+)\.hc_head_(fn|base|scale)$" ), r"model.layers.\1.hc_head.hc_\2" ),
176179 # Final HC-head collapse module:
177180 # hc_head_{fn,base,scale} -> model.hc_head.hc_{fn,base,scale}
178181 # (HF uses ``hc_fn`` / ``hc_base`` / ``hc_scale`` inside HyperHead, in
@@ -228,9 +231,10 @@ def from_hf(
228231 """Convert HF checkpoint to internal format.
229232
230233 Steps:
231- 1. Split MTP layers (index >= num_hidden_layers) from backbone keys
232- and renumber them as ``layers.{k}.*`` so the standard pipeline
233- (dequantize / aggregate-experts / rename) handles them too.
234+ 1. Split native ``mtp.{k}.*`` keys (and legacy
235+ ``layers.{num_hidden_layers+k}.*`` keys) from backbone keys and
236+ renumber them as temporary ``layers.{k}.*`` keys so the standard
237+ pipeline (dequantize / aggregate-experts / rename) handles them too.
234238 2. Dequantize FP8 / FP4 weights for both backbone and MTP.
235239 3. Aggregate per-expert routed weights into stacked tensors.
236240 4. Rename keys using the HF -> internal mapping table.
@@ -244,26 +248,39 @@ def from_hf(
244248 # split regex must accept either form.
245249 _layer_re = re .compile (r"^(model\.)?layers\.(\d+)\." )
246250
247- # Split MTP keys from backbone keys. MTP layers in HF format are
248- # ``[model.]layers.{N+k}.*`` — strip the optional ``model.`` prefix
249- # and renumber to ``layers.{k}.*`` so we can run them through the
250- # same dequantize / aggregate / rename pipeline as the backbone
251- # (FP4 routed experts and FP8 attention projections live under MTP
252- # too, so they need the same handling).
251+ # Split MTP keys from backbone keys. Current DSV4-Flash stores MTP as
252+ # ``mtp.{k}.*``; HF/intermediate exports can also use
253+ # ``[model.]layers.{N+k}.*``. Normalize either format to temporary
254+ # ``layers.{k}.*`` keys so the standard dequantize / aggregate / rename
255+ # pipeline can handle FP4 routed experts and FP8 projections uniformly.
253256 mtp_hf : dict [str , Any ] = {}
254- if num_mtp > 0 :
255- backbone_hf : dict [str , Any ] = {}
256- for key in list (hf_state_dict .keys ()):
257- val = hf_state_dict [key ]
258- m = _layer_re .match (key )
259- if m and int (m .group (2 )) >= N :
260- orig_idx = int (m .group (2 ))
261- mtp_depth = orig_idx - N
257+ backbone_hf : dict [str , Any ] = {}
258+ native_mtp_re = re .compile (r"^mtp\.(\d+)\." )
259+ for key in list (hf_state_dict .keys ()):
260+ val = hf_state_dict [key ]
261+ native_m = native_mtp_re .match (key )
262+ if native_m is not None :
263+ mtp_depth = int (native_m .group (1 ))
264+ if mtp_depth < num_mtp :
265+ renumbered = f"layers.{ mtp_depth } ." + key [native_m .end () :]
266+ mtp_hf [renumbered ] = val
267+ # Drop checkpoint MTP tensors when the runtime config disables
268+ # MTP. Otherwise loading DSV4-Flash with
269+ # num_nextn_predict_layers=0 produces a large set of dangling
270+ # ``mtp.0.*`` keys.
271+ continue
272+
273+ m = _layer_re .match (key )
274+ if m and int (m .group (2 )) >= N and num_mtp > 0 :
275+ orig_idx = int (m .group (2 ))
276+ mtp_depth = orig_idx - N
277+ if mtp_depth < num_mtp :
262278 renumbered = f"layers.{ mtp_depth } ." + key [m .end () :]
263279 mtp_hf [renumbered ] = val
264- else :
265- backbone_hf [key ] = val
266- hf_state_dict = backbone_hf
280+ continue
281+
282+ backbone_hf [key ] = val
283+ hf_state_dict = backbone_hf
267284
268285 hf_state_dict = self ._dequantize (hf_state_dict )
269286 hf_state_dict = self ._aggregate_experts (hf_state_dict , device_mesh )
@@ -277,9 +294,9 @@ def from_hf(
277294 # After _rename_all, layer-indexed keys are in one of two forms:
278295 # - ``model.layers.{k}.*`` if a rename rule matched (norms,
279296 # attn, mlp, experts, hc), or
280- # - ``layers.{k}.*`` if no rule matched — V4 MTP fusion -only
281- # modules (``eh_proj `` / ``enorm`` / ``hnorm`` /
282- # ``final_layernorm ``) have no specific rename rule.
297+ # - ``layers.{k}.*`` if no rule matched — V4 MTP-only
298+ # modules (``e_proj`` / ``h_proj `` / ``enorm`` / ``hnorm`` /
299+ # ``norm ``) have no backbone rename rule.
283300 # Re-prefix both forms into the ``mtp.layers.{k}.*`` namespace.
284301 if key .startswith ("model.layers." ):
285302 state_dict ["mtp" + key [len ("model" ) :]] = val
@@ -564,6 +581,7 @@ def _drop_hash_layer_gate_bias(self, state_dict: dict[str, Any], scope: "_HashBi
564581 # Reverse of the HC submodule renames above.
565582 (re .compile (r"^model\.layers\.(\d+)\.attn_hc\.(fn|base|scale)$" ), r"layers.\1.hc_attn_\2" ),
566583 (re .compile (r"^model\.layers\.(\d+)\.ffn_hc\.(fn|base|scale)$" ), r"layers.\1.hc_ffn_\2" ),
584+ (re .compile (r"^model\.layers\.(\d+)\.hc_head\.hc_(fn|base|scale)$" ), r"layers.\1.hc_head_\2" ),
567585 (re .compile (r"^model\.hc_head\.hc_(fn|base|scale)$" ), r"hc_head_\1" ),
568586 ]
569587
@@ -575,21 +593,19 @@ def _internal_key_to_hf(self, key: str) -> str:
575593 return key
576594
577595 def convert_single_tensor_to_hf (self , fqn : str , tensor : Any , ** kwargs ) -> list [tuple [str , Any ]]:
578- # MTP keys (``mtp.layers.{k}.*``) share the same on-disk layout as
579- # backbone layers ``layers.{N+k}.*``, so rewrite the fqn into the
580- # equivalent backbone-internal form and run it through the standard
581- # split / rename / quantize pipeline. This keeps expert splitting,
582- # FP4/FP8 quantization, and exclude-key filtering symmetric with the
583- # ``from_hf`` path instead of bypassing them.
584- mtp_hf_idx : int | None = None
596+ # MTP keys (``mtp.layers.{k}.*``) share the same per-block layout as
597+ # backbone layers, but current DSV4-Flash stores them under native
598+ # ``mtp.{k}.*`` keys. Rewrite to an equivalent temporary
599+ # ``model.layers.{k}.*`` form for splitting / renaming / quantization,
600+ # then replace the emitted ``layers.{k}.`` prefix with ``mtp.{k}.``.
601+ mtp_depth : int | None = None
585602 if fqn .startswith ("mtp." ):
586- N = self .config .num_hidden_layers
587603 rest = fqn [len ("mtp." ) :]
588604 m = re .match (r"^layers\.(\d+)\." , rest )
589605 if m is None :
590606 return [(fqn , tensor )]
591- mtp_hf_idx = N + int (m .group (1 ))
592- fqn = f"model.layers.{ mtp_hf_idx } ." + rest [m .end () :]
607+ mtp_depth = int (m .group (1 ))
608+ fqn = f"model.layers.{ mtp_depth } ." + rest [m .end () :]
593609
594610 quantization = kwargs .get ("quantization" , False )
595611 exclude_key_regex = kwargs .get ("exclude_key_regex" , None )
@@ -603,16 +619,24 @@ def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[t
603619 # Rename internal keys to HF keys
604620 result = [(self ._internal_key_to_hf (k ), v ) for k , v in result ]
605621
606- if mtp_hf_idx is not None :
607- # MTP fusion -only modules (``eh_proj `` / ``enorm `` / ``hnorm `` /
608- # ``final_layernorm`` ) have no entry in ``_INTERNAL_TO_HF_RENAMES``,
609- # so they leave ``_internal_key_to_hf `` still carrying the
610- # ``model.layers.{N+k}.*`` prefix. Strip ``model.`` here so they
611- # land at the HF-side `` layers.{N+k}.*`` like every other MTP key.
612- internal_prefix = f"model. layers.{ mtp_hf_idx } ."
613- hf_prefix = f"layers. { mtp_hf_idx } ."
622+ if mtp_depth is not None :
623+ # MTP-only modules (``e_proj `` / ``h_proj `` / ``enorm `` /
624+ # ``hnorm`` / ``norm`` ) have no generic backbone rename rule, so
625+ # they can still carry ``model.layers.{k}. `` here. Normalize both
626+ # possible temporary prefixes to the checkpoint's native MTP prefix.
627+ internal_prefix = f"model. layers.{ mtp_depth } ."
628+ layer_prefix = f"layers.{ mtp_depth } ."
629+ mtp_prefix = f"mtp. { mtp_depth } ."
614630 result = [
615- (hf_prefix + k [len (internal_prefix ) :] if k .startswith (internal_prefix ) else k , v ) for k , v in result
631+ (
632+ mtp_prefix + k [len (internal_prefix ) :]
633+ if k .startswith (internal_prefix )
634+ else mtp_prefix + k [len (layer_prefix ) :]
635+ if k .startswith (layer_prefix )
636+ else k ,
637+ v ,
638+ )
639+ for k , v in result
616640 ]
617641
618642 if quantization :
0 commit comments