@@ -154,6 +154,13 @@ struct common_speculative_state {
154154
155155 virtual void accept (uint16_t n_accepted) = 0;
156156
157+ // Optional hook: invoked by the server after each successful llama_decode
158+ // on ctx_tgt. MTP uses it (only when is_prompt_prefill) to mirror the
159+ // ubatch into ctx_mtp's KV.
160+ virtual void on_target_decoded (const llama_batch & /* batch*/ ,
161+ llama_seq_id /* slot_seq_id*/ ,
162+ bool /* is_prompt_prefill*/ ) {}
163+
157164 virtual int32_t n_max (const common_params_speculative & params) const = 0;
158165 virtual int32_t n_min (const common_params_speculative & params) const = 0;
159166};
@@ -642,6 +649,15 @@ struct common_speculative_state_mtp : public common_speculative_state {
642649 // where ctx_tgt's t_h_pre_norm has only the prompt's last-position row.
643650 int32_t last_n_accepted = -1 ;
644651
652+ // No prompt-prefill accumulator: instead of harvesting trunk h rows into
653+ // a host buffer and replaying them in one big MTP decode at begin(), we
654+ // do an MTP ubatch decode FROM INSIDE on_target_decoded — i.e. each time
655+ // ctx_target finishes a ubatch, we immediately push those rows + tokens
656+ // through ctx_mtp. ctx_mtp's KV grows incrementally as the trunk's
657+ // prompt prefill progresses, so by the time begin() is called the MTP
658+ // KV already covers the full prompt, no matter how many ubatches it
659+ // took on the trunk side.
660+
645661 common_speculative_state_mtp (enum common_speculative_type type,
646662 llama_context * ctx_tgt,
647663 llama_context * ctx_mtp)
@@ -651,8 +667,11 @@ struct common_speculative_state_mtp : public common_speculative_state {
651667 const int32_t n_vocab = llama_vocab_n_tokens (llama_model_get_vocab (model_mtp));
652668 logits_buf.resize (n_vocab);
653669
654- // Single-token batches drive the MTP draft step.
655- batch = llama_batch_init (/* n_tokens=*/ 1 , /* n_embd=*/ 0 , /* n_seq_max=*/ 1 );
670+ // Sized to a full ctx_mtp ubatch: largest case is the prompt-prefill
671+ // mirror in on_target_decoded, which can run up to n_ubatch tokens
672+ // per chunk; per-step drafts only use 1.
673+ const int32_t n_batch_max = (int32_t ) llama_n_ubatch (ctx_mtp);
674+ batch = llama_batch_init (/* n_tokens=*/ n_batch_max, /* n_embd=*/ 0 , /* n_seq_max=*/ 1 );
656675
657676 // Warmup decode on ctx_mtp: builds the graph for real (not just reserve)
658677 // and populates ctx_mtp->gf_res_prev->t_inp_h so the relay function can
@@ -683,30 +702,44 @@ struct common_speculative_state_mtp : public common_speculative_state {
683702 }
684703
685704 void begin (const llama_tokens & prompt) override {
686- // Reset ctx_mtp's KV. Step 7 will replay the prompt here so MTP
687- // attention has full history before the first draft.
688- llama_memory_clear (llama_get_memory (ctx_mtp), /* data=*/ true );
689-
690- // Seed a single token at position 0 so the cache has a "last position"
691- // baseline. M-RoPE's X<Y check fires if a fresh batch tries to start
692- // at the same position the cache just saw, so the first real draft
693- // will land at position 1.
694- const llama_model * model = llama_get_model (ctx_mtp);
695- const llama_token bos = llama_vocab_bos (llama_model_get_vocab (model));
696- batch.n_tokens = 1 ;
697- batch.token [0 ] = bos;
698- batch.pos [0 ] = 0 ;
699- batch.n_seq_id [0 ] = 1 ;
700- batch.seq_id [0 ][0 ] = 0 ;
701- batch.logits [0 ] = 0 ; // we don't need logits from this seed decode
702- const int32_t rc = llama_decode (ctx_mtp, batch);
703- if (rc != 0 ) {
704- LOG_WRN (" %s: ctx_mtp seed decode rc=%d\n " , __func__, rc);
705- }
706- mtp_pos = 1 ;
707- last_n_accepted = -1 ; // signal "first draft of this generation"
705+ // ctx_mtp's KV has been incrementally populated by on_target_decoded
706+ // as the trunk processed each prompt-prefill ubatch. By the time
707+ // begin() is called, MTP KV covers positions 0..N-1 (matching the
708+ // trunk's prompt) — provided the server-side toggle and the
709+ // contiguous-rows precondition held. We just need to set up the
710+ // tracking state for the first draft.
711+ last_n_accepted = -1 ;
708712
709- GGML_UNUSED (prompt);
713+ const llama_pos pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 );
714+ const int32_t N = (int32_t ) prompt.size ();
715+ LOG_INF (" mtp begin: N=%d mtp_pos_max=%d (KV %s)\n " ,
716+ N, (int ) pos_max,
717+ (pos_max + 1 == N) ? " fully prefilled" :
718+ (pos_max < 0 ) ? " empty (no prefill)" : " partial" );
719+
720+ if (pos_max < 0 ) {
721+ // No prefill happened (e.g. server toggle off for non-MTP slot,
722+ // or contiguous-rows precondition failed). Seed BOS at position
723+ // 0 so the first draft has a non-empty KV to attend to. RoPE
724+ // will be misaligned with trunk for short prompts that's
725+ // tolerable; for long prompts the prefill path should always
726+ // win this race.
727+ const llama_model * model_mtp = llama_get_model (ctx_mtp);
728+ const llama_token bos = llama_vocab_bos (llama_model_get_vocab (model_mtp));
729+ batch.n_tokens = 1 ;
730+ batch.token [0 ] = bos;
731+ batch.pos [0 ] = 0 ;
732+ batch.n_seq_id [0 ] = 1 ;
733+ batch.seq_id [0 ][0 ] = 0 ;
734+ batch.logits [0 ] = 0 ;
735+ const int32_t rc = llama_decode (ctx_mtp, batch);
736+ if (rc != 0 ) {
737+ LOG_WRN (" %s: ctx_mtp seed decode rc=%d\n " , __func__, rc);
738+ }
739+ mtp_pos = 1 ;
740+ } else {
741+ mtp_pos = pos_max + 1 ;
742+ }
710743 }
711744
712745 void draft (
@@ -725,6 +758,10 @@ struct common_speculative_state_mtp : public common_speculative_state {
725758 const int32_t n_vocab = (int32_t ) logits_buf.size ();
726759 llama_token cond_tok = id_last;
727760
761+ const llama_pos pos_max_before = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 );
762+ LOG_INF (" mtp draft: id_last=%d n_max=%d last_n_accepted=%d mtp_pos_max=%d\n " ,
763+ (int ) id_last, (int ) n_max, (int ) last_n_accepted, (int ) pos_max_before);
764+
728765 for (int32_t k = 0 ; k < n_max; ++k) {
729766 // Stage h. Step 0: from ctx_tgt's t_h_pre_norm at the row whose
730767 // hidden produced id_last. After a previous verify [sampled, d0,
@@ -735,14 +772,16 @@ struct common_speculative_state_mtp : public common_speculative_state {
735772 // ctx_tgt only computed the prompt's last position → row 0.
736773 // Step k>0: self-relay from ctx_mtp's previous t_mtp_out.
737774 int32_t rc_relay;
775+ int32_t src_row_used = -1 ;
738776 if (k == 0 ) {
739- const int32_t src_row = (last_n_accepted < 0 ) ? 0 : last_n_accepted;
740- rc_relay = llama_mtp_relay_h (ctx_tgt, ctx_mtp, src_row , /* n_rows=*/ 1 );
777+ src_row_used = (last_n_accepted < 0 ) ? 0 : last_n_accepted;
778+ rc_relay = llama_mtp_relay_h (ctx_tgt, ctx_mtp, src_row_used , /* n_rows=*/ 1 );
741779 } else {
742780 rc_relay = llama_mtp_relay_h_self (ctx_mtp, /* n_rows=*/ 1 );
743781 }
744782 if (rc_relay != 0 ) {
745- LOG_DBG (" %s: relay rc=%d at k=%d; stopping chain\n " , __func__, rc_relay, k);
783+ LOG_WRN (" %s: relay rc=%d at k=%d (src_row=%d); stopping chain\n " ,
784+ __func__, rc_relay, k, src_row_used);
746785 return ;
747786 }
748787
@@ -775,6 +814,8 @@ struct common_speculative_state_mtp : public common_speculative_state {
775814 for (int i = 1 ; i < n_vocab; ++i) {
776815 if (logits_buf[i] > bv) { bv = logits_buf[i]; best = i; }
777816 }
817+ LOG_INF (" mtp draft k=%d pos=%d cond=%d -> %d (logit=%.2f)\n " ,
818+ (int ) k, (int ) pos, (int ) cond_tok, best, bv);
778819 draft_tokens.push_back (best);
779820 cond_tok = best;
780821 }
@@ -790,12 +831,14 @@ struct common_speculative_state_mtp : public common_speculative_state {
790831 // positions from ctx_mtp's KV so the next draft writes K/V at the
791832 // right slots.
792833 const llama_pos pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 );
834+ const int32_t n_drafted_last = (int32_t ) last_n_drafted;
835+ const int32_t n_to_drop = std::max (0 , n_drafted_last - (int32_t ) n_accepted);
836+ LOG_INF (" mtp accept: n_drafted=%d n_accepted=%d n_to_drop=%d mtp_pos_max=%d\n " ,
837+ n_drafted_last, (int ) n_accepted, n_to_drop, (int ) pos_max);
793838 if (pos_max < 0 ) {
794839 last_n_accepted = (int32_t ) n_accepted;
795840 return ;
796841 }
797- const int32_t n_drafted_last = (int32_t ) last_n_drafted;
798- const int32_t n_to_drop = std::max (0 , n_drafted_last - (int32_t ) n_accepted);
799842 if (n_to_drop > 0 ) {
800843 const llama_pos drop_from = pos_max - n_to_drop + 1 ;
801844 llama_memory_seq_rm (llama_get_memory (ctx_mtp), /* seq_id=*/ 0 ,
@@ -807,6 +850,116 @@ struct common_speculative_state_mtp : public common_speculative_state {
807850 last_n_accepted = (int32_t ) n_accepted;
808851 }
809852
853+ void on_target_decoded (const llama_batch & batch, llama_seq_id slot_seq_id, bool is_prompt_prefill) override {
854+ if (!is_prompt_prefill) {
855+ return ; // verify-batch decodes are owned by the draft path
856+ }
857+ // Mirror the trunk's just-finished ubatch into ctx_mtp by running one
858+ // MTP forward over the same positions. ctx_target's t_h_pre_norm
859+ // currently carries one row per output position of THIS ubatch (the
860+ // server toggles output=true on every prompt-prefill token for MTP
861+ // slots), and its data is still fresh — graph_compute_async finished
862+ // before this hook fires.
863+ //
864+ // Conditions for staging a real prefill MTP decode:
865+ // - we're in prompt prefill (not the verify decode that draft()
866+ // handles itself: skip if any slot tokens have logits=true,
867+ // since the verify batch always sets logits everywhere). We
868+ // detect this by checking that ALL of OUR slot's tokens carry
869+ // logits=true AND the trunk t_h_pre_norm has rows for all of
870+ // them — i.e. this is the prefill regime.
871+ // - the slot is single-seq (n_parallel=1 is enforced for MTP).
872+ //
873+ // For each token at trunk pos p in the slot, we feed (h_p, prompt[p])
874+ // to the MTP block at MTP pos p. This is a "no-shift" approximation
875+ // — MTP was trained on (h_p, x_{p+1}) → predict x_{p+2}, so feeding
876+ // (h_p, x_p) puts slightly off-distribution K/V into MTP's KV, but
877+ // the K/V values are at the right positions for attention. The
878+ // alternative (proper shift) requires looking ahead to the next
879+ // ubatch's first token, which we don't have here.
880+ if (batch.n_tokens <= 0 ) {
881+ return ;
882+ }
883+ ggml_tensor * h = llama_context_get_t_h_pre_norm (ctx_tgt);
884+ if (!h) {
885+ return ; // trunk didn't produce t_h_pre_norm this decode
886+ }
887+ const int64_t n_rows = h->ne [1 ];
888+ if (n_rows < batch.n_tokens ) {
889+ return ; // not all positions have output rows; can't safely match
890+ }
891+
892+ // Filter tokens belonging to this slot, preserving batch order.
893+ // For n_parallel=1 every token belongs to the slot; the filter is a
894+ // no-op there.
895+ struct entry { int batch_idx; int row_idx; };
896+ std::vector<entry> mine;
897+ mine.reserve (batch.n_tokens );
898+ int row_idx = -1 ;
899+ for (int i = 0 ; i < batch.n_tokens ; ++i) {
900+ const bool has_out = batch.logits && batch.logits [i];
901+ if (has_out) row_idx++;
902+ bool is_mine = false ;
903+ if (batch.n_seq_id && batch.n_seq_id [i] > 0 && batch.seq_id ) {
904+ for (int j = 0 ; j < batch.n_seq_id [i]; ++j) {
905+ if (batch.seq_id [i][j] == slot_seq_id) { is_mine = true ; break ; }
906+ }
907+ }
908+ if (is_mine && has_out && row_idx >= 0 && row_idx < n_rows) {
909+ mine.push_back ({i, row_idx});
910+ }
911+ }
912+ if (mine.empty ()) {
913+ return ;
914+ }
915+ // Heuristic: only run prefill if the rows in t_h_pre_norm are
916+ // contiguous starting at 0 (they will be when our slot's tokens are
917+ // the only ones with output=true). Otherwise we'd need to gather
918+ // non-contiguous rows — skip rather than risk wrong h.
919+ for (size_t k = 0 ; k < mine.size (); ++k) {
920+ if (mine[k].row_idx != (int ) k) {
921+ LOG_INF (" mtp prefill skip: non-contiguous rows (slot=%d)\n " , (int ) slot_seq_id);
922+ return ;
923+ }
924+ }
925+
926+ const int n = (int ) mine.size ();
927+ // Run MTP forwards in chunks of at most n_ubatch tokens — single
928+ // huge MTP forwards (e.g. 1500-token prompts) exceed compute scratch
929+ // and crash in ggml. The KV result is identical regardless of split,
930+ // since each chunk attends to all earlier MTP KV positions.
931+ const int chunk_max = (int ) llama_n_ubatch (ctx_mtp);
932+ for (int off = 0 ; off < n; off += chunk_max) {
933+ const int n_chunk = std::min (chunk_max, n - off);
934+
935+ this ->batch .n_tokens = n_chunk;
936+ for (int k = 0 ; k < n_chunk; ++k) {
937+ const int bi = mine[off + k].batch_idx ;
938+ this ->batch .token [k] = batch.token [bi];
939+ this ->batch .pos [k] = batch.pos ? batch.pos [bi] : (off + k);
940+ this ->batch .n_seq_id [k] = 1 ;
941+ this ->batch .seq_id [k][0 ] = 0 ;
942+ this ->batch .logits [k] = 0 ;
943+ }
944+ const int32_t rc_relay = llama_mtp_relay_h (ctx_tgt, ctx_mtp,
945+ /* src_row=*/ off, /* n_rows=*/ n_chunk);
946+ if (rc_relay != 0 ) {
947+ LOG_WRN (" mtp prefill: relay rc=%d (chunk_off=%d, n=%d, slot=%d)\n " ,
948+ rc_relay, off, n_chunk, (int ) slot_seq_id);
949+ return ;
950+ }
951+ const int32_t rc = llama_decode (ctx_mtp, this ->batch );
952+ if (rc != 0 ) {
953+ LOG_WRN (" mtp prefill: decode rc=%d (chunk_off=%d, n=%d, slot=%d)\n " ,
954+ rc, off, n_chunk, (int ) slot_seq_id);
955+ return ;
956+ }
957+ }
958+ const llama_pos new_pos_max = llama_memory_seq_pos_max (llama_get_memory (ctx_mtp), 0 );
959+ LOG_INF (" mtp prefill: slot=%d n=%d chunks=%d mtp_pos_max=%d\n " ,
960+ (int ) slot_seq_id, n, (n + chunk_max - 1 ) / chunk_max, (int ) new_pos_max);
961+ }
962+
810963 int32_t n_max (const common_params_speculative & params) const override {
811964 return std::max (1 , params.draft .n_max );
812965 }
@@ -1423,6 +1576,19 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
14231576 }
14241577}
14251578
1579+ void common_speculative_on_target_decoded (
1580+ common_speculative * spec,
1581+ const llama_batch & batch,
1582+ llama_seq_id slot_seq_id,
1583+ bool is_prompt_prefill) {
1584+ if (!spec) {
1585+ return ;
1586+ }
1587+ for (auto & impl : spec->impls ) {
1588+ impl->on_target_decoded (batch, slot_seq_id, is_prompt_prefill);
1589+ }
1590+ }
1591+
14261592int32_t common_speculative_n_max (const common_speculative * spec, const common_params_speculative & params) {
14271593 if (spec == nullptr ) {
14281594 return 0 ;
0 commit comments