@@ -549,10 +549,19 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
549549
550550 const size_t row_bytes = (size_t ) n_embd * sizeof (float );
551551
552+ // Stacked MTP needs each block's KV pre-populated with the chain
553+ // context, not just MTP1. For multi-block archs we request logits at
554+ // every batch position so the masked-mode hidden-state extraction in
555+ // ctx_dft captures every row (we feed those rows into the next chain
556+ // step). The extra LM-head matmuls only fire during prefill / verify,
557+ // not during drafting, so the cost is bounded.
558+ const bool chain_prefill = (n_mtp_layers > 1 );
559+ const int8_t want_logits = chain_prefill ? 1 : 0 ;
560+
552561 common_batch_clear (batch);
553562
554563 for (int k = 0 ; k < n_tokens; ++k) {
555- common_batch_add (batch, batch_in.token [k], batch_in.pos [k], { batch_in.seq_id [k][0 ] }, 0 );
564+ common_batch_add (batch, batch_in.token [k], batch_in.pos [k], { batch_in.seq_id [k][0 ] }, want_logits );
556565 }
557566
558567 // shift the tgt embeddings to the right by one position
@@ -587,12 +596,97 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
587596 set_h (i_batch_beg[seq_id], pending_h[seq_id].data ());
588597 }
589598
599+ llama_set_mtp_step (ctx_dft, 0 );
590600 const int32_t rc = llama_decode (ctx_dft, batch);
591601 if (rc != 0 ) {
592602 LOG_ERR (" %s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n " , __func__, (int ) rc, (int ) batch_in.pos [0 ]);
593603 return false ;
594604 }
595605
606+ // For stacked MTP, run the remaining chain steps so blocks 2..N also
607+ // get their KV slots filled with the verified context. Each step uses:
608+ // token[i] = batch_in.token[i+1] (the next trunk token)
609+ // pos [i] = batch_in.pos [i+1] (the next trunk position)
610+ // embd [i] = prev block's hidden at original batch index i
611+ // Each step drops the leading position of every sequence (no prior
612+ // chain hidden available there), so the deeper blocks lose a few
613+ // entries at sequence starts — negligible for prompts of meaningful
614+ // length, and correct: those positions never come up during drafting.
615+ if (chain_prefill) {
616+ // Snapshot MTP_{k-1}'s pre-norm hiddens, indexed by the original
617+ // batch_in position so we can chain across the per-seq remap.
618+ std::vector<float > prev_hiddens ((size_t ) n_tokens * n_embd);
619+ {
620+ const float * h_dft = llama_get_embeddings_pre_norm (ctx_dft);
621+ if (h_dft != nullptr ) {
622+ std::memcpy (prev_hiddens.data (), h_dft, (size_t ) n_tokens * row_bytes);
623+ } else {
624+ LOG_WRN (" %s: chain prefill skipped (ctx_dft pre-norm embeddings unavailable)\n " , __func__);
625+ return true ;
626+ }
627+ }
628+
629+ for (int step = 1 ; step < n_mtp_layers; ++step) {
630+ common_batch_clear (batch);
631+
632+ // Maps step-batch index -> original batch_in index so the
633+ // next iteration can pick this step's hidden by absolute pos.
634+ std::vector<int32_t > step_idx_to_in (n_tokens, -1 );
635+
636+ for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
637+ const int32_t beg = i_batch_beg[seq_id];
638+ const int32_t end = i_batch_end[seq_id];
639+ if (beg < 0 || end < 0 ) {
640+ continue ;
641+ }
642+ if (end - beg < step) {
643+ continue ; // this seq is shorter than the chain depth so far
644+ }
645+
646+ for (int32_t i = beg + step; i <= end; ++i) {
647+ const int32_t prev_idx = i - 1 ; // MTP_{step-1}'s hidden at the source position
648+ const int32_t step_i = batch.n_tokens ;
649+
650+ common_batch_add (batch, batch_in.token [i], batch_in.pos [i], { seq_id }, 1 );
651+ std::memcpy (batch.embd + (size_t ) step_i * n_embd,
652+ prev_hiddens.data () + (size_t ) prev_idx * n_embd,
653+ row_bytes);
654+
655+ step_idx_to_in[step_i] = i;
656+ }
657+ }
658+
659+ if (batch.n_tokens == 0 ) {
660+ break ;
661+ }
662+
663+ llama_set_mtp_step (ctx_dft, (uint32_t ) step);
664+ const int32_t rc_step = llama_decode (ctx_dft, batch);
665+ if (rc_step != 0 ) {
666+ LOG_WRN (" %s: chain prefill step %d llama_decode failed rc=%d\n " , __func__, step, rc_step);
667+ break ;
668+ }
669+
670+ // Gather this step's hidden states for the next iteration,
671+ // remapped to the original batch_in indexing.
672+ if (step + 1 < n_mtp_layers) {
673+ std::vector<float > next_prev ((size_t ) n_tokens * n_embd, 0 .0f );
674+ for (int32_t step_i = 0 ; step_i < (int32_t ) batch.n_tokens ; ++step_i) {
675+ const int32_t in_i = step_idx_to_in[step_i];
676+ if (in_i < 0 ) {
677+ continue ;
678+ }
679+ const float * h = llama_get_embeddings_pre_norm_ith (ctx_dft, step_i);
680+ std::memcpy (next_prev.data () + (size_t ) in_i * n_embd, h, row_bytes);
681+ }
682+ prev_hiddens = std::move (next_prev);
683+ }
684+ }
685+
686+ // Reset so subsequent draft() starts the chain at block 0.
687+ llama_set_mtp_step (ctx_dft, 0 );
688+ }
689+
596690 for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
597691 if (i_batch_end[seq_id] < 0 ) {
598692 continue ;
0 commit comments