Skip to content

Commit b9be095

Browse files
committed
mtp: stage h relay through set_input for graph-rebuild safety
Refactor the h relay path so the device-to-device copy from ctx_target's t_h_pre_norm into ctx_mtp's t_inp_h happens during set_input on the next decode, rather than immediately at the relay call. This prepares the ground for batched MTP prompt prefill, where ctx_mtp has to switch between n_tokens=N (prefill) and n_tokens=1 (chain step) graphs and the old "copy now" path could not survive the rebuild — t_inp_h's tensor identity changes when the graph rebuilds, but the relay had already written to the prior graph's tensor. Mechanism: - llama_context gains an mtp_h_source_t staging slot (ctx_src + src tensor + row range), set by llama_mtp_relay_h{,_self} and consumed during set_inputs on the next decode. - llm_graph_input_h_pre_norm now holds a llama_context* and reads the staged source in its set_input. The actual ggml_backend_tensor_copy_ async lives there (synchronizes ctx_src, builds row views with manually-wired buffers, sched-resolves backends per side, then async copies). After the copy the staging is cleared so a stray decode without a fresh relay call doesn't replay stale data. - llm_graph_params carries a llama_context* so the graph builder can wire it onto the input class. graph_params() in llama-context.cpp passes `this`. - llama_mtp_relay_h gains an n_rows parameter (1 by default for per- step drafting; N for an upcoming batched-prefill caller). No behavior change at K=1/K=2 — relay still fires every draft step, still copies the same rows. Verified send_req on Qwen3.6-q8_0-mtp: K=1: 88.2% accept (187/212), 12.0 tok/s (was 88%, 12.5) K=2: 85.7% accept (252/294), 16.2 tok/s (was 86%, 16.9) Within noise — the slight tok/s dip is the extra synchronize + view allocation per set_input call; trivially recoverable later. Why this matters: with the relay flowing through set_input, the next commit can do batched MTP prompt prefill (single n_tokens=N decode) followed by the existing single-token chain steps without the t_inp_h identity gymnastics. That fixes the long-context issue where MTP's KV currently holds only [BOS, draft_1, ..., draft_M] and MTP attention cannot see prompt context, plus the position drift where MTP applies RoPE at local positions 1..M+1 while the trunk is at absolute position N..N+M (for a 4K prompt those rotations diverge enough to wreck attention quality).
1 parent 183a99c commit b9be095

7 files changed

Lines changed: 149 additions & 97 deletions

File tree

