Skip to content

Commit 2952d83

Browse files
committed
merged draft_multi_head into draft()
1 parent 9a0ff26 commit 2952d83

1 file changed

Lines changed: 42 additions & 153 deletions

File tree

common/speculative.cpp

Lines changed: 42 additions & 153 deletions
Original file line numberDiff line numberDiff line change
@@ -846,11 +846,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
846846
std::vector<std::vector<float>> verify_h;
847847
std::vector<int32_t> verify_h_rows;
848848

849-
// Per-seq draft length from the last draft() call, used in accept() to
850-
// roll back ctx_dft's recurrent state past the AR draft's redundant
851-
// pre-advancement before process() mirrored the verify batch.
852-
std::vector<uint16_t> last_n_drafted;
853-
854849
common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
855850
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
856851
, params(params.draft)
@@ -925,8 +920,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
925920

926921
verify_h.assign(n_seq, {});
927922
verify_h_rows.assign(n_seq, 0);
928-
929-
last_n_drafted.assign(n_seq, 0);
930923
}
931924

932925
~common_speculative_impl_draft_mtp() override {
@@ -1093,11 +1086,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
10931086
void draft(common_speculative_draft_params_vec & dparams) override {
10941087
auto & ctx_dft = params.ctx_dft;
10951088

1096-
if (chain_heads) {
1097-
draft_multi_head(dparams);
1098-
return;
1099-
}
1100-
11011089
common_batch_clear(batch);
11021090

11031091
// keep track of which sequences are still drafting
@@ -1107,6 +1095,11 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11071095
const float * h_row = nullptr;
11081096
const size_t row_bytes = (size_t) n_embd * sizeof(float);
11091097

1098+
// chained heads accumulate the prefix into the batch across draft
1099+
// steps instead of rebuilding it, so each sequence samples from its last
1100+
// appended slot; i_last tracks that slot.
1101+
std::vector<int> i_last(n_seq, -1);
1102+
11101103
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
11111104
auto & dp = dparams[seq_id];
11121105

@@ -1122,20 +1115,38 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11221115

11231116
h_row = pending_h[seq_id].data();
11241117
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
1125-
}
11261118

1127-
int ret = llama_decode(ctx_dft, batch);
1128-
if (ret != 0) {
1129-
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
1130-
return;
1119+
i_last[seq_id] = batch.n_tokens - 1;
11311120
}
11321121

11331122
int i = 0;
11341123

11351124
while (n_drafting > 0) {
1136-
int i_batch = 0;
1125+
// chained heads: every trained head keeps its own KV, so reset
1126+
// each sequence's draft region and select head `i` before re-decoding the
1127+
// accumulated prefix (head 0 sees just id_last). the single-head (qwen)
1128+
// and mem-shared (gemma4) paths leave their growing KV untouched.
1129+
if (chain_heads) {
1130+
auto * mem_dft = llama_get_memory(ctx_dft);
1131+
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1132+
if (drafting[seq_id]) {
1133+
llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1);
1134+
}
1135+
}
1136+
llama_set_mtp_layer_offset(ctx_dft, i);
1137+
}
11371138

1138-
common_batch_clear(batch);
1139+
int ret = llama_decode(ctx_dft, batch);
1140+
if (ret != 0) {
1141+
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
1142+
break;
1143+
}
1144+
1145+
// the growing-KV paths rebuild the batch with only the new tokens (the KV
1146+
// already holds the prefix); chained heads keep accumulating into it
1147+
if (!chain_heads) {
1148+
common_batch_clear(batch);
1149+
}
11391150

11401151
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
11411152
if (!drafting[seq_id]) {
@@ -1144,9 +1155,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11441155

11451156
auto * smpl = smpls[seq_id].get();
11461157

1147-
common_sampler_sample(smpl, ctx_dft, i_batch, true);
1148-
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
1149-
++i_batch;
1158+
common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true);
1159+
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]);
11501160

11511161
const auto * cur_p = common_sampler_get_candidates(smpl, true);
11521162

@@ -1180,6 +1190,12 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11801190
continue;
11811191
}
11821192

1193+
// chained heads: only the newly appended slot keeps its logits, so the
1194+
// slot we just sampled stops producing an output on the next head
1195+
if (chain_heads) {
1196+
batch.logits[i_last[seq_id]] = false;
1197+
}
1198+
11831199
if (is_mem_shared) {
11841200
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
11851201
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
@@ -1188,154 +1204,27 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11881204
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
11891205
}
11901206
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
1191-
}
11921207

1193-
if (batch.n_tokens == 0) {
1194-
break;
1208+
i_last[seq_id] = batch.n_tokens - 1;
11951209
}
11961210

1197-
// evaluate the drafted tokens on the draft model
1198-
ret = llama_decode(ctx_dft, batch);
1199-
if (ret != 0) {
1200-
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
1211+
if (batch.n_tokens == 0) {
12011212
break;
12021213
}
12031214

12041215
++i;
12051216
}
12061217

1207-
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1208-
auto & dp = dparams[seq_id];
1209-
if (!dp.drafting) {
1210-
continue;
1211-
}
1212-
1213-
if (dp.result->size() < (size_t) params.n_min) {
1214-
dp.result->clear();
1215-
}
1216-
1217-
last_n_drafted[seq_id] = (uint16_t) dp.result->size();
1218+
if (chain_heads) {
1219+
llama_set_mtp_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
12181220
}
1219-
}
1220-
1221-
// Multi-head MTP draft: chain heads 0..n_mtp_layers-1. Step s runs head s on
1222-
// the accumulated prefix [id_last, draft_1, .., draft_s] and samples draft_{s+1}.
1223-
// Each slot's embd is the hidden produced by the PREVIOUS head for that token
1224-
// (slot 0 is always pending_h = trunk h). Per-step seq_rm keeps each head's KV
1225-
// on a clean, position-aligned slot set.
1226-
void draft_multi_head(common_speculative_draft_params_vec & dparams) {
1227-
auto * ctx_dft = params.ctx_dft;
1228-
auto * mem_dft = llama_get_memory(ctx_dft);
1229-
const size_t row_bytes = (size_t) n_embd * sizeof(float);
12301221

1231-
// Per-seq accumulated draft state. Positions are implicit: slot k always
1232-
// sits at round_start + k (the prefix is contiguous from id_last), so we
1233-
// don't store a parallel positions vector.
1234-
struct seq_state {
1235-
bool active = false;
1236-
llama_pos round_start = 0; // pos of id_last (slot 0)
1237-
std::vector<llama_token> toks; // [id_last, draft_1, ...]
1238-
std::vector<std::vector<float>> slot_h; // embd per slot (toks[k] <-> slot_h[k])
1239-
};
1240-
std::vector<seq_state> st(n_seq);
1241-
1242-
int n_drafting = 0;
12431222
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
12441223
auto & dp = dparams[seq_id];
12451224
if (!dp.drafting) {
12461225
continue;
12471226
}
1248-
common_sampler_reset(smpls[seq_id].get());
1249-
st[seq_id].active = true;
1250-
st[seq_id].round_start = dp.n_past;
1251-
st[seq_id].toks.push_back(dp.id_last);
1252-
st[seq_id].slot_h.push_back(pending_h[seq_id]); // slot 0 = trunk pending_h
1253-
n_drafting++;
1254-
}
1255-
if (n_drafting == 0) {
1256-
return;
1257-
}
1258-
1259-
const int n_steps = n_mtp_layers; // one head per step, capped at head count
1260-
1261-
for (int step = 0; step < n_steps && n_drafting > 0; ++step) {
1262-
// 1) per-seq KV reset for this head + select the head's layer.
1263-
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1264-
if (!st[seq_id].active) {
1265-
continue;
1266-
}
1267-
llama_memory_seq_rm(mem_dft, seq_id, st[seq_id].round_start, -1);
1268-
}
1269-
llama_set_mtp_layer_offset(ctx_dft, step);
1270-
1271-
// 2) build accumulated batch; logits only on each seq's last slot.
1272-
common_batch_clear(batch);
1273-
std::vector<int> last_i_batch(n_seq, -1);
1274-
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1275-
if (!st[seq_id].active) {
1276-
continue;
1277-
}
1278-
auto & s = st[seq_id];
1279-
const int len = (int) s.toks.size();
1280-
for (int k = 0; k < len; ++k) {
1281-
const bool is_last = (k == len - 1);
1282-
common_batch_add(batch, s.toks[k], s.round_start + k, { seq_id }, is_last);
1283-
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd,
1284-
s.slot_h[k].data(), row_bytes);
1285-
if (is_last) last_i_batch[seq_id] = batch.n_tokens - 1;
1286-
}
1287-
}
1288-
1289-
const int rc = llama_decode(ctx_dft, batch);
1290-
if (rc != 0) {
1291-
LOG_WRN("%s: multi-head draft decode step=%d rc=%d\n", __func__, step, rc);
1292-
break;
1293-
}
1294-
1295-
// 3) sample next draft token per active seq; extend prefix + slot_h.
1296-
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1297-
if (!st[seq_id].active) {
1298-
continue;
1299-
}
1300-
auto & s = st[seq_id];
1301-
auto & dp = dparams[seq_id];
1302-
const int ib = last_i_batch[seq_id];
1303-
1304-
auto * smpl = smpls[seq_id].get();
1305-
common_sampler_sample(smpl, ctx_dft, ib, true);
1306-
const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, ib);
1307-
const auto * cur_p = common_sampler_get_candidates(smpl, true);
1308-
const llama_token id = cur_p->data[0].id;
1309-
1310-
if (cur_p->data[0].p < params.p_min) {
1311-
s.active = false;
1312-
n_drafting--;
1313-
continue;
1314-
}
1315-
1316-
common_sampler_accept(smpl, id, true);
1317-
dp.result->push_back(id);
1318-
1319-
if ((int) dp.result->size() >= params.n_max) {
1320-
s.active = false;
1321-
n_drafting--;
1322-
continue;
1323-
}
1324-
1325-
// next slot's token + its embd (= this head's output at last slot).
1326-
// Its position is implicit: round_start + toks.size().
1327-
s.toks.push_back(id);
1328-
s.slot_h.emplace_back(h_row, h_row + n_embd);
1329-
}
1330-
}
13311227

1332-
llama_set_mtp_layer_offset(ctx_dft, 0); // restore default
1333-
1334-
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1335-
auto & dp = dparams[seq_id];
1336-
if (!dp.drafting) {
1337-
continue;
1338-
}
13391228
if (dp.result->size() < (size_t) params.n_min) {
13401229
dp.result->clear();
13411230
}

0 commit comments

Comments
 (0)