@@ -846,11 +846,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
846846 std::vector<std::vector<float >> verify_h;
847847 std::vector<int32_t > verify_h_rows;
848848
849- // Per-seq draft length from the last draft() call, used in accept() to
850- // roll back ctx_dft's recurrent state past the AR draft's redundant
851- // pre-advancement before process() mirrored the verify batch.
852- std::vector<uint16_t > last_n_drafted;
853-
854849 common_speculative_impl_draft_mtp (const common_params_speculative & params, uint32_t n_seq)
855850 : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP , n_seq)
856851 , params(params.draft)
@@ -925,8 +920,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
925920
926921 verify_h.assign (n_seq, {});
927922 verify_h_rows.assign (n_seq, 0 );
928-
929- last_n_drafted.assign (n_seq, 0 );
930923 }
931924
932925 ~common_speculative_impl_draft_mtp () override {
@@ -1093,11 +1086,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
10931086 void draft (common_speculative_draft_params_vec & dparams) override {
10941087 auto & ctx_dft = params.ctx_dft ;
10951088
1096- if (chain_heads) {
1097- draft_multi_head (dparams);
1098- return ;
1099- }
1100-
11011089 common_batch_clear (batch);
11021090
11031091 // keep track of which sequences are still drafting
@@ -1107,6 +1095,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11071095 const float * h_row = nullptr ;
11081096 const size_t row_bytes = (size_t ) n_embd * sizeof (float );
11091097
1098+ // chained heads accumulate the prefix into the batch across draft
1099+ // steps instead of rebuilding it, so each sequence samples from its last
1100+ // appended slot; i_last tracks that slot.
1101+ std::vector<int > i_last (n_seq, -1 );
1102+
11101103 for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
11111104 auto & dp = dparams[seq_id];
11121105
@@ -1122,20 +1115,38 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11221115
11231116 h_row = pending_h[seq_id].data ();
11241117 std::memcpy (batch.embd + n_embd*(batch.n_tokens - 1 ), h_row, row_bytes);
1125- }
11261118
1127- int ret = llama_decode (ctx_dft, batch);
1128- if (ret != 0 ) {
1129- LOG_WRN (" %s: llama_decode returned %d\n " , __func__, ret);
1130- return ;
1119+ i_last[seq_id] = batch.n_tokens - 1 ;
11311120 }
11321121
11331122 int i = 0 ;
11341123
11351124 while (n_drafting > 0 ) {
1136- int i_batch = 0 ;
1125+ // chained heads: every trained head keeps its own KV, so reset
1126+ // each sequence's draft region and select head `i` before re-decoding the
1127+ // accumulated prefix (head 0 sees just id_last). the single-head (qwen)
1128+ // and mem-shared (gemma4) paths leave their growing KV untouched.
1129+ if (chain_heads) {
1130+ auto * mem_dft = llama_get_memory (ctx_dft);
1131+ for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1132+ if (drafting[seq_id]) {
1133+ llama_memory_seq_rm (mem_dft, seq_id, dparams[seq_id].n_past , -1 );
1134+ }
1135+ }
1136+ llama_set_mtp_layer_offset (ctx_dft, i);
1137+ }
11371138
1138- common_batch_clear (batch);
1139+ int ret = llama_decode (ctx_dft, batch);
1140+ if (ret != 0 ) {
1141+ LOG_WRN (" %s: llama_decode[%d] returned %d\n " , __func__, i, ret);
1142+ break ;
1143+ }
1144+
1145+ // the growing-KV paths rebuild the batch with only the new tokens (the KV
1146+ // already holds the prefix); chained heads keep accumulating into it
1147+ if (!chain_heads) {
1148+ common_batch_clear (batch);
1149+ }
11391150
11401151 for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
11411152 if (!drafting[seq_id]) {
@@ -1144,9 +1155,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11441155
11451156 auto * smpl = smpls[seq_id].get ();
11461157
1147- common_sampler_sample (smpl, ctx_dft, i_batch, true );
1148- h_row = llama_get_embeddings_nextn_ith (ctx_dft, i_batch);
1149- ++i_batch;
1158+ common_sampler_sample (smpl, ctx_dft, i_last[seq_id], true );
1159+ h_row = llama_get_embeddings_nextn_ith (ctx_dft, i_last[seq_id]);
11501160
11511161 const auto * cur_p = common_sampler_get_candidates (smpl, true );
11521162
@@ -1180,6 +1190,12 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11801190 continue ;
11811191 }
11821192
1193+ // chained heads: only the newly appended slot keeps its logits, so the
1194+ // slot we just sampled stops producing an output on the next head
1195+ if (chain_heads) {
1196+ batch.logits [i_last[seq_id]] = false ;
1197+ }
1198+
11831199 if (is_mem_shared) {
11841200 // note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
11851201 // ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
@@ -1188,154 +1204,27 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11881204 common_batch_add (batch, id, dp.n_past + i + 1 , { seq_id }, true );
11891205 }
11901206 std::memcpy (batch.embd + n_embd*(batch.n_tokens - 1 ), h_row, row_bytes);
1191- }
11921207
1193- if (batch.n_tokens == 0 ) {
1194- break ;
1208+ i_last[seq_id] = batch.n_tokens - 1 ;
11951209 }
11961210
1197- // evaluate the drafted tokens on the draft model
1198- ret = llama_decode (ctx_dft, batch);
1199- if (ret != 0 ) {
1200- LOG_WRN (" %s: llama_decode[%d] returned %d\n " , __func__, i, ret);
1211+ if (batch.n_tokens == 0 ) {
12011212 break ;
12021213 }
12031214
12041215 ++i;
12051216 }
12061217
1207- for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1208- auto & dp = dparams[seq_id];
1209- if (!dp.drafting ) {
1210- continue ;
1211- }
1212-
1213- if (dp.result ->size () < (size_t ) params.n_min ) {
1214- dp.result ->clear ();
1215- }
1216-
1217- last_n_drafted[seq_id] = (uint16_t ) dp.result ->size ();
1218+ if (chain_heads) {
1219+ llama_set_mtp_layer_offset (ctx_dft, 0 ); // restore default for non-draft decodes
12181220 }
1219- }
1220-
1221- // Multi-head MTP draft: chain heads 0..n_mtp_layers-1. Step s runs head s on
1222- // the accumulated prefix [id_last, draft_1, .., draft_s] and samples draft_{s+1}.
1223- // Each slot's embd is the hidden produced by the PREVIOUS head for that token
1224- // (slot 0 is always pending_h = trunk h). Per-step seq_rm keeps each head's KV
1225- // on a clean, position-aligned slot set.
1226- void draft_multi_head (common_speculative_draft_params_vec & dparams) {
1227- auto * ctx_dft = params.ctx_dft ;
1228- auto * mem_dft = llama_get_memory (ctx_dft);
1229- const size_t row_bytes = (size_t ) n_embd * sizeof (float );
12301221
1231- // Per-seq accumulated draft state. Positions are implicit: slot k always
1232- // sits at round_start + k (the prefix is contiguous from id_last), so we
1233- // don't store a parallel positions vector.
1234- struct seq_state {
1235- bool active = false ;
1236- llama_pos round_start = 0 ; // pos of id_last (slot 0)
1237- std::vector<llama_token> toks; // [id_last, draft_1, ...]
1238- std::vector<std::vector<float >> slot_h; // embd per slot (toks[k] <-> slot_h[k])
1239- };
1240- std::vector<seq_state> st (n_seq);
1241-
1242- int n_drafting = 0 ;
12431222 for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
12441223 auto & dp = dparams[seq_id];
12451224 if (!dp.drafting ) {
12461225 continue ;
12471226 }
1248- common_sampler_reset (smpls[seq_id].get ());
1249- st[seq_id].active = true ;
1250- st[seq_id].round_start = dp.n_past ;
1251- st[seq_id].toks .push_back (dp.id_last );
1252- st[seq_id].slot_h .push_back (pending_h[seq_id]); // slot 0 = trunk pending_h
1253- n_drafting++;
1254- }
1255- if (n_drafting == 0 ) {
1256- return ;
1257- }
1258-
1259- const int n_steps = n_mtp_layers; // one head per step, capped at head count
1260-
1261- for (int step = 0 ; step < n_steps && n_drafting > 0 ; ++step) {
1262- // 1) per-seq KV reset for this head + select the head's layer.
1263- for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1264- if (!st[seq_id].active ) {
1265- continue ;
1266- }
1267- llama_memory_seq_rm (mem_dft, seq_id, st[seq_id].round_start , -1 );
1268- }
1269- llama_set_mtp_layer_offset (ctx_dft, step);
1270-
1271- // 2) build accumulated batch; logits only on each seq's last slot.
1272- common_batch_clear (batch);
1273- std::vector<int > last_i_batch (n_seq, -1 );
1274- for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1275- if (!st[seq_id].active ) {
1276- continue ;
1277- }
1278- auto & s = st[seq_id];
1279- const int len = (int ) s.toks .size ();
1280- for (int k = 0 ; k < len; ++k) {
1281- const bool is_last = (k == len - 1 );
1282- common_batch_add (batch, s.toks [k], s.round_start + k, { seq_id }, is_last);
1283- std::memcpy (batch.embd + (size_t ) (batch.n_tokens - 1 ) * n_embd,
1284- s.slot_h [k].data (), row_bytes);
1285- if (is_last) last_i_batch[seq_id] = batch.n_tokens - 1 ;
1286- }
1287- }
1288-
1289- const int rc = llama_decode (ctx_dft, batch);
1290- if (rc != 0 ) {
1291- LOG_WRN (" %s: multi-head draft decode step=%d rc=%d\n " , __func__, step, rc);
1292- break ;
1293- }
1294-
1295- // 3) sample next draft token per active seq; extend prefix + slot_h.
1296- for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1297- if (!st[seq_id].active ) {
1298- continue ;
1299- }
1300- auto & s = st[seq_id];
1301- auto & dp = dparams[seq_id];
1302- const int ib = last_i_batch[seq_id];
1303-
1304- auto * smpl = smpls[seq_id].get ();
1305- common_sampler_sample (smpl, ctx_dft, ib, true );
1306- const float * h_row = llama_get_embeddings_nextn_ith (ctx_dft, ib);
1307- const auto * cur_p = common_sampler_get_candidates (smpl, true );
1308- const llama_token id = cur_p->data [0 ].id ;
1309-
1310- if (cur_p->data [0 ].p < params.p_min ) {
1311- s.active = false ;
1312- n_drafting--;
1313- continue ;
1314- }
1315-
1316- common_sampler_accept (smpl, id, true );
1317- dp.result ->push_back (id);
1318-
1319- if ((int ) dp.result ->size () >= params.n_max ) {
1320- s.active = false ;
1321- n_drafting--;
1322- continue ;
1323- }
1324-
1325- // next slot's token + its embd (= this head's output at last slot).
1326- // Its position is implicit: round_start + toks.size().
1327- s.toks .push_back (id);
1328- s.slot_h .emplace_back (h_row, h_row + n_embd);
1329- }
1330- }
13311227
1332- llama_set_mtp_layer_offset (ctx_dft, 0 ); // restore default
1333-
1334- for (llama_seq_id seq_id = 0 ; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1335- auto & dp = dparams[seq_id];
1336- if (!dp.drafting ) {
1337- continue ;
1338- }
13391228 if (dp.result ->size () < (size_t ) params.n_min ) {
13401229 dp.result ->clear ();
13411230 }
0 commit comments