common/speculative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -737,7 +737,7 @@ struct common_speculative_state_mtp : public common_speculative_state {
737737
int32_t rc_relay;
738738
if (k == 0) {
739739
const int32_t src_row = (last_n_accepted < 0) ? 0 : last_n_accepted;
740-
rc_relay = llama_mtp_relay_h(ctx_tgt, ctx_mtp, src_row);
740+
rc_relay = llama_mtp_relay_h(ctx_tgt, ctx_mtp, src_row, /*n_rows=*/ 1);
741741
} else {
742742
rc_relay = llama_mtp_relay_h_self(ctx_mtp, /*n_rows=*/ 1);
743743
}

include/llama.h

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -988,24 +988,25 @@ extern "C" {
988988
// hidden state plus a token batch to produce draft logits, with its own KV
989989
// cache populated by build_attn the same way any other layer's is.
990990
//
991-
// Copies a single row at index `src_row` of ctx_target's t_h_pre_norm into
992-
// row 0 of ctx_mtp's t_inp_h. Both backends must be able to issue a copy
993-
// between each other (typical case: same device, fast on-device copy).
991+
// Stages a copy of n_rows of ctx_target's t_h_pre_norm starting at index
992+
// `src_row` into rows [0, n_rows) of ctx_mtp's t_inp_h. The copy is
993+
// deferred to the next llama_decode on ctx_mtp — by then the destination
994+
// graph has been built and t_inp_h has stable identity. Calling this
995+
// function only records the source; the actual device-to-device copy
996+
// happens during set_inputs on the next decode.
994997
//
995-
// The right `src_row` for MTP drafting is the row whose hidden produced the
996-
// verifier sample that becomes the next draft's id_last. After a verify
997-
// batch [sampled, d0, ..., d_{K-1}] with `n_accepted` drafts accepted, that
998-
// is `src_row = n_accepted` (the bonus token was sampled from h at row
999-
// n_accepted). Using the last row instead silently corrupts MTP whenever
1000-
// n_accepted < K; the bug is invisible at K=1 most of the time but tanks
1001-
// K>=2.
998+
// For per-step drafting use n_rows=1 with src_row = n_accepted (the row
999+
// whose hidden produced the verifier sample that became id_last). For
1000+
// batched MTP prompt prefill use src_row=0 and n_rows = N (the prompt
1001+
// length, requiring ctx_target's prompt prefill to have logits=true on
1002+
// every position so t_h_pre_norm carries all rows).
10021003
//
1003-
// Returns 0 on success; negative on error (e.g. ctx_target's last decode
1004-
// didn't produce t_h_pre_norm, src_row out of range, shape mismatch).
1004+
// Returns 0 on success; negative on error.
10051005
LLAMA_API int32_t llama_mtp_relay_h(
10061006
struct llama_context * ctx_target,
10071007
struct llama_context * ctx_mtp,
1008-
int32_t src_row);
1008+
int32_t src_row,
1009+
int32_t n_rows);
10091010

10101011
// Self-relay: copy the LAST n_rows of ctx_mtp's most recent t_mtp_out
10111012
// (the MTP block's post-FFN hidden) into the FIRST n_rows of its own

src/llama-context.cpp

Lines changed: 35 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2165,6 +2165,7 @@ llm_graph_params llama_context::graph_params(
21652165
/*.gtype =*/ gtype,
21662166
/*.sched =*/ sched.get(),
21672167
/*.backend_cpu =*/ backend_cpu,
2168+
/*.ctx =*/ const_cast<llama_context *>(this),
21682169
/*.cvec =*/ cvec.get(),
21692170
/*.loras =*/ loras.get(),
21702171
/*.mctx =*/ mctx,
@@ -3103,6 +3104,14 @@ ggml_tensor * llama_context::get_t_mtp_out() const {
31033104
return gf_res_prev ? gf_res_prev->t_mtp_out : nullptr;
31043105
}
31053106

3107+
void llama_context::set_mtp_h_source(struct llama_context * ctx_src, ggml_tensor * src,
3108+
int32_t row_first, int32_t n_rows) {
3109+
mtp_h_staging.ctx_src = ctx_src;
3110+
mtp_h_staging.src = src;
3111+
mtp_h_staging.row_first = row_first;
3112+
mtp_h_staging.n_rows = n_rows;
3113+
}
3114+
31063115
ggml_tensor * llama_context::get_t_inp_h() const {
31073116
// gf_res_prev->t_inp_h is set by the model's graph builder (e.g.
31083117
// llm_build_qwen35_mtp). After the first real llama_decode it lives there.
@@ -3118,91 +3127,45 @@ ggml_tensor * llama_context::get_t_inp_h() const {
31183127
return nullptr;
31193128
}
31203129

3121-
// Common implementation: copy a single row at `src_row` of `src` into row 0
3122-
// of `dst`, on-device via ggml_backend_tensor_copy_async. ctx_src/ctx_dst are
3123-
// used to look up backends per tensor and to synchronize the source.
3124-
static int32_t llama_mtp_relay_impl(
3130+
// Helper: validate the source tensor + row range, then stage on ctx_mtp.
3131+
// The actual device-to-device copy is deferred to llm_graph_input_h_pre_norm::
3132+
// set_input on the next decode — by then ctx_mtp's graph is built and t_inp_h
3133+
// has stable identity. Doing the copy immediately would race with later graph
3134+
// rebuilds (e.g. between an n_tokens=N prefill and an n_tokens=1 chain step).
3135+
static int32_t llama_mtp_stage_source(
31253136
struct llama_context * ctx_src,
3126-
struct llama_context * ctx_dst,
3137+
struct llama_context * ctx_mtp,
31273138
ggml_tensor * src,
3128-
ggml_tensor * dst,
3129-
int32_t src_row,
3139+
int32_t row_first,
3140+
int32_t n_rows,
31303141
const char * fn) {
31313142
if (!src) {
31323143
LLAMA_LOG_ERROR("%s: src tensor missing\n", fn);
31333144
return -2;
31343145
}
3135-
if (!dst) {
3136-
LLAMA_LOG_ERROR("%s: dst tensor missing (graph not built or wrong arch)\n", fn);
3146+
if (n_rows <= 0) {
3147+
LLAMA_LOG_ERROR("%s: n_rows=%d must be > 0\n", fn, n_rows);
31373148
return -3;
31383149
}
3139-
if (src->ne[0] != dst->ne[0]) {
3140-
LLAMA_LOG_ERROR("%s: shape mismatch: src n_embd=%" PRId64 ", dst n_embd=%" PRId64 "\n",
3141-
fn, src->ne[0], dst->ne[0]);
3150+
if (row_first < 0 || row_first + n_rows > src->ne[1]) {
3151+
LLAMA_LOG_ERROR("%s: row range [%d, %d) out of src cap %" PRId64 "\n",
3152+
fn, row_first, row_first + n_rows, src->ne[1]);
31423153
return -4;
31433154
}
3144-
if (src_row < 0 || src_row >= src->ne[1] || dst->ne[1] < 1) {
3145-
LLAMA_LOG_ERROR("%s: src_row=%d out of range (src cap=%" PRId64 ", dst cap=%" PRId64 ")\n",
3146-
fn, src_row, src->ne[1], dst->ne[1]);
3147-
return -5;
3148-
}
3149-
3150-
// Wait for the source's compute to finish before reading.
3151-
ctx_src->synchronize();
3152-
3153-
// Build views for the row range. ggml_view_2d does not propagate the
3154-
// parent's backend buffer to the view (it sets view->buffer = NULL and
3155-
// only forwards view->data + offset), so wire the buffer manually before
3156-
// passing the views to copy_async.
3157-
const size_t row_size = src->nb[1];
3158-
const int32_t n_rows = 1;
3159-
const size_t src_offset = (size_t) src_row * row_size;
3160-
3161-
ggml_context_ptr view_ctx;
3162-
{
3163-
ggml_init_params params = {
3164-
/*.mem_size =*/ ggml_tensor_overhead() * 2,
3165-
/*.mem_buffer =*/ nullptr,
3166-
/*.no_alloc =*/ true,
3167-
};
3168-
view_ctx.reset(ggml_init(params));
3169-
if (!view_ctx) {
3170-
return -7;
3171-
}
3172-
}
3173-
3174-
ggml_tensor * src_view = ggml_view_2d(view_ctx.get(), src,
3175-
src->ne[0], n_rows, src->nb[1], src_offset);
3176-
ggml_tensor * dst_view = ggml_view_2d(view_ctx.get(), dst,
3177-
dst->ne[0], n_rows, dst->nb[1], /*offset=*/ 0);
3178-
src_view->buffer = src->buffer;
3179-
dst_view->buffer = dst->buffer;
3180-
3181-
auto * sched_src = ctx_src->get_sched();
3182-
auto * sched_dst = ctx_dst->get_sched();
3183-
auto * backend_src = ggml_backend_sched_get_tensor_backend(sched_src, src);
3184-
auto * backend_dst = ggml_backend_sched_get_tensor_backend(sched_dst, dst);
3185-
if (!backend_src || !backend_dst) {
3186-
LLAMA_LOG_ERROR("%s: backend resolve failed (src=%p dst=%p)\n",
3187-
fn, (void *) backend_src, (void *) backend_dst);
3188-
return -8;
3189-
}
3190-
3191-
ggml_backend_tensor_copy_async(backend_src, backend_dst, src_view, dst_view);
3155+
ctx_mtp->set_mtp_h_source(ctx_src, src, row_first, n_rows);
31923156
return 0;
31933157
}
31943158

31953159
int32_t llama_mtp_relay_h(
31963160
struct llama_context * ctx_target,
31973161
struct llama_context * ctx_mtp,
3198-
int32_t src_row) {
3162+
int32_t src_row,
3163+
int32_t n_rows) {
31993164
if (!ctx_target || !ctx_mtp) {
32003165
return -1;
32013166
}
3202-
return llama_mtp_relay_impl(ctx_target, ctx_mtp,
3203-
ctx_target->get_t_h_pre_norm(),
3204-
ctx_mtp->get_t_inp_h(),
3205-
src_row, __func__);
3167+
return llama_mtp_stage_source(ctx_target, ctx_mtp,
3168+
ctx_target->get_t_h_pre_norm(), src_row, n_rows, __func__);
32063169
}
32073170

32083171
int32_t llama_mtp_relay_h_self(
@@ -3211,18 +3174,18 @@ int32_t llama_mtp_relay_h_self(
32113174
if (!ctx_mtp) {
32123175
return -1;
32133176
}
3214-
// Self-relay: t_mtp_out has shape [n_embd, n_tokens] from the previous
3215-
// single-token decode, so n_tokens=1 and the only row is 0.
32163177
GGML_UNUSED(n_rows);
3178+
// Self-relay sources from the LAST row of t_mtp_out (the most recent
3179+
// chain step's post-FFN hidden). Single row only — t_mtp_out has
3180+
// shape [n_embd, n_tokens] of the prior decode and we always want row
3181+
// n_tokens-1 here.
32173182
ggml_tensor * src = ctx_mtp->get_t_mtp_out();
32183183
if (!src) {
32193184
return -2;
32203185
}
3221-
const int32_t src_row = (int32_t) src->ne[1] - 1;
3222-
return llama_mtp_relay_impl(ctx_mtp, ctx_mtp,
3223-
src,
3224-
ctx_mtp->get_t_inp_h(),
3225-
src_row, __func__);
3186+
const int32_t row_first = (int32_t) src->ne[1] - 1;
3187+
return llama_mtp_stage_source(ctx_mtp, ctx_mtp,
3188+
src, row_first, /*n_rows=*/ 1, __func__);
32263189
}
32273190

32283191
void llama_synchronize(llama_context * ctx) {

src/llama-context.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,25 @@ struct llama_context {
8585
// self-relay copies this into t_inp_h for the next chain step.
8686
ggml_tensor * get_t_mtp_out() const;
8787

88+
// MTP h staging — set by llama_mtp_relay_h{,_self}, consumed during the
89+
// NEXT llama_decode by llm_graph_input_h_pre_norm::set_input. Stable across
90+
// graph rebuilds (lives on the context, not on a per-decode graph result),
91+
// which is what lets us survive the n_tokens=N → n_tokens=1 transition
92+
// between batched prompt prefill and single-token chain steps. The relay
93+
// stages instead of copying immediately because the destination tensor
94+
// (t_inp_h) only exists once the next decode builds its graph.
95+
struct mtp_h_source_t {
96+
struct llama_context * ctx_src = nullptr; // for synchronize() + sched lookup
97+
ggml_tensor * src = nullptr; // tensor to copy rows from
98+
int32_t row_first = 0; // first source row
99+
int32_t n_rows = 0; // 0 = no staging, set_input is a no-op
100+
};
101+
102+
void set_mtp_h_source(struct llama_context * ctx_src, ggml_tensor * src,
103+
int32_t row_first, int32_t n_rows);
104+
mtp_h_source_t get_mtp_h_source() const { return mtp_h_staging; }
105+
void clear_mtp_h_source() { mtp_h_staging = {}; }
106+
88107
llama_token * get_sampled_tokens() const;
89108
llama_token get_sampled_token_ith(int32_t idx);
90109

@@ -362,4 +381,6 @@ struct llama_context {
362381
mutable int32_t n_eval = 0; // number of eval calls
363382

364383
mutable int32_t n_reused = 0; // number of times the previous graph was reused
384+
385+
mtp_h_source_t mtp_h_staging;
365386
};

src/llama-graph.cpp

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "llama-impl.h"
44
#include "llama-model.h"
55
#include "llama-batch.h"
6+
#include "llama-context.h"
67
#include "llama-cparams.h"
78

89
#include "llama-kv-cache.h"
@@ -97,6 +98,63 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) {
9798
return res;
9899
}
99100

101+
void llm_graph_input_h_pre_norm::set_input(const llama_ubatch * /*ubatch*/) {
102+
// Read the staged source from the owning context. The relay function
103+
// (llama_mtp_relay_h{,_self}) only RECORDS the source; the actual copy
104+
// happens here, after the current graph has been built and `h` has stable
105+
// identity. This is what lets us survive graph rebuilds between batched
106+
// prefill (n_tokens=N) and single-token chain steps.
107+
if (!ctx_mtp || !h) {
108+
return;
109+
}
110+
auto staged = ctx_mtp->get_mtp_h_source();
111+
if (!staged.src || staged.n_rows <= 0) {
112+
return; // no relay staged for this decode
113+
}
114+
115+
GGML_ASSERT(staged.src->ne[0] == h->ne[0] && "h embd dim mismatch");
116+
GGML_ASSERT(staged.row_first >= 0 && staged.row_first + staged.n_rows <= staged.src->ne[1]);
117+
GGML_ASSERT(staged.n_rows <= h->ne[1] && "staged n_rows exceeds h capacity");
118+
119+
// The source ctx may be ctx_mtp itself (self-relay during chained
120+
// drafting) or a separate ctx_target. Sync it so its compute is done.
121+
staged.ctx_src->synchronize();
122+
123+
// ggml_view_2d does not propagate the parent's backend buffer onto the
124+
// view — it leaves view->buffer == NULL. Wire it manually before passing
125+
// the views to copy_async / sched_get_tensor_backend.
126+
const size_t row_size = staged.src->nb[1];
127+
const size_t src_off = (size_t) staged.row_first * row_size;
128+
129+
ggml_init_params init = {
130+
/*.mem_size =*/ ggml_tensor_overhead() * 2,
131+
/*.mem_buffer =*/ nullptr,
132+
/*.no_alloc =*/ true,
133+
};
134+
ggml_context_ptr view_ctx;
135+
view_ctx.reset(ggml_init(init));
136+
GGML_ASSERT(view_ctx);
137+
138+
ggml_tensor * src_view = ggml_view_2d(view_ctx.get(), staged.src,
139+
staged.src->ne[0], staged.n_rows, row_size, src_off);
140+
ggml_tensor * dst_view = ggml_view_2d(view_ctx.get(), h,
141+
h->ne[0], staged.n_rows, h->nb[1], /*offset=*/ 0);
142+
src_view->buffer = staged.src->buffer;
143+
dst_view->buffer = h->buffer;
144+
145+
auto * sched_src = staged.ctx_src->get_sched();
146+
auto * sched_dst = ctx_mtp->get_sched();
147+
auto * backend_src = ggml_backend_sched_get_tensor_backend(sched_src, staged.src);
148+
auto * backend_dst = ggml_backend_sched_get_tensor_backend(sched_dst, h);
149+
GGML_ASSERT(backend_src && backend_dst && "MTP h relay: backend resolve failed");
150+
151+
ggml_backend_tensor_copy_async(backend_src, backend_dst, src_view, dst_view);
152+
153+
// Consume the staging so a subsequent decode without a fresh relay call
154+
// doesn't accidentally re-copy stale rows.
155+
ctx_mtp->clear_mtp_h_source();
156+
}
157+
100158
void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
101159
if (ubatch->pos && pos) {
102160
const int64_t n_tokens = ubatch->n_tokens;

src/llama-graph.h

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ struct ggml_tensor;
1818

1919
struct llama_cparams;
2020
struct llama_layer;
21+
struct llama_context;
2122

2223
struct llama_memory_context_i;
2324

@@ -120,23 +121,26 @@ class llm_graph_input_embd : public llm_graph_input_i {
120121
const int64_t n_embd = 0;
121122
};
122123

123-
// Graph input for the trunk's pre-output-norm hidden state, passed into a
124-
// separate ctx_mtp from a target ctx via ggml_backend_tensor_copy_async on
125-
// each draft step. set_input is a no-op — the tensor data is populated
126-
// externally (the relay function), not from the ubatch.
124+
// Graph input for the trunk's pre-output-norm hidden state. Populated during
125+
// set_input by reading the staged source on the owning llama_context (set by
126+
// llama_mtp_relay_h{,_self}) and doing a device-to-device copy from those
127+
// rows into this->h. Going through set_input rather than an immediate copy in
128+
// the relay function is what lets the relay survive a graph rebuild between
129+
// different n_tokens (e.g. n_tokens=N prompt prefill → n_tokens=1 chain step):
130+
// the staging lives on the context, but the destination this->h is whatever
131+
// tensor the current graph just allocated.
127132
class llm_graph_input_h_pre_norm : public llm_graph_input_i {
128133
public:
129-
llm_graph_input_h_pre_norm(int64_t n_embd) : n_embd(n_embd) {}
134+
llm_graph_input_h_pre_norm(int64_t n_embd, llama_context * ctx_mtp)
135+
: n_embd(n_embd), ctx_mtp(ctx_mtp) {}
130136
virtual ~llm_graph_input_h_pre_norm() = default;
131137

132-
void set_input(const llama_ubatch * /*ubatch*/) override {
133-
// The h tensor is populated by the speculative wrapper before
134-
// llama_decode via ggml_backend_tensor_copy_async. Nothing to do here.
135-
}
138+
void set_input(const llama_ubatch * ubatch) override;
136139

137140
ggml_tensor * h = nullptr; // F32 [n_embd, n_batch]
138141

139-
const int64_t n_embd = 0;
142+
const int64_t n_embd = 0;
143+
llama_context * ctx_mtp = nullptr; // not owned; used to read staged source
140144
};
141145

142146
class llm_graph_input_pos : public llm_graph_input_i {
@@ -559,6 +563,11 @@ struct llm_graph_params {
559563
ggml_backend_sched_t sched;
560564
ggml_backend_t backend_cpu;
561565

566+
// Owning context. Currently only consumed by llm_graph_input_h_pre_norm,
567+
// which needs to read MTP h-staging state on every set_input. Not used by
568+
// can_reuse / allow_reuse — same context across decodes by construction.
569+
llama_context * ctx = nullptr;
570+
562571
const llama_adapter_cvec * cvec;
563572
const llama_adapter_loras * loras;
564573
const llama_memory_context_i * mctx;

src/models/qwen35_mtp.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ llm_build_qwen35_mtp::llm_build_qwen35_mtp(const llama_model & model, const llm_
4646
// input buffer. Populated externally (no ubatch source) by the speculative
4747
// wrapper via ggml_backend_tensor_copy_async, using ctx_target's
4848
// t_h_pre_norm graph output as the source.
49-
auto h_inp = std::make_unique<llm_graph_input_h_pre_norm>(hparams.n_embd);
49+
auto h_inp = std::make_unique<llm_graph_input_h_pre_norm>(hparams.n_embd, params.ctx);
5050
h_inp->h = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, hparams.n_embd, n_tokens);
5151
ggml_set_input(h_inp->h);
5252
ggml_set_name(h_inp->h, "mtp_h_input");

0 commit comments

Comments
 (0)