Skip to content

Commit 293517a

Browse files
committed
prompt prefill fix
1 parent b9be095 commit 293517a

7 files changed

Lines changed: 298 additions & 89 deletions

File tree

common/speculative.cpp

Lines changed: 196 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ struct common_speculative_state {
154154

155155
virtual void accept(uint16_t n_accepted) = 0;
156156

157+
// Optional hook: invoked by the server after each successful llama_decode
158+
// on ctx_tgt. MTP uses it (only when is_prompt_prefill) to mirror the
159+
// ubatch into ctx_mtp's KV.
160+
virtual void on_target_decoded(const llama_batch & /*batch*/,
161+
llama_seq_id /*slot_seq_id*/,
162+
bool /*is_prompt_prefill*/) {}
163+
157164
virtual int32_t n_max(const common_params_speculative & params) const = 0;
158165
virtual int32_t n_min(const common_params_speculative & params) const = 0;
159166
};
@@ -642,6 +649,15 @@ struct common_speculative_state_mtp : public common_speculative_state {
642649
// where ctx_tgt's t_h_pre_norm has only the prompt's last-position row.
643650
int32_t last_n_accepted = -1;
644651

652+
// No prompt-prefill accumulator: instead of harvesting trunk h rows into
653+
// a host buffer and replaying them in one big MTP decode at begin(), we
654+
// do an MTP ubatch decode FROM INSIDE on_target_decoded — i.e. each time
655+
// ctx_target finishes a ubatch, we immediately push those rows + tokens
656+
// through ctx_mtp. ctx_mtp's KV grows incrementally as the trunk's
657+
// prompt prefill progresses, so by the time begin() is called the MTP
658+
// KV already covers the full prompt, no matter how many ubatches it
659+
// took on the trunk side.
660+
645661
common_speculative_state_mtp(enum common_speculative_type type,
646662
llama_context * ctx_tgt,
647663
llama_context * ctx_mtp)
@@ -651,8 +667,11 @@ struct common_speculative_state_mtp : public common_speculative_state {
651667
const int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model_mtp));
652668
logits_buf.resize(n_vocab);
653669

654-
// Single-token batches drive the MTP draft step.
655-
batch = llama_batch_init(/*n_tokens=*/ 1, /*n_embd=*/ 0, /*n_seq_max=*/ 1);
670+
// Sized to a full ctx_mtp ubatch: largest case is the prompt-prefill
671+
// mirror in on_target_decoded, which can run up to n_ubatch tokens
672+
// per chunk; per-step drafts only use 1.
673+
const int32_t n_batch_max = (int32_t) llama_n_ubatch(ctx_mtp);
674+
batch = llama_batch_init(/*n_tokens=*/ n_batch_max, /*n_embd=*/ 0, /*n_seq_max=*/ 1);
656675

657676
// Warmup decode on ctx_mtp: builds the graph for real (not just reserve)
658677
// and populates ctx_mtp->gf_res_prev->t_inp_h so the relay function can
@@ -683,30 +702,44 @@ struct common_speculative_state_mtp : public common_speculative_state {
683702
}
684703

