@@ -905,7 +905,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
905905
906906 int32_t n_embd = 0 ;
907907
908- bool is_mem_shared = false ;
908+ // One MTP draft driver, three modes (set once in the ctor):
909+ // is_mem_shared (gemma4): shares the target KV, runs all heads in one graph.
910+ // chain_heads (step35): n_mtp_layers trained heads, one per draft step.
911+ // neither (qwen35 / qwen35moe): a single trained MTP head.
912+ int32_t n_mtp_layers = 1 ;
913+ bool is_mem_shared = false ; // gemma4
914+ bool chain_heads = false ; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared
909915
910916 // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
911917 // The last h-row of one process() call needs the first token of the NEXT
@@ -920,10 +926,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
920926 std::vector<std::vector<float >> verify_h;
921927 std::vector<int32_t > verify_h_rows;
922928
923- // Per-seq draft length from the last draft() call, used in accept() to
924- // roll back ctx_dft's recurrent state past the AR draft's redundant
925- // pre-advancement before process() mirrored the verify batch.
926- std::vector<uint16_t > last_n_drafted;
929+ std::vector<int > i_last;
930+ std::vector<std::vector<float >> chain_h;
927931
928932 common_speculative_impl_draft_mtp (const common_params_speculative & params, uint32_t n_seq)
929933 : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP , n_seq)
@@ -936,6 +940,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
936940 n_embd = llama_model_n_embd_out (llama_get_model (ctx_dft));
937941 GGML_ASSERT (n_embd == llama_model_n_embd (llama_get_model (ctx_tgt)) &&
938942 " MTP input row width must match the target h_nextn width" );
943+ n_mtp_layers = std::max (1 , (int ) llama_model_n_layer_nextn (llama_get_model (ctx_dft)));
939944
940945 LOG_INF (" %s: adding speculative implementation 'draft-mtp'\n " , __func__);
941946 LOG_INF (" %s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n " , __func__, this ->params .n_max , this ->params .n_min , this ->params .p_min , n_embd, (int ) this ->params .backend_sampling );
@@ -982,16 +987,25 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
982987 llama_set_embeddings_nextn (ctx_dft, true , /* masked*/ true );
983988
984989 is_mem_shared = llama_get_ctx_other (ctx_dft) == ctx_tgt;
990+ chain_heads = n_mtp_layers > 1 && !is_mem_shared;
991+
992+ if (chain_heads) {
993+ this ->params .n_max = std::min (this ->params .n_max , n_mtp_layers);
994+
995+ chain_h.assign (n_seq, {});
996+ for (auto & c : chain_h) {
997+ c.reserve ((size_t ) (this ->params .n_max + 1 ) * n_embd);
998+ }
999+ }
9851000
9861001 pending_h.assign (n_seq, std::vector<float >(n_embd, 0 .0f ));
9871002
1003+ i_last.assign (n_seq, -1 );
9881004 i_batch_beg.assign (n_seq, -1 );
9891005 i_batch_end.assign (n_seq, -1 );
9901006
9911007 verify_h.assign (n_seq, {});
9921008 verify_h_rows.assign (n_seq, 0 );
993-
994- last_n_drafted.assign (n_seq, 0 );
9951009 }
9961010
9971011 ~common_speculative_impl_draft_mtp () override {
@@ -1097,9 +1111,34 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
10971111 set_h (i_batch_beg[seq_id], pending_h[seq_id].data ());
10981112 }
10991113
1100- const int32_t rc = llama_decode (ctx_dft, batch);
1101- if (rc != 0 ) {
1102- LOG_ERR (" %s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n " , __func__, (int ) rc, (int ) batch_in.pos [0 ]);
1114+ auto * mem_dft = llama_get_memory (ctx_dft);
1115+
1116+ bool ok = true ;
1117+ for (int head = 0 ; head < n_mtp_layers; ++head) {
1118+ if (chain_heads) {
1119+ // ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544
1120+ for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1121+ if (i_batch_beg[seq_id] < 0 ) {
1122+ continue ;
1123+ }
1124+ llama_memory_seq_rm (mem_dft, seq_id, batch_in.pos [i_batch_beg[seq_id]], -1 );
1125+ }
1126+ llama_set_nextn_layer_offset (ctx_dft, head);
1127+ }
1128+
1129+ const int32_t rc = llama_decode (ctx_dft, batch);
1130+ if (rc != 0 ) {
1131+ LOG_ERR (" %s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n " ,
1132+ __func__, head, (int ) rc, (int ) batch_in.pos [0 ]);
1133+ ok = false ;
1134+ break ;
1135+ }
1136+ }
1137+
1138+ if (chain_heads) {
1139+ llama_set_nextn_layer_offset (ctx_dft, 0 ); // restore default for non-draft decodes
1140+ }
1141+ if (!ok) {
11031142 return false ;
11041143 }
11051144 }
@@ -1134,7 +1173,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11341173 int n_drafting = 0 ;
11351174 std::vector<bool > drafting (n_seq);
11361175
1137- const float * h_row = nullptr ;
11381176 const size_t row_bytes = (size_t ) n_embd * sizeof (float );
11391177
11401178 for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@@ -1149,22 +1187,43 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11491187 common_sampler_reset (smpls[seq_id].get ());
11501188
11511189 common_batch_add (batch, dp.id_last , dp.n_past , { seq_id }, true );
1190+ std::memcpy (batch.embd + (size_t ) (batch.n_tokens - 1 ) * n_embd, pending_h[seq_id].data (), row_bytes);
11521191
1153- h_row = pending_h[seq_id].data ();
1154- std::memcpy (batch.embd + n_embd*(batch.n_tokens - 1 ), h_row, row_bytes);
1155- }
1192+ i_last[seq_id] = batch.n_tokens - 1 ;
11561193
1157- int ret = llama_decode (ctx_dft, batch);
1158- if (ret != 0 ) {
1159- LOG_WRN (" %s: llama_decode returned %d\n " , __func__, ret);
1160- return ;
1194+ if (chain_heads) {
1195+ chain_h[seq_id].assign (pending_h[seq_id].begin (), pending_h[seq_id].end ());
1196+ }
11611197 }
11621198
11631199 int i = 0 ;
11641200
11651201 while (n_drafting > 0 ) {
1166- int i_batch = 0 ;
1202+ // each step decodes under a different head, i.e. a different decoder layer, and
1203+ // KV is per layer. process() filled this layer's KV only for positions < n_past
1204+ // (prompt + accepted prefix) — nothing in the draft region yet. so reset the
1205+ // draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact)
1206+ // and select head i so it rebuilds its own layer's KV there; decoding just the
1207+ // latest token would leave its attention reading cells only another head wrote.
1208+ if (chain_heads) {
1209+ auto * mem_dft = llama_get_memory (ctx_dft);
1210+ for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1211+ if (drafting[seq_id]) {
1212+ llama_memory_seq_rm (mem_dft, seq_id, dparams[seq_id].n_past , -1 );
1213+ }
1214+ }
1215+ llama_set_nextn_layer_offset (ctx_dft, i);
1216+ }
11671217
1218+ int ret = llama_decode (ctx_dft, batch);
1219+ if (ret != 0 ) {
1220+ LOG_WRN (" %s: llama_decode[%d] returned %d\n " , __func__, i, ret);
1221+ break ;
1222+ }
1223+
1224+ // rebuild the batch for the next step: the growing-KV paths re-add only the
1225+ // new token (the KV already holds the prefix), while chained heads re-add the
1226+ // whole prefix at the next head. dropped sequences are simply not re-added.
11681227 common_batch_clear (batch);
11691228
11701229 for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@@ -1174,9 +1233,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11741233
11751234 auto * smpl = smpls[seq_id].get ();
11761235
1177- common_sampler_sample (smpl, ctx_dft, i_batch, true );
1178- h_row = llama_get_embeddings_nextn_ith (ctx_dft, i_batch);
1179- ++i_batch;
1236+ common_sampler_sample (smpl, ctx_dft, i_last[seq_id], true );
1237+ const float * h_row = llama_get_embeddings_nextn_ith (ctx_dft, i_last[seq_id]);
11801238
11811239 const auto * cur_p = common_sampler_get_candidates (smpl, true );
11821240
@@ -1210,30 +1268,41 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
12101268 continue ;
12111269 }
12121270
1213- if (is_mem_shared) {
1271+ if (chain_heads) {
1272+ // ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546
1273+ chain_h[seq_id].insert (chain_h[seq_id].end (), h_row, h_row + n_embd);
1274+
1275+ const int n_rows = (int ) result.size () + 1 ; // id_last + tokens drafted so far
1276+ for (int t = 0 ; t < n_rows; ++t) {
1277+ const llama_token tok = (t == 0 ) ? dp.id_last : result[t - 1 ];
1278+ common_batch_add (batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1 );
1279+ std::memcpy (batch.embd + (size_t ) (batch.n_tokens - 1 ) * n_embd,
1280+ chain_h[seq_id].data () + (size_t ) t * n_embd, row_bytes);
1281+ }
1282+ } else if (is_mem_shared) {
12141283 // note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
12151284 // ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
12161285 common_batch_add (batch, id, dp.n_past , { seq_id }, true );
1286+ std::memcpy (batch.embd + (size_t ) (batch.n_tokens - 1 ) * n_embd, h_row, row_bytes);
12171287 } else {
12181288 common_batch_add (batch, id, dp.n_past + i + 1 , { seq_id }, true );
1289+ std::memcpy (batch.embd + (size_t ) (batch.n_tokens - 1 ) * n_embd, h_row, row_bytes);
12191290 }
1220- std::memcpy (batch.embd + n_embd*(batch.n_tokens - 1 ), h_row, row_bytes);
1221- }
12221291
1223- if (batch.n_tokens == 0 ) {
1224- break ;
1292+ i_last[seq_id] = batch.n_tokens - 1 ;
12251293 }
12261294
1227- // evaluate the drafted tokens on the draft model
1228- ret = llama_decode (ctx_dft, batch);
1229- if (ret != 0 ) {
1230- LOG_WRN (" %s: llama_decode[%d] returned %d\n " , __func__, i, ret);
1295+ if (batch.n_tokens == 0 ) {
12311296 break ;
12321297 }
12331298
12341299 ++i;
12351300 }
12361301
1302+ if (chain_heads) {
1303+ llama_set_nextn_layer_offset (ctx_dft, 0 ); // restore default for non-draft decodes
1304+ }
1305+
12371306 for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
12381307 auto & dp = dparams[seq_id];
12391308 if (!dp.drafting ) {
@@ -1243,8 +1312,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
12431312 if (dp.result ->size () < (size_t ) params.n_min ) {
12441313 dp.result ->clear ();
12451314 }
1246-
1247- last_n_drafted[seq_id] = (uint16_t ) dp.result ->size ();
12481315 }
12491316 }
12501317
@@ -1857,7 +1924,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
18571924
18581925 bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE ));
18591926 bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3 )) && params.draft .ctx_dft != nullptr ;
1860- bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP )) && params.draft .ctx_dft != nullptr ;
1927+ bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP )) && params.draft .ctx_dft != nullptr ;
18611928
18621929
18631930
@@ -1895,7 +1962,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
18951962 if (has_draft_eagle3) {
18961963 configs.push_back (common_speculative_config (COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3 , params));
18971964 }
1898- if (has_mtp ) {
1965+ if (has_draft_mtp ) {
18991966 configs.push_back (common_speculative_config (COMMON_SPECULATIVE_TYPE_DRAFT_MTP , params));
19001967 }
19011968 }
0 commit comments