Skip to content

Commit 183a99c

Browse files
committed
mtp: copy correct row of t_h_pre_norm based on prior n_accepted
Real bug fix. Previously llama_mtp_relay_h copied the LAST row of ctx_target's t_h_pre_norm into ctx_mtp's t_inp_h. That is only correct when the verifier accepts ALL drafts in the previous round; on partial acceptance, the row whose hidden produced the next id_last is row n_accepted, not the last row. For a verify batch [sampled, d0, ..., d_{K-1}] at positions [p..p+K]: - bonus = verifier's sample at row n_accepted (rejected position, or the last row if all K drafts accepted) - next id_last lives at position p + n_accepted + 1 - MTP needs h at position p + n_accepted = ROW n_accepted of t_h_pre_norm The bug was invisible at K=1 in canonical paths (most rounds full- accept → row K-1 = last row = correct) but degraded acceptance whenever a draft was rejected. At K>=2, partial-accept dominates and MTP cascades on wrong h, collapsing acceptance to ~30%. Changes: - llama_mtp_relay_h signature: int32_t n_rows → int32_t src_row. Copies a single row at the specified index from src into row 0 of dst. Caller picks the row. - llama_mtp_relay_h_self unchanged in semantics — t_mtp_out has only the one row produced by the previous chain step's single-token decode. - common_speculative_state_mtp: track last_n_accepted (set by accept(), consumed by next draft()'s k=0 relay). begin() resets it to -1, which the relay maps to row 0 (only the prompt's last position is in the trunk's outputs after prefill). Measured on Qwen3.6-q8_0-mtp.gguf, send_req.sh (dense Python code, 400 tokens, temp=1, seed=42): before fix after fix K=1 84% accept, 11.4 tok/s 88% accept, 12.5 tok/s K=2 30% accept, 9.5 tok/s 86% accept, 16.9 tok/s (+78%) K=3 not viable 73% accept, 17.5 tok/s K=2 now matches vLLM's documented sweet spot for Qwen3.6 / DeepSeek MTP on code workloads. K=3 is a marginal win on top. Architecture confirmation: an independent walk of vLLM's chain code (SpecDecodeBaseProposer.propose, qwen3_5_mtp.forward) confirms vLLM's K>1 chain is a pure self-roll on the MTP block's post-residual hidden with hnorm reapplied each step — the same mechanism this codebase already implements; the only delta vs vLLM was the row-selection bug fixed here.
1 parent 17d47df commit 183a99c

3 files changed

Lines changed: 60 additions & 25 deletions

File tree

common/speculative.cpp

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,14 @@ struct common_speculative_state_mtp : public common_speculative_state {
634634
// uses this to compute how many trailing positions to roll back.
635635
uint16_t last_n_drafted = 0;
636636

637+
// # of drafts the verifier accepted on the most recent round. Used by the
638+
// NEXT draft()'s k=0 relay: the row of ctx_tgt's t_h_pre_norm whose hidden
639+
// produced the new id_last is exactly `last_n_accepted` (the bonus came
640+
// from that row's logits). Using the last row instead silently corrupts
641+
// MTP whenever last_n_accepted < n_drafts. -1 = first draft after begin(),
642+
// where ctx_tgt's t_h_pre_norm has only the prompt's last-position row.
643+
int32_t last_n_accepted = -1;
644+
637645
common_speculative_state_mtp(enum common_speculative_type type,
638646
llama_context * ctx_tgt,
639647
llama_context * ctx_mtp)
@@ -696,6 +704,7 @@ struct common_speculative_state_mtp : public common_speculative_state {
696704
LOG_WRN("%s: ctx_mtp seed decode rc=%d\n", __func__, rc);
697705
}
698706
mtp_pos = 1;
707+
last_n_accepted = -1; // signal "first draft of this generation"
699708

700709
GGML_UNUSED(prompt);
701710
}
@@ -717,13 +726,21 @@ struct common_speculative_state_mtp : public common_speculative_state {
717726
llama_token cond_tok = id_last;
718727

719728
for (int32_t k = 0; k < n_max; ++k) {
720-
// Stage h. Step 0: from ctx_tgt's t_h_pre_norm (the h corresponding
721-
// to id_last's position in the trunk). Step k>0: self-relay from
722-
// ctx_mtp's previous t_mtp_out (the MTP block's post-FFN h from
723-
// the previous chain step).
724-
const int32_t rc_relay = (k == 0)
725-
? llama_mtp_relay_h(ctx_tgt, ctx_mtp, /*n_rows=*/ 1)
726-
: llama_mtp_relay_h_self(ctx_mtp, /*n_rows=*/ 1);
729+
// Stage h. Step 0: from ctx_tgt's t_h_pre_norm at the row whose
730+
// hidden produced id_last. After a previous verify [sampled, d0,
731+
// ..., d_{K-1}] with `last_n_accepted` drafts accepted, the bonus
732+
// (= new id_last) was sampled from h at row `last_n_accepted`
733+
// (rows 0..K of t_h_pre_norm correspond to those K+1 positions).
734+
// For the very first draft of a generation (last_n_accepted=-1)
735+
// ctx_tgt only computed the prompt's last position → row 0.
736+
// Step k>0: self-relay from ctx_mtp's previous t_mtp_out.
737+
int32_t rc_relay;
738+
if (k == 0) {
739+
const int32_t src_row = (last_n_accepted < 0) ? 0 : last_n_accepted;
740+
rc_relay = llama_mtp_relay_h(ctx_tgt, ctx_mtp, src_row);
741+
} else {
742+
rc_relay = llama_mtp_relay_h_self(ctx_mtp, /*n_rows=*/ 1);
743+
}
727744
if (rc_relay != 0) {
728745
LOG_DBG("%s: relay rc=%d at k=%d; stopping chain\n", __func__, rc_relay, k);
729746
return;
@@ -774,6 +791,7 @@ struct common_speculative_state_mtp : public common_speculative_state {
774791
// right slots.
775792
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
776793
if (pos_max < 0) {
794+
last_n_accepted = (int32_t) n_accepted;
777795
return;
778796
}
779797
const int32_t n_drafted_last = (int32_t) last_n_drafted;
@@ -784,6 +802,9 @@ struct common_speculative_state_mtp : public common_speculative_state {
784802
/*p0=*/ drop_from, /*p1=*/ -1);
785803
}
786804
last_n_drafted = 0;
805+
// Record so the NEXT draft()'s k=0 relay knows which row of ctx_tgt's
806+
// t_h_pre_norm to copy.
807+
last_n_accepted = (int32_t) n_accepted;
787808
}
788809

789810
int32_t n_max(const common_params_speculative & params) const override {

include/llama.h

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -988,18 +988,24 @@ extern "C" {
988988
// hidden state plus a token batch to produce draft logits, with its own KV
989989
// cache populated by build_attn the same way any other layer's is.
990990
//
991-
// Copies the LAST n_rows of ctx_target's t_h_pre_norm into the FIRST n_rows
992-
// of ctx_mtp's t_inp_h. Typical use: n_rows=1 to feed a single-token MTP
993-
// draft step with the most recently produced h row. Both backends must be
994-
// able to issue a copy between each other (typical case: same device, fast
995-
// on-device copy).
991+
// Copies a single row at index `src_row` of ctx_target's t_h_pre_norm into
992+
// row 0 of ctx_mtp's t_inp_h. Both backends must be able to issue a copy
993+
// between each other (typical case: same device, fast on-device copy).
994+
//
995+
// The right `src_row` for MTP drafting is the row whose hidden produced the
996+
// verifier sample that becomes the next draft's id_last. After a verify
997+
// batch [sampled, d0, ..., d_{K-1}] with `n_accepted` drafts accepted, that
998+
// is `src_row = n_accepted` (the bonus token was sampled from h at row
999+
// n_accepted). Using the last row instead silently corrupts MTP whenever
1000+
// n_accepted < K; the bug is invisible at K=1 most of the time but tanks
1001+
// K>=2.
9961002
//
9971003
// Returns 0 on success; negative on error (e.g. ctx_target's last decode
998-
// didn't produce t_h_pre_norm, n_rows out of range, shape mismatch).
1004+
// didn't produce t_h_pre_norm, src_row out of range, shape mismatch).
9991005
LLAMA_API int32_t llama_mtp_relay_h(
10001006
struct llama_context * ctx_target,
10011007
struct llama_context * ctx_mtp,
1002-
int32_t n_rows);
1008+
int32_t src_row);
10031009

10041010
// Self-relay: copy the LAST n_rows of ctx_mtp's most recent t_mtp_out
10051011
// (the MTP block's post-FFN hidden) into the FIRST n_rows of its own

src/llama-context.cpp

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3118,15 +3118,15 @@ ggml_tensor * llama_context::get_t_inp_h() const {
31183118
return nullptr;
31193119
}
31203120

3121-
// Common implementation: copy the LAST n_rows of `src` into the FIRST n_rows
3121+
// Common implementation: copy a single row at `src_row` of `src` into row 0
31223122
// of `dst`, on-device via ggml_backend_tensor_copy_async. ctx_src/ctx_dst are
31233123
// used to look up backends per tensor and to synchronize the source.
31243124
static int32_t llama_mtp_relay_impl(
31253125
struct llama_context * ctx_src,
31263126
struct llama_context * ctx_dst,
31273127
ggml_tensor * src,
31283128
ggml_tensor * dst,
3129-
int32_t n_rows,
3129+
int32_t src_row,
31303130
const char * fn) {
31313131
if (!src) {
31323132
LLAMA_LOG_ERROR("%s: src tensor missing\n", fn);
@@ -3141,9 +3141,9 @@ static int32_t llama_mtp_relay_impl(
31413141
fn, src->ne[0], dst->ne[0]);
31423142
return -4;
31433143
}
3144-
if (n_rows <= 0 || n_rows > src->ne[1] || n_rows > dst->ne[1]) {
3145-
LLAMA_LOG_ERROR("%s: n_rows=%d out of range (src cap=%" PRId64 ", dst cap=%" PRId64 ")\n",
3146-
fn, n_rows, src->ne[1], dst->ne[1]);
3144+
if (src_row < 0 || src_row >= src->ne[1] || dst->ne[1] < 1) {
3145+
LLAMA_LOG_ERROR("%s: src_row=%d out of range (src cap=%" PRId64 ", dst cap=%" PRId64 ")\n",
3146+
fn, src_row, src->ne[1], dst->ne[1]);
31473147
return -5;
31483148
}
31493149

@@ -3155,8 +3155,8 @@ static int32_t llama_mtp_relay_impl(
31553155
// only forwards view->data + offset), so wire the buffer manually before
31563156
// passing the views to copy_async.
31573157
const size_t row_size = src->nb[1];
3158-
const int32_t src_first = (int32_t) src->ne[1] - n_rows; // last n_rows of src
3159-
const size_t src_offset = (size_t) src_first * row_size;
3158+
const int32_t n_rows = 1;
3159+
const size_t src_offset = (size_t) src_row * row_size;
31603160

31613161
ggml_context_ptr view_ctx;
31623162
{
@@ -3195,14 +3195,14 @@ static int32_t llama_mtp_relay_impl(
31953195
int32_t llama_mtp_relay_h(
31963196
struct llama_context * ctx_target,
31973197
struct llama_context * ctx_mtp,
3198-
int32_t n_rows) {
3198+
int32_t src_row) {
31993199
if (!ctx_target || !ctx_mtp) {
32003200
return -1;
32013201
}
32023202
return llama_mtp_relay_impl(ctx_target, ctx_mtp,
32033203
ctx_target->get_t_h_pre_norm(),
32043204
ctx_mtp->get_t_inp_h(),
3205-
n_rows, __func__);
3205+
src_row, __func__);
32063206
}
32073207

32083208
int32_t llama_mtp_relay_h_self(
@@ -3211,10 +3211,18 @@ int32_t llama_mtp_relay_h_self(
32113211
if (!ctx_mtp) {
32123212
return -1;
32133213
}
3214+
// Self-relay: t_mtp_out has shape [n_embd, n_tokens] from the previous
3215+
// single-token decode, so n_tokens=1 and the only row is 0.
3216+
GGML_UNUSED(n_rows);
3217+
ggml_tensor * src = ctx_mtp->get_t_mtp_out();
3218+
if (!src) {
3219+
return -2;
3220+
}
3221+
const int32_t src_row = (int32_t) src->ne[1] - 1;
32143222
return llama_mtp_relay_impl(ctx_mtp, ctx_mtp,
3215-
ctx_mtp->get_t_mtp_out(),
3223+
src,
32163224
ctx_mtp->get_t_inp_h(),
3217-
n_rows, __func__);
3225+
src_row, __func__);
32183226
}
32193227

32203228
void llama_synchronize(llama_context * ctx) {

0 commit comments

Comments
 (0)