685704
void begin(const llama_tokens & prompt) override {
686-
// Reset ctx_mtp's KV. Step 7 will replay the prompt here so MTP
687-
// attention has full history before the first draft.
688-
llama_memory_clear(llama_get_memory(ctx_mtp), /*data=*/ true);
689-
690-
// Seed a single token at position 0 so the cache has a "last position"
691-
// baseline. M-RoPE's X<Y check fires if a fresh batch tries to start
692-
// at the same position the cache just saw, so the first real draft
693-
// will land at position 1.
694-
const llama_model * model = llama_get_model(ctx_mtp);
695-
const llama_token bos = llama_vocab_bos(llama_model_get_vocab(model));
696-
batch.n_tokens = 1;
697-
batch.token[0] = bos;
698-
batch.pos[0] = 0;
699-
batch.n_seq_id[0] = 1;
700-
batch.seq_id[0][0] = 0;
701-
batch.logits[0] = 0; // we don't need logits from this seed decode
702-
const int32_t rc = llama_decode(ctx_mtp, batch);
703-
if (rc != 0) {
704-
LOG_WRN("%s: ctx_mtp seed decode rc=%d\n", __func__, rc);
705-
}
706-
mtp_pos = 1;
707-
last_n_accepted = -1; // signal "first draft of this generation"
705+
// ctx_mtp's KV has been incrementally populated by on_target_decoded
706+
// as the trunk processed each prompt-prefill ubatch. By the time
707+
// begin() is called, MTP KV covers positions 0..N-1 (matching the
708+
// trunk's prompt) — provided the server-side toggle and the
709+
// contiguous-rows precondition held. We just need to set up the
710+
// tracking state for the first draft.
711+
last_n_accepted = -1;
708712

709-
GGML_UNUSED(prompt);
713+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
714+
const int32_t N = (int32_t) prompt.size();
715+
LOG_INF("mtp begin: N=%d mtp_pos_max=%d (KV %s)\n",
716+
N, (int) pos_max,
717+
(pos_max + 1 == N) ? "fully prefilled" :
718+
(pos_max < 0) ? "empty (no prefill)" : "partial");
719+
720+
if (pos_max < 0) {
721+
// No prefill happened (e.g. server toggle off for non-MTP slot,
722+
// or contiguous-rows precondition failed). Seed BOS at position
723+
// 0 so the first draft has a non-empty KV to attend to. RoPE
724+
// will be misaligned with trunk for short prompts that's
725+
// tolerable; for long prompts the prefill path should always
726+
// win this race.
727+
const llama_model * model_mtp = llama_get_model(ctx_mtp);
728+
const llama_token bos = llama_vocab_bos(llama_model_get_vocab(model_mtp));
729+
batch.n_tokens = 1;
730+
batch.token[0] = bos;
731+
batch.pos[0] = 0;
732+
batch.n_seq_id[0] = 1;
733+
batch.seq_id[0][0] = 0;
734+
batch.logits[0] = 0;
735+
const int32_t rc = llama_decode(ctx_mtp, batch);
736+
if (rc != 0) {
737+
LOG_WRN("%s: ctx_mtp seed decode rc=%d\n", __func__, rc);
738+
}
739+
mtp_pos = 1;
740+
} else {
741+
mtp_pos = pos_max + 1;
742+
}
710743
}
711744

712745
void draft(
@@ -725,6 +758,10 @@ struct common_speculative_state_mtp : public common_speculative_state {
725758
const int32_t n_vocab = (int32_t) logits_buf.size();
726759
llama_token cond_tok = id_last;
727760

761+
const llama_pos pos_max_before = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
762+
LOG_INF("mtp draft: id_last=%d n_max=%d last_n_accepted=%d mtp_pos_max=%d\n",
763+
(int) id_last, (int) n_max, (int) last_n_accepted, (int) pos_max_before);
764+
728765
for (int32_t k = 0; k < n_max; ++k) {
729766
// Stage h. Step 0: from ctx_tgt's t_h_pre_norm at the row whose
730767
// hidden produced id_last. After a previous verify [sampled, d0,
@@ -735,14 +772,16 @@ struct common_speculative_state_mtp : public common_speculative_state {
735772
// ctx_tgt only computed the prompt's last position → row 0.
736773
// Step k>0: self-relay from ctx_mtp's previous t_mtp_out.
737774
int32_t rc_relay;
775+
int32_t src_row_used = -1;
738776
if (k == 0) {
739-
const int32_t src_row = (last_n_accepted < 0) ? 0 : last_n_accepted;
740-
rc_relay = llama_mtp_relay_h(ctx_tgt, ctx_mtp, src_row, /*n_rows=*/ 1);
777+
src_row_used = (last_n_accepted < 0) ? 0 : last_n_accepted;
778+
rc_relay = llama_mtp_relay_h(ctx_tgt, ctx_mtp, src_row_used, /*n_rows=*/ 1);
741779
} else {
742780
rc_relay = llama_mtp_relay_h_self(ctx_mtp, /*n_rows=*/ 1);
743781
}
744782
if (rc_relay != 0) {
745-
LOG_DBG("%s: relay rc=%d at k=%d; stopping chain\n", __func__, rc_relay, k);
783+
LOG_WRN("%s: relay rc=%d at k=%d (src_row=%d); stopping chain\n",
784+
__func__, rc_relay, k, src_row_used);
746785
return;
747786
}
748787

@@ -775,6 +814,8 @@ struct common_speculative_state_mtp : public common_speculative_state {
775814
for (int i = 1; i < n_vocab; ++i) {
776815
if (logits_buf[i] > bv) { bv = logits_buf[i]; best = i; }
777816
}
817+
LOG_INF("mtp draft k=%d pos=%d cond=%d -> %d (logit=%.2f)\n",
818+
(int) k, (int) pos, (int) cond_tok, best, bv);
778819
draft_tokens.push_back(best);
779820
cond_tok = best;
780821
}
@@ -790,12 +831,14 @@ struct common_speculative_state_mtp : public common_speculative_state {
790831
// positions from ctx_mtp's KV so the next draft writes K/V at the
791832
// right slots.
792833
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
834+
const int32_t n_drafted_last = (int32_t) last_n_drafted;
835+
const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted);
836+
LOG_INF("mtp accept: n_drafted=%d n_accepted=%d n_to_drop=%d mtp_pos_max=%d\n",
837+
n_drafted_last, (int) n_accepted, n_to_drop, (int) pos_max);
793838
if (pos_max < 0) {
794839
last_n_accepted = (int32_t) n_accepted;
795840
return;
796841
}
797-
const int32_t n_drafted_last = (int32_t) last_n_drafted;
798-
const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted);
799842
if (n_to_drop > 0) {
800843
const llama_pos drop_from = pos_max - n_to_drop + 1;
801844
llama_memory_seq_rm(llama_get_memory(ctx_mtp), /*seq_id=*/ 0,
@@ -807,6 +850,116 @@ struct common_speculative_state_mtp : public common_speculative_state {
807850
last_n_accepted = (int32_t) n_accepted;
808851
}
809852

853+
void on_target_decoded(const llama_batch & batch, llama_seq_id slot_seq_id, bool is_prompt_prefill) override {
854+
if (!is_prompt_prefill) {
855+
return; // verify-batch decodes are owned by the draft path
856+
}
857+
// Mirror the trunk's just-finished ubatch into ctx_mtp by running one
858+
// MTP forward over the same positions. ctx_target's t_h_pre_norm
859+
// currently carries one row per output position of THIS ubatch (the
860+
// server toggles output=true on every prompt-prefill token for MTP
861+
// slots), and its data is still fresh — graph_compute_async finished
862+
// before this hook fires.
863+
//
864+
// Conditions for staging a real prefill MTP decode:
865+
// - we're in prompt prefill (not the verify decode that draft()
866+
// handles itself: skip if any slot tokens have logits=true,
867+
// since the verify batch always sets logits everywhere). We
868+
// detect this by checking that ALL of OUR slot's tokens carry
869+
// logits=true AND the trunk t_h_pre_norm has rows for all of
870+
// them — i.e. this is the prefill regime.
871+
// - the slot is single-seq (n_parallel=1 is enforced for MTP).
872+
//
873+
// For each token at trunk pos p in the slot, we feed (h_p, prompt[p])
874+
// to the MTP block at MTP pos p. This is a "no-shift" approximation
875+
// — MTP was trained on (h_p, x_{p+1}) → predict x_{p+2}, so feeding
876+
// (h_p, x_p) puts slightly off-distribution K/V into MTP's KV, but
877+
// the K/V values are at the right positions for attention. The
878+
// alternative (proper shift) requires looking ahead to the next
879+
// ubatch's first token, which we don't have here.
880+
if (batch.n_tokens <= 0) {
881+
return;
882+
}
883+
ggml_tensor * h = llama_context_get_t_h_pre_norm(ctx_tgt);
884+
if (!h) {
885+
return; // trunk didn't produce t_h_pre_norm this decode
886+
}
887+
const int64_t n_rows = h->ne[1];
888+
if (n_rows < batch.n_tokens) {
889+
return; // not all positions have output rows; can't safely match
890+
}
891+
892+
// Filter tokens belonging to this slot, preserving batch order.
893+
// For n_parallel=1 every token belongs to the slot; the filter is a
894+
// no-op there.
895+
struct entry { int batch_idx; int row_idx; };
896+
std::vector<entry> mine;
897+
mine.reserve(batch.n_tokens);
898+
int row_idx = -1;
899+
for (int i = 0; i < batch.n_tokens; ++i) {
900+
const bool has_out = batch.logits && batch.logits[i];
901+
if (has_out) row_idx++;
902+
bool is_mine = false;
903+
if (batch.n_seq_id && batch.n_seq_id[i] > 0 && batch.seq_id) {
904+
for (int j = 0; j < batch.n_seq_id[i]; ++j) {
905+
if (batch.seq_id[i][j] == slot_seq_id) { is_mine = true; break; }
906+
}
907+
}
908+
if (is_mine && has_out && row_idx >= 0 && row_idx < n_rows) {
909+
mine.push_back({i, row_idx});
910+
}
911+
}
912+
if (mine.empty()) {
913+
return;
914+
}
915+
// Heuristic: only run prefill if the rows in t_h_pre_norm are
916+
// contiguous starting at 0 (they will be when our slot's tokens are
917+
// the only ones with output=true). Otherwise we'd need to gather
918+
// non-contiguous rows — skip rather than risk wrong h.
919+
for (size_t k = 0; k < mine.size(); ++k) {
920+
if (mine[k].row_idx != (int) k) {
921+
LOG_INF("mtp prefill skip: non-contiguous rows (slot=%d)\n", (int) slot_seq_id);
922+
return;
923+
}
924+
}
925+
926+
const int n = (int) mine.size();
927+
// Run MTP forwards in chunks of at most n_ubatch tokens — single
928+
// huge MTP forwards (e.g. 1500-token prompts) exceed compute scratch
929+
// and crash in ggml. The KV result is identical regardless of split,
930+
// since each chunk attends to all earlier MTP KV positions.
931+
const int chunk_max = (int) llama_n_ubatch(ctx_mtp);
932+
for (int off = 0; off < n; off += chunk_max) {
933+
const int n_chunk = std::min(chunk_max, n - off);
934+
935+
this->batch.n_tokens = n_chunk;
936+
for (int k = 0; k < n_chunk; ++k) {
937+
const int bi = mine[off + k].batch_idx;
938+
this->batch.token[k] = batch.token[bi];
939+
this->batch.pos[k] = batch.pos ? batch.pos[bi] : (off + k);
940+
this->batch.n_seq_id[k] = 1;
941+
this->batch.seq_id[k][0] = 0;
942+
this->batch.logits[k] = 0;
943+
}
944+
const int32_t rc_relay = llama_mtp_relay_h(ctx_tgt, ctx_mtp,
945+
/*src_row=*/ off, /*n_rows=*/ n_chunk);
946+
if (rc_relay != 0) {
947+
LOG_WRN("mtp prefill: relay rc=%d (chunk_off=%d, n=%d, slot=%d)\n",
948+
rc_relay, off, n_chunk, (int) slot_seq_id);
949+
return;
950+
}
951+
const int32_t rc = llama_decode(ctx_mtp, this->batch);
952+
if (rc != 0) {
953+
LOG_WRN("mtp prefill: decode rc=%d (chunk_off=%d, n=%d, slot=%d)\n",
954+
rc, off, n_chunk, (int) slot_seq_id);
955+
return;
956+
}
957+
}
958+
const llama_pos new_pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
959+
LOG_INF("mtp prefill: slot=%d n=%d chunks=%d mtp_pos_max=%d\n",
960+
(int) slot_seq_id, n, (n + chunk_max - 1) / chunk_max, (int) new_pos_max);
961+
}
962+
810963
int32_t n_max(const common_params_speculative & params) const override {
811964
return std::max(1, params.draft.n_max);
812965
}
@@ -1423,6 +1576,19 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
14231576
}
14241577
}
14251578

1579+
void common_speculative_on_target_decoded(
1580+
common_speculative * spec,
1581+
const llama_batch & batch,
1582+
llama_seq_id slot_seq_id,
1583+
bool is_prompt_prefill) {
1584+
if (!spec) {
1585+
return;
1586+
}
1587+
for (auto & impl : spec->impls) {
1588+
impl->on_target_decoded(batch, slot_seq_id, is_prompt_prefill);
1589+
}
1590+
}
1591+
14261592
int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params) {
14271593
if (spec == nullptr) {
14281594
return 0;

common/speculative.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,19 @@ llama_tokens common_speculative_draft(
3737
// informs the speculative decoder that n_accepted tokens were accepted by the target model
3838
void common_speculative_accept(common_speculative * spec, uint16_t n_accepted);
3939

40+
// Notifies the speculative decoder that ctx_tgt just decoded a batch. MTP
41+
// uses this hook (only when is_prompt_prefill = true) to mirror the just-
42+
// decoded ubatch into ctx_mtp — i.e. each trunk prompt-prefill ubatch
43+
// triggers one MTP ubatch decode with the same positions and tokens, so
44+
// MTP's KV grows incrementally as the trunk's prompt prefill progresses.
45+
// Pass is_prompt_prefill=false for verify-batch decodes (drafting flow
46+
// owns those) so MTP's draft-time K/V isn't clobbered.
47+
void common_speculative_on_target_decoded(
48+
common_speculative * spec,
49+
const llama_batch & batch,
50+
llama_seq_id slot_seq_id,
51+
bool is_prompt_prefill);
52+
4053
int32_t common_speculative_n_max(const common_speculative * spec, const common_params_speculative & params);
4154
int32_t common_speculative_n_min(const common_speculative * spec, const common_params_speculative & params);
4255

include/llama.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,15 @@ extern "C" {
988988
// hidden state plus a token batch to produce draft logits, with its own KV
989989
// cache populated by build_attn the same way any other layer's is.
990990
//
991+
// Returns ctx's most recent t_h_pre_norm tensor (the trunk's pre-output-
992+
// norm hidden state) for an MTP-enabled trunk arch, or NULL. Used by the
993+
// common speculative MTP implementation to harvest hidden-state rows
994+
// across ubatches during prompt prefill (the trunk's gf_res_prev only
995+
// carries the last ubatch's rows, so we accumulate as they're produced).
996+
// The returned tensor's data lives in ctx's compute or output buffer;
997+
// call llama_synchronize(ctx) before reading via ggml_backend_tensor_get.
998+
LLAMA_API struct ggml_tensor * llama_context_get_t_h_pre_norm(struct llama_context * ctx);
999+
9911000
// Stages a copy of n_rows of ctx_target's t_h_pre_norm starting at index
9921001
// `src_row` into rows [0, n_rows) of ctx_mtp's t_inp_h. The copy is
9931002
// deferred to the next llama_decode on ctx_mtp — by then the destination
@@ -1017,6 +1026,7 @@ extern "C" {
10171026
struct llama_context * ctx_mtp,
10181027
int32_t n_rows);
10191028

1029+
10201030
// Set abort callback
10211031
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
10221032

src/llama-context.cpp

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3100,18 +3100,39 @@ ggml_tensor * llama_context::get_t_h_pre_norm() const {
31003100
return gf_res_prev ? gf_res_prev->t_h_pre_norm : nullptr;
31013101
}
31023102

3103+
ggml_tensor * llama_context_get_t_h_pre_norm(struct llama_context * ctx) {
3104+
return ctx ? ctx->get_t_h_pre_norm() : nullptr;
3105+
}
3106+
31033107
ggml_tensor * llama_context::get_t_mtp_out() const {
31043108
return gf_res_prev ? gf_res_prev->t_mtp_out : nullptr;
31053109
}
31063110

31073111
void llama_context::set_mtp_h_source(struct llama_context * ctx_src, ggml_tensor * src,
31083112
int32_t row_first, int32_t n_rows) {
3109-
mtp_h_staging.ctx_src = ctx_src;
3110-
mtp_h_staging.src = src;
3111-
mtp_h_staging.row_first = row_first;
3112-
mtp_h_staging.n_rows = n_rows;
3113+
GGML_ASSERT(ctx_src && src && n_rows > 0);
3114+
GGML_ASSERT(row_first >= 0 && row_first + n_rows <= src->ne[1]);
3115+
3116+
// Wait for the source's compute to finish before reading its rows.
3117+
ctx_src->synchronize();
3118+
3119+
const size_t row_bytes = src->nb[1];
3120+
mtp_h_staging.host_buf.resize(row_bytes * (size_t) n_rows);
3121+
mtp_h_staging.n_rows = n_rows;
3122+
mtp_h_staging.n_embd = (int32_t) src->ne[0];
3123+
3124+
// Synchronous device-to-host of the requested row range.
3125+
ggml_backend_tensor_get(src, mtp_h_staging.host_buf.data(),
3126+
(size_t) row_first * row_bytes,
3127+
row_bytes * (size_t) n_rows);
3128+
3129+
LLAMA_LOG_DEBUG("mtp_relay stage: src=%s ne=[%lld,%lld] rows=[%d,%d) embd=%d bytes=%zu\n",
3130+
src->name, (long long) src->ne[0], (long long) src->ne[1],
3131+
row_first, row_first + n_rows, mtp_h_staging.n_embd,
3132+
mtp_h_staging.host_buf.size());
31133133
}
31143134

3135+
31153136
ggml_tensor * llama_context::get_t_inp_h() const {
31163137
// gf_res_prev->t_inp_h is set by the model's graph builder (e.g.
31173138
// llm_build_qwen35_mtp). After the first real llama_decode it lives there.

0 commit comments

Comments
 (0)