@@ -384,8 +384,9 @@ def forward(
384384 position_ids : torch .Tensor | None = None ,
385385 attention_mask : torch .Tensor | None = None ,
386386 padding_mask : torch .Tensor | None = None ,
387+ return_hc_hidden : bool = False ,
387388 ** attn_kwargs : Any ,
388- ) -> torch .Tensor :
389+ ) -> torch .Tensor | tuple [ torch . Tensor , torch . Tensor ] :
389390 # PP-aware forward (same pattern as DeepseekV3Model.forward).
390391 # Stage 0 of pipeline parallelism owns ``embed_tokens`` and receives
391392 # raw token ids; subsequent stages have ``embed_tokens=None`` and
@@ -460,6 +461,8 @@ def forward(
460461 ** attn_kwargs ,
461462 )
462463
464+ mtp_hc_hidden = h if return_hc_hidden else None
465+
463466 # Reduce hc_mult copies -> [B,S,dim] via the learned HC head, then
464467 # apply the shared RMSNorm. Both modules live ONLY on the last PP
465468 # stage (intermediate stages keep h at 4D so the next stage can
@@ -468,6 +471,10 @@ def forward(
468471 h = self .hc_head (h )
469472 if getattr (self , "norm" , None ) is not None :
470473 h = self .norm (h )
474+ if return_hc_hidden :
475+ if mtp_hc_hidden is None :
476+ raise ValueError ("return_hc_hidden requested before HC stream was available" )
477+ return h , mtp_hc_hidden
471478 return h
472479
473480 def update_moe_gate_bias (self ) -> None :
@@ -608,20 +615,28 @@ def forward(
608615 )
609616 attention_mask = None
610617
611- hidden_states = self .model (
618+ use_mtp = self .mtp is not None and self .training
619+ model_out = self .model (
612620 input_ids ,
613621 position_ids = position_ids ,
614622 attention_mask = attention_mask ,
615623 padding_mask = padding_mask ,
624+ return_hc_hidden = use_mtp ,
616625 ** attn_kwargs ,
617626 )
627+ if use_mtp :
628+ hidden_states , mtp_hc_hidden = model_out
629+ else :
630+ hidden_states = model_out
631+ mtp_hc_hidden = None
618632 logits = self .lm_head (hidden_states ) if self .lm_head else hidden_states
619633 if thd_mode :
620634 logits = logits .unsqueeze (0 )
621635
622636 mtp_per_depth_h = None
623- if self .mtp is not None and self .training :
624- # hidden_states is [B, S, hidden] after hc_head+norm — correct MTP input.
637+ if use_mtp :
638+ # MTP consumes the pre-final-head HC stream [B, S, hc_mult, hidden]
639+ # and returns collapsed per-depth [B, S, hidden] tensors for CE.
625640 seq_len = hidden_states .shape [1 ]
626641 batch_size = hidden_states .shape [0 ]
627642 if position_ids is None :
@@ -637,7 +652,7 @@ def forward(
637652 )
638653 mtp_per_depth_h = self .mtp (
639654 input_ids = input_ids ,
640- hidden_states = hidden_states ,
655+ hidden_states = mtp_hc_hidden ,
641656 embed_fn = self .model .embed_tokens ,
642657 position_ids = position_ids ,
643658 attention_mask = mtp_attn_mask ,
0 commit comments