Skip to content

Commit 37744cb

Browse files
committed
feat(deepseek-v4): support HC-backed MTP training
1 parent c228ec4 commit 37744cb

3 files changed

Lines changed: 245 additions & 176 deletions

File tree

nemo_automodel/components/models/deepseek_v4/model.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,9 @@ def forward(
384384
position_ids: torch.Tensor | None = None,
385385
attention_mask: torch.Tensor | None = None,
386386
padding_mask: torch.Tensor | None = None,
387+
return_hc_hidden: bool = False,
387388
**attn_kwargs: Any,
388-
) -> torch.Tensor:
389+
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
389390
# PP-aware forward (same pattern as DeepseekV3Model.forward).
390391
# Stage 0 of pipeline parallelism owns ``embed_tokens`` and receives
391392
# raw token ids; subsequent stages have ``embed_tokens=None`` and
@@ -460,6 +461,8 @@ def forward(
460461
**attn_kwargs,
461462
)
462463

464+
mtp_hc_hidden = h if return_hc_hidden else None
465+
463466
# Reduce hc_mult copies -> [B,S,dim] via the learned HC head, then
464467
# apply the shared RMSNorm. Both modules live ONLY on the last PP
465468
# stage (intermediate stages keep h at 4D so the next stage can
@@ -468,6 +471,10 @@ def forward(
468471
h = self.hc_head(h)
469472
if getattr(self, "norm", None) is not None:
470473
h = self.norm(h)
474+
if return_hc_hidden:
475+
if mtp_hc_hidden is None:
476+
raise ValueError("return_hc_hidden requested before HC stream was available")
477+
return h, mtp_hc_hidden
471478
return h
472479

473480
def update_moe_gate_bias(self) -> None:
@@ -608,20 +615,28 @@ def forward(
608615
)
609616
attention_mask = None
610617

611-
hidden_states = self.model(
618+
use_mtp = self.mtp is not None and self.training
619+
model_out = self.model(
612620
input_ids,
613621
position_ids=position_ids,
614622
attention_mask=attention_mask,
615623
padding_mask=padding_mask,
624+
return_hc_hidden=use_mtp,
616625
**attn_kwargs,
617626
)
627+
if use_mtp:
628+
hidden_states, mtp_hc_hidden = model_out
629+
else:
630+
hidden_states = model_out
631+
mtp_hc_hidden = None
618632
logits = self.lm_head(hidden_states) if self.lm_head else hidden_states
619633
if thd_mode:
620634
logits = logits.unsqueeze(0)
621635

622636
mtp_per_depth_h = None
623-
if self.mtp is not None and self.training:
624-
# hidden_states is [B, S, hidden] after hc_head+norm — correct MTP input.
637+
if use_mtp:
638+
# MTP consumes the pre-final-head HC stream [B, S, hc_mult, hidden]
639+
# and returns collapsed per-depth [B, S, hidden] tensors for CE.
625640
seq_len = hidden_states.shape[1]
626641
batch_size = hidden_states.shape[0]
627642
if position_ids is None:
@@ -637,7 +652,7 @@ def forward(
637652
)
638653
mtp_per_depth_h = self.mtp(
639654
input_ids=input_ids,
640-
hidden_states=hidden_states,
655+
hidden_states=mtp_hc_hidden,
641656
embed_fn=self.model.embed_tokens,
642657
position_ids=position_ids,
643658
attention_mask=mtp_attn_mask,

0 commit comments

Comments
 (0)