Skip to content

Commit 3c08682

Browse files
committed
fix(deepseek-v4): load MTP checkpoint weights
1 parent 37744cb commit 3c08682

2 files changed

Lines changed: 120 additions & 90 deletions

File tree

nemo_automodel/components/models/deepseek_v4/state_dict_adapter.py

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@
173173
# layers.{i}.hc_ffn_{fn,base,scale} -> model.layers.{i}.ffn_hc.{fn,base,scale}
174174
(re.compile(r"^layers\.(\d+)\.hc_attn_(base|fn|scale)$"), r"model.layers.\1.attn_hc.\2"),
175175
(re.compile(r"^layers\.(\d+)\.hc_ffn_(base|fn|scale)$"), r"model.layers.\1.ffn_hc.\2"),
176+
# MTP-local HC head. Native MTP keys are normalized to temporary
177+
# ``layers.{k}.*`` keys before the rename table is applied.
178+
(re.compile(r"^layers\.(\d+)\.hc_head_(fn|base|scale)$"), r"model.layers.\1.hc_head.hc_\2"),
176179
# Final HC-head collapse module:
177180
# hc_head_{fn,base,scale} -> model.hc_head.hc_{fn,base,scale}
178181
# (HF uses ``hc_fn`` / ``hc_base`` / ``hc_scale`` inside HyperHead, in
@@ -228,9 +231,10 @@ def from_hf(
228231
"""Convert HF checkpoint to internal format.
229232
230233
Steps:
231-
1. Split MTP layers (index >= num_hidden_layers) from backbone keys
232-
and renumber them as ``layers.{k}.*`` so the standard pipeline
233-
(dequantize / aggregate-experts / rename) handles them too.
234+
1. Split native ``mtp.{k}.*`` keys (and legacy
235+
``layers.{num_hidden_layers+k}.*`` keys) from backbone keys and
236+
renumber them as temporary ``layers.{k}.*`` keys so the standard
237+
pipeline (dequantize / aggregate-experts / rename) handles them too.
234238
2. Dequantize FP8 / FP4 weights for both backbone and MTP.
235239
3. Aggregate per-expert routed weights into stacked tensors.
236240
4. Rename keys using the HF -> internal mapping table.
@@ -244,26 +248,39 @@ def from_hf(
244248
# split regex must accept either form.
245249
_layer_re = re.compile(r"^(model\.)?layers\.(\d+)\.")
246250

247-
# Split MTP keys from backbone keys. MTP layers in HF format are
248-
# ``[model.]layers.{N+k}.*`` — strip the optional ``model.`` prefix
249-
# and renumber to ``layers.{k}.*`` so we can run them through the
250-
# same dequantize / aggregate / rename pipeline as the backbone
251-
# (FP4 routed experts and FP8 attention projections live under MTP
252-
# too, so they need the same handling).
251+
# Split MTP keys from backbone keys. Current DSV4-Flash stores MTP as
252+
# ``mtp.{k}.*``; HF/intermediate exports can also use
253+
# ``[model.]layers.{N+k}.*``. Normalize either format to temporary
254+
# ``layers.{k}.*`` keys so the standard dequantize / aggregate / rename
255+
# pipeline can handle FP4 routed experts and FP8 projections uniformly.
253256
mtp_hf: dict[str, Any] = {}
254-
if num_mtp > 0:
255-
backbone_hf: dict[str, Any] = {}
256-
for key in list(hf_state_dict.keys()):
257-
val = hf_state_dict[key]
258-
m = _layer_re.match(key)
259-
if m and int(m.group(2)) >= N:
260-
orig_idx = int(m.group(2))
261-
mtp_depth = orig_idx - N
257+
backbone_hf: dict[str, Any] = {}
258+
native_mtp_re = re.compile(r"^mtp\.(\d+)\.")
259+
for key in list(hf_state_dict.keys()):
260+
val = hf_state_dict[key]
261+
native_m = native_mtp_re.match(key)
262+
if native_m is not None:
263+
mtp_depth = int(native_m.group(1))
264+
if mtp_depth < num_mtp:
265+
renumbered = f"layers.{mtp_depth}." + key[native_m.end() :]
266+
mtp_hf[renumbered] = val
267+
# Drop checkpoint MTP tensors when the runtime config disables
268+
# MTP. Otherwise loading DSV4-Flash with
269+
# num_nextn_predict_layers=0 produces a large set of dangling
270+
# ``mtp.0.*`` keys.
271+
continue
272+
273+
m = _layer_re.match(key)
274+
if m and int(m.group(2)) >= N and num_mtp > 0:
275+
orig_idx = int(m.group(2))
276+
mtp_depth = orig_idx - N
277+
if mtp_depth < num_mtp:
262278
renumbered = f"layers.{mtp_depth}." + key[m.end() :]
263279
mtp_hf[renumbered] = val
264-
else:
265-
backbone_hf[key] = val
266-
hf_state_dict = backbone_hf
280+
continue
281+
282+
backbone_hf[key] = val
283+
hf_state_dict = backbone_hf
267284

268285
hf_state_dict = self._dequantize(hf_state_dict)
269286
hf_state_dict = self._aggregate_experts(hf_state_dict, device_mesh)
@@ -277,9 +294,9 @@ def from_hf(
277294
# After _rename_all, layer-indexed keys are in one of two forms:
278295
# - ``model.layers.{k}.*`` if a rename rule matched (norms,
279296
# attn, mlp, experts, hc), or
280-
# - ``layers.{k}.*`` if no rule matched — V4 MTP fusion-only
281-
# modules (``eh_proj`` / ``enorm`` / ``hnorm`` /
282-
# ``final_layernorm``) have no specific rename rule.
297+
# - ``layers.{k}.*`` if no rule matched — V4 MTP-only
298+
# modules (``e_proj`` / ``h_proj`` / ``enorm`` / ``hnorm`` /
299+
# ``norm``) have no backbone rename rule.
283300
# Re-prefix both forms into the ``mtp.layers.{k}.*`` namespace.
284301
if key.startswith("model.layers."):
285302
state_dict["mtp" + key[len("model") :]] = val
@@ -564,6 +581,7 @@ def _drop_hash_layer_gate_bias(self, state_dict: dict[str, Any], scope: "_HashBi
564581
# Reverse of the HC submodule renames above.
565582
(re.compile(r"^model\.layers\.(\d+)\.attn_hc\.(fn|base|scale)$"), r"layers.\1.hc_attn_\2"),
566583
(re.compile(r"^model\.layers\.(\d+)\.ffn_hc\.(fn|base|scale)$"), r"layers.\1.hc_ffn_\2"),
584+
(re.compile(r"^model\.layers\.(\d+)\.hc_head\.hc_(fn|base|scale)$"), r"layers.\1.hc_head_\2"),
567585
(re.compile(r"^model\.hc_head\.hc_(fn|base|scale)$"), r"hc_head_\1"),
568586
]
569587

@@ -575,21 +593,19 @@ def _internal_key_to_hf(self, key: str) -> str:
575593
return key
576594

577595
def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[tuple[str, Any]]:
578-
# MTP keys (``mtp.layers.{k}.*``) share the same on-disk layout as
579-
# backbone layers ``layers.{N+k}.*``, so rewrite the fqn into the
580-
# equivalent backbone-internal form and run it through the standard
581-
# split / rename / quantize pipeline. This keeps expert splitting,
582-
# FP4/FP8 quantization, and exclude-key filtering symmetric with the
583-
# ``from_hf`` path instead of bypassing them.
584-
mtp_hf_idx: int | None = None
596+
# MTP keys (``mtp.layers.{k}.*``) share the same per-block layout as
597+
# backbone layers, but current DSV4-Flash stores them under native
598+
# ``mtp.{k}.*`` keys. Rewrite to an equivalent temporary
599+
# ``model.layers.{k}.*`` form for splitting / renaming / quantization,
600+
# then replace the emitted ``layers.{k}.`` prefix with ``mtp.{k}.``.
601+
mtp_depth: int | None = None
585602
if fqn.startswith("mtp."):
586-
N = self.config.num_hidden_layers
587603
rest = fqn[len("mtp.") :]
588604
m = re.match(r"^layers\.(\d+)\.", rest)
589605
if m is None:
590606
return [(fqn, tensor)]
591-
mtp_hf_idx = N + int(m.group(1))
592-
fqn = f"model.layers.{mtp_hf_idx}." + rest[m.end() :]
607+
mtp_depth = int(m.group(1))
608+
fqn = f"model.layers.{mtp_depth}." + rest[m.end() :]
593609

594610
quantization = kwargs.get("quantization", False)
595611
exclude_key_regex = kwargs.get("exclude_key_regex", None)
@@ -603,16 +619,24 @@ def convert_single_tensor_to_hf(self, fqn: str, tensor: Any, **kwargs) -> list[t
603619
# Rename internal keys to HF keys
604620
result = [(self._internal_key_to_hf(k), v) for k, v in result]
605621

606-
if mtp_hf_idx is not None:
607-
# MTP fusion-only modules (``eh_proj`` / ``enorm`` / ``hnorm`` /
608-
# ``final_layernorm``) have no entry in ``_INTERNAL_TO_HF_RENAMES``,
609-
# so they leave ``_internal_key_to_hf`` still carrying the
610-
# ``model.layers.{N+k}.*`` prefix. Strip ``model.`` here so they
611-
# land at the HF-side ``layers.{N+k}.*`` like every other MTP key.
612-
internal_prefix = f"model.layers.{mtp_hf_idx}."
613-
hf_prefix = f"layers.{mtp_hf_idx}."
622+
if mtp_depth is not None:
623+
# MTP-only modules (``e_proj`` / ``h_proj`` / ``enorm`` /
624+
# ``hnorm`` / ``norm``) have no generic backbone rename rule, so
625+
# they can still carry ``model.layers.{k}.`` here. Normalize both
626+
# possible temporary prefixes to the checkpoint's native MTP prefix.
627+
internal_prefix = f"model.layers.{mtp_depth}."
628+
layer_prefix = f"layers.{mtp_depth}."
629+
mtp_prefix = f"mtp.{mtp_depth}."
614630
result = [
615-
(hf_prefix + k[len(internal_prefix) :] if k.startswith(internal_prefix) else k, v) for k, v in result
631+
(
632+
mtp_prefix + k[len(internal_prefix) :]
633+
if k.startswith(internal_prefix)
634+
else mtp_prefix + k[len(layer_prefix) :]
635+
if k.startswith(layer_prefix)
636+
else k,
637+
v,
638+
)
639+
for k, v in result
616640
]
617641

618642
if quantization:

0 commit comments

Comments
 (0)