Skip to content

Commit 17d47df

Browse files
committed
mtp: add MTP-K chain via t_mtp_out + self-relay
Generalize the MTP draft path to support chain length K > 1, where each chain step conditions on the previous step's MTP block output instead of the target's pre-output-norm hidden. Pieces: - llama_graph_result gains t_mtp_out: the MTP block's post-FFN hidden (pre-LM-head). qwen35_mtp's graph builder sets it. - llama_context::get_t_mtp_out() exposes the most recent decode's value. - llama_mtp_relay_h_self(ctx_mtp, n_rows): on-device copy of the LAST n_rows of t_mtp_out into the FIRST n_rows of t_inp_h. Same machinery as llama_mtp_relay_h, just self-source. - common_speculative_state_mtp::draft chains n_max calls. Step 0 relays from ctx_target's t_h_pre_norm (existing). Steps 1..K-1 self-relay from ctx_mtp's previous t_mtp_out. Each step argmaxes its logits and feeds the result to the next. - accept(n_accepted) trims any rejected trailing draft positions from ctx_mtp's KV via seq_rm so the next draft writes K/V at the right slots. Tracks last_n_drafted to know how many to potentially drop. Smoke results on Qwen3.6-q8_0-mtp.gguf, --spec-draft-n-max 2: fibonacci: K=1 → 13.17 tok/s, 100% accept K=2 → 15.40 tok/s, 75% accept (12/16) K=2 wins because the prompt is highly canonical and even chain step 1 stays accepted most of the time. send_req: K=1 → 11.44 tok/s, 83.9% accept (182/217) K=2 → 9.48 tok/s, 29.7% accept (148/499) K=2 loses on dense code: chain step 1 accept falls off a cliff because Qwen3.6's MTP head is trained one-step-ahead and feeding it its own previous output is out-of-distribution (the FastMTP problem; also discussed in DeepSeek V3 paper). The infrastructure works correctly; the model doesn't benefit without retraining. Practical guidance: keep --spec-draft-n-max 1 for code/dense workloads. K > 1 only helps when the head was either trained for chain prediction (FastMTP-style) or when the workload is canonical enough that vanilla self-rolling stays in-distribution.
1 parent 51f799b commit 17d47df

6 files changed

Lines changed: 149 additions & 86 deletions

File tree

common/speculative.cpp

Lines changed: 70 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,10 @@ struct common_speculative_state_mtp : public common_speculative_state {
630630
// next draft writes at pos+1, etc. Reset by begin().
631631
llama_pos mtp_pos = 0;
632632

633+
// How many tokens the most recent draft() pushed into ctx_mtp. accept()
634+
// uses this to compute how many trailing positions to roll back.
635+
uint16_t last_n_drafted = 0;
636+
633637
common_speculative_state_mtp(enum common_speculative_type type,
634638
llama_context * ctx_tgt,
635639
llama_context * ctx_mtp)
@@ -701,73 +705,85 @@ struct common_speculative_state_mtp : public common_speculative_state {
701705
const llama_tokens & prompt_tgt,
702706
llama_token id_last,
703707
llama_tokens & draft_tokens) override {
704-
GGML_UNUSED(params);
705708
GGML_UNUSED(prompt_tgt);
706709
draft_tokens.clear();
707710

708-
// Stage h from the target's last decode into ctx_mtp's input buffer.
709-
// For a single-token MTP step we relay the LAST row of t_h_pre_norm
710-
// (the h corresponding to id_last's predecessor) into the FIRST row
711-
// of ctx_mtp's t_inp_h.
712-
const int32_t rc = llama_mtp_relay_h(ctx_tgt, ctx_mtp, /*n_rows=*/ 1);
713-
if (rc != 0) {
714-
LOG_DBG("%s: llama_mtp_relay_h rc=%d; skipping MTP draft\n", __func__, rc);
715-
return;
716-
}
711+
// Chain length K: 1 for plain MTP, 2+ for chained MTP-K. Each step
712+
// runs one MTP forward, takes the argmax as a draft token, and feeds
713+
// its hidden state forward as the next step's t_inp_h.
714+
const int32_t n_max = std::max(1, params.draft.n_max);
717715

718-
// Position: one past whatever's currently in ctx_mtp's KV cache for
719-
// seq 0. M-RoPE asserts Y > X (input start > cache last), so we always
720-
// advance by querying the cache rather than relying on a stale local
721-
// counter. Step 7's prompt prefill will warm the cache; until then it
722-
// starts at -1 (empty) and grows as drafts are accepted.
723-
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
724-
const llama_pos pos = pos_max + 1;
716+
const int32_t n_vocab = (int32_t) logits_buf.size();
717+
llama_token cond_tok = id_last;
718+
719+
for (int32_t k = 0; k < n_max; ++k) {
720+
// Stage h. Step 0: from ctx_tgt's t_h_pre_norm (the h corresponding
721+
// to id_last's position in the trunk). Step k>0: self-relay from
722+
// ctx_mtp's previous t_mtp_out (the MTP block's post-FFN h from
723+
// the previous chain step).
724+
const int32_t rc_relay = (k == 0)
725+
? llama_mtp_relay_h(ctx_tgt, ctx_mtp, /*n_rows=*/ 1)
726+
: llama_mtp_relay_h_self(ctx_mtp, /*n_rows=*/ 1);
727+
if (rc_relay != 0) {
728+
LOG_DBG("%s: relay rc=%d at k=%d; stopping chain\n", __func__, rc_relay, k);
729+
return;
730+
}
725731

726-
// Build the single-token batch.
727-
batch.n_tokens = 1;
728-
batch.token[0] = id_last;
729-
batch.pos[0] = pos;
730-
batch.n_seq_id[0] = 1;
731-
batch.seq_id[0][0] = 0;
732-
batch.logits[0] = 1;
732+
// Position: one past whatever's in ctx_mtp's KV. Always queried
733+
// (M-RoPE requires Y > X strictly).
734+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
735+
const llama_pos pos = pos_max + 1;
736+
737+
batch.n_tokens = 1;
738+
batch.token[0] = cond_tok;
739+
batch.pos[0] = pos;
740+
batch.n_seq_id[0] = 1;
741+
batch.seq_id[0][0] = 0;
742+
batch.logits[0] = 1;
743+
744+
const int32_t dec_rc = llama_decode(ctx_mtp, batch);
745+
if (dec_rc != 0) {
746+
LOG_DBG("%s: llama_decode rc=%d at k=%d; stopping chain\n", __func__, dec_rc, k);
747+
return;
748+
}
733749

734-
// Run the MTP graph: ctx_mtp consumes (h_input from the relay,
735-
// e(id_last)) and produces draft logits.
736-
const int32_t dec_rc = llama_decode(ctx_mtp, batch);
737-
if (dec_rc != 0) {
738-
LOG_DBG("%s: llama_decode on ctx_mtp rc=%d\n", __func__, dec_rc);
739-
return;
740-
}
750+
const float * logits = llama_get_logits_ith(ctx_mtp, 0);
751+
if (!logits) {
752+
return;
753+
}
741754

742-
const float * logits = llama_get_logits_ith(ctx_mtp, 0);
743-
if (!logits) {
744-
return;
755+
std::memcpy(logits_buf.data(), logits, n_vocab * sizeof(float));
756+
int best = 0;
757+
float bv = logits_buf[0];
758+
for (int i = 1; i < n_vocab; ++i) {
759+
if (logits_buf[i] > bv) { bv = logits_buf[i]; best = i; }
760+
}
761+
draft_tokens.push_back(best);
762+
cond_tok = best;
745763
}
746764

747-
// Greedy argmax draft.
748-
const int32_t n_vocab = (int32_t) logits_buf.size();
749-
std::memcpy(logits_buf.data(), logits, n_vocab * sizeof(float));
750-
int best = 0;
751-
float bv = logits_buf[0];
752-
for (int i = 1; i < n_vocab; ++i) {
753-
if (logits_buf[i] > bv) { bv = logits_buf[i]; best = i; }
754-
}
755-
draft_tokens.push_back(best);
765+
last_n_drafted = (uint16_t) draft_tokens.size();
756766
}
757767

758768
void accept(uint16_t n_accepted) override {
759-
// The previous draft() pushed one token (the drafted token) into
760-
// ctx_mtp at pos+1. If the verifier accepted it, leave it. If not
761-
// (n_accepted=0), trim it back so the next draft picks the slot
762-
// again. We always trim past `pos_max - (n_drafted - n_accepted)`,
763-
// but for K=1 this simplifies to: trim if rejected.
764-
if (n_accepted == 0) {
765-
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
766-
if (pos_max >= 0) {
767-
llama_memory_seq_rm(llama_get_memory(ctx_mtp), /*seq_id=*/ 0,
768-
/*pos_min=*/ pos_max, /*pos_max=*/ -1);
769-
}
769+
// The previous draft() pushed K tokens into ctx_mtp at positions
770+
// [pos_max - K + 1, pos_max]. The verifier accepted the first
771+
// n_accepted of them; the remaining K - n_accepted came from
772+
// chain steps that the verifier rejected. Trim those rejected
773+
// positions from ctx_mtp's KV so the next draft writes K/V at the
774+
// right slots.
775+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
776+
if (pos_max < 0) {
777+
return;
778+
}
779+
const int32_t n_drafted_last = (int32_t) last_n_drafted;
780+
const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted);
781+
if (n_to_drop > 0) {
782+
const llama_pos drop_from = pos_max - n_to_drop + 1;
783+
llama_memory_seq_rm(llama_get_memory(ctx_mtp), /*seq_id=*/ 0,
784+
/*p0=*/ drop_from, /*p1=*/ -1);
770785
}
786+
last_n_drafted = 0;
771787
}
772788

773789
int32_t n_max(const common_params_speculative & params) const override {

include/llama.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,15 @@ extern "C" {
10011001
struct llama_context * ctx_mtp,
10021002
int32_t n_rows);
10031003

1004+
// Self-relay: copy the LAST n_rows of ctx_mtp's most recent t_mtp_out
1005+
// (the MTP block's post-FFN hidden) into the FIRST n_rows of its own
1006+
// t_inp_h. Used for chained MTP-K drafting (K > 1) where each chain step
1007+
// conditions on the previous step's MTP output rather than the target's
1008+
// pre-output-norm hidden.
1009+
LLAMA_API int32_t llama_mtp_relay_h_self(
1010+
struct llama_context * ctx_mtp,
1011+
int32_t n_rows);
1012+
10041013
// Set abort callback
10051014
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
10061015

src/llama-context.cpp

Lines changed: 53 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3099,6 +3099,10 @@ ggml_tensor * llama_context::get_t_h_pre_norm() const {
30993099
return gf_res_prev ? gf_res_prev->t_h_pre_norm : nullptr;
31003100
}
31013101

3102+
ggml_tensor * llama_context::get_t_mtp_out() const {
3103+
return gf_res_prev ? gf_res_prev->t_mtp_out : nullptr;
3104+
}
3105+
31023106
ggml_tensor * llama_context::get_t_inp_h() const {
31033107
// gf_res_prev->t_inp_h is set by the model's graph builder (e.g.
31043108
// llm_build_qwen35_mtp). After the first real llama_decode it lives there.
@@ -3114,52 +3118,44 @@ ggml_tensor * llama_context::get_t_inp_h() const {
31143118
return nullptr;
31153119
}
31163120

3117-
int32_t llama_mtp_relay_h(
3118-
struct llama_context * ctx_target,
3119-
struct llama_context * ctx_mtp,
3120-
int32_t n_rows) {
3121-
if (!ctx_target || !ctx_mtp) {
3122-
return -1;
3123-
}
3124-
3125-
ggml_tensor * src = ctx_target->get_t_h_pre_norm();
3121+
// Common implementation: copy the LAST n_rows of `src` into the FIRST n_rows
3122+
// of `dst`, on-device via ggml_backend_tensor_copy_async. ctx_src/ctx_dst are
3123+
// used to look up backends per tensor and to synchronize the source.
3124+
static int32_t llama_mtp_relay_impl(
3125+
struct llama_context * ctx_src,
3126+
struct llama_context * ctx_dst,
3127+
ggml_tensor * src,
3128+
ggml_tensor * dst,
3129+
int32_t n_rows,
3130+
const char * fn) {
31263131
if (!src) {
3127-
LLAMA_LOG_ERROR("%s: ctx_target's last decode did not produce t_h_pre_norm\n", __func__);
3132+
LLAMA_LOG_ERROR("%s: src tensor missing\n", fn);
31283133
return -2;
31293134
}
3130-
3131-
ggml_tensor * dst = ctx_mtp->get_t_inp_h();
31323135
if (!dst) {
3133-
LLAMA_LOG_ERROR("%s: ctx_mtp has no t_inp_h (graph not built or wrong arch)\n", __func__);
3136+
LLAMA_LOG_ERROR("%s: dst tensor missing (graph not built or wrong arch)\n", fn);
31343137
return -3;
31353138
}
3136-
31373139
if (src->ne[0] != dst->ne[0]) {
31383140
LLAMA_LOG_ERROR("%s: shape mismatch: src n_embd=%" PRId64 ", dst n_embd=%" PRId64 "\n",
3139-
__func__, src->ne[0], dst->ne[0]);
3141+
fn, src->ne[0], dst->ne[0]);
31403142
return -4;
31413143
}
3142-
31433144
if (n_rows <= 0 || n_rows > src->ne[1] || n_rows > dst->ne[1]) {
31443145
LLAMA_LOG_ERROR("%s: n_rows=%d out of range (src cap=%" PRId64 ", dst cap=%" PRId64 ")\n",
3145-
__func__, n_rows, src->ne[1], dst->ne[1]);
3146+
fn, n_rows, src->ne[1], dst->ne[1]);
31463147
return -5;
31473148
}
31483149

3149-
// Copy the LAST n_rows of src into the FIRST n_rows of dst.
3150-
const int32_t src_first = (int32_t) src->ne[1] - n_rows;
3151-
const int32_t dst_first = 0;
3152-
3153-
// Wait for ctx_target's last compute to finish before reading t_h_pre_norm.
3154-
ctx_target->synchronize();
3150+
// Wait for the source's compute to finish before reading.
3151+
ctx_src->synchronize();
31553152

3156-
// Build views for the row range we want to copy. ggml_view_2d does NOT
3157-
// propagate the parent's backend buffer to the view tensor (it sets
3158-
// view->buffer = NULL and only forwards view->data + offset), so we have
3159-
// to wire the buffer manually before passing the views to copy_async —
3160-
// otherwise the backend's copy path hits a null buffer and aborts inside
3161-
// ggml_backend_buffer_get_type.
3153+
// Build views for the row range. ggml_view_2d does not propagate the
3154+
// parent's backend buffer to the view (it sets view->buffer = NULL and
3155+
// only forwards view->data + offset), so wire the buffer manually before
3156+
// passing the views to copy_async.
31623157
const size_t row_size = src->nb[1];
3158+
const int32_t src_first = (int32_t) src->ne[1] - n_rows; // last n_rows of src
31633159
const size_t src_offset = (size_t) src_first * row_size;
31643160

31653161
ggml_context_ptr view_ctx;
@@ -3182,20 +3178,45 @@ int32_t llama_mtp_relay_h(
31823178
src_view->buffer = src->buffer;
31833179
dst_view->buffer = dst->buffer;
31843180

3185-
auto * sched_src = ctx_target->get_sched();
3186-
auto * sched_dst = ctx_mtp->get_sched();
3181+
auto * sched_src = ctx_src->get_sched();
3182+
auto * sched_dst = ctx_dst->get_sched();
31873183
auto * backend_src = ggml_backend_sched_get_tensor_backend(sched_src, src);
31883184
auto * backend_dst = ggml_backend_sched_get_tensor_backend(sched_dst, dst);
31893185
if (!backend_src || !backend_dst) {
31903186
LLAMA_LOG_ERROR("%s: backend resolve failed (src=%p dst=%p)\n",
3191-
__func__, (void *) backend_src, (void *) backend_dst);
3187+
fn, (void *) backend_src, (void *) backend_dst);
31923188
return -8;
31933189
}
31943190

31953191
ggml_backend_tensor_copy_async(backend_src, backend_dst, src_view, dst_view);
31963192
return 0;
31973193
}
31983194

3195+
int32_t llama_mtp_relay_h(
3196+
struct llama_context * ctx_target,
3197+
struct llama_context * ctx_mtp,
3198+
int32_t n_rows) {
3199+
if (!ctx_target || !ctx_mtp) {
3200+
return -1;
3201+
}
3202+
return llama_mtp_relay_impl(ctx_target, ctx_mtp,
3203+
ctx_target->get_t_h_pre_norm(),
3204+
ctx_mtp->get_t_inp_h(),
3205+
n_rows, __func__);
3206+
}
3207+
3208+
int32_t llama_mtp_relay_h_self(
3209+
struct llama_context * ctx_mtp,
3210+
int32_t n_rows) {
3211+
if (!ctx_mtp) {
3212+
return -1;
3213+
}
3214+
return llama_mtp_relay_impl(ctx_mtp, ctx_mtp,
3215+
ctx_mtp->get_t_mtp_out(),
3216+
ctx_mtp->get_t_inp_h(),
3217+
n_rows, __func__);
3218+
}
3219+
31993220
void llama_synchronize(llama_context * ctx) {
32003221
ctx->synchronize();
32013222
}

src/llama-context.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ struct llama_context {
8080
// writes into this tensor before llama_decode runs on ctx_mtp.
8181
ggml_tensor * get_t_inp_h() const;
8282

83+
// For LLM_ARCH_QWEN35_MTP contexts: the MTP block's post-FFN output from
84+
// the most recent decode. Used by chained MTP-K drafting (K > 1) — the
85+
// self-relay copies this into t_inp_h for the next chain step.
86+
ggml_tensor * get_t_mtp_out() const;
87+
8388
llama_token * get_sampled_tokens() const;
8489
llama_token get_sampled_token_ith(int32_t idx);
8590

src/llama-graph.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,13 @@ class llm_graph_result {
706706
// each llama_decode via llama_mtp_relay_h.
707707
ggml_tensor * t_inp_h = nullptr; // [n_embd, n_tokens]
708708

709+
// For LLM_ARCH_QWEN35_MTP: the MTP block's post-FFN output, before the
710+
// shared LM head. Used by chained MTP-K drafting (K > 1): the speculative
711+
// wrapper relays this back into t_inp_h for the next draft step so the
712+
// chain conditions on the previous step's hidden state, matching how the
713+
// MTP head was trained.
714+
ggml_tensor * t_mtp_out = nullptr; // [n_embd, n_tokens]
715+
709716
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
710717
std::map<llama_seq_id, ggml_tensor*> t_candidates;
711718
std::map<llama_seq_id, ggml_tensor*> t_sampled;

src/models/qwen35_mtp.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ llm_build_qwen35_mtp::llm_build_qwen35_mtp(const llama_model & model, const llm_
150150
cur = ggml_add(ctx0, cur, ffn_residual);
151151
cb(cur, "mtp_post_ffn", il);
152152

153+
// Snapshot the MTP block's post-FFN hidden — this is what gets fed back
154+
// as the next chain step's t_inp_h for MTP-K drafting (K > 1). Lives on
155+
// the output buffer alongside t_logits.
156+
res->t_mtp_out = cur;
157+
153158
// Shared final norm + LM head. The MTP block carries its own
154159
// shared_head_norm; if absent (some converted variants), fall back to the
155160
// model's output_norm. The LM head is the model's output (or tok_embd if

0 commit comments

Comments
 (0)