Skip to content

Commit d789527

Browse files
forforever73am17an
andauthored
spec : Support Step3.5/3.7 flash mtp3 (#24340)
* add mtp_layer_offset + include nextn flags in graph reuse * add llama_set_mtp_layer_offset + llama_model_n_nextn_layer API * offset head select + require all MTP blocks * speculative multi-head process() * speculative multi-head draft() * gather outputs via inp_out_ids * cleanup * fix core * minor cleanup * merged draft_multi_head into draft() * mtp rename nextn * Apply suggestions from code review Co-authored-by: Aman Gupta <amangupta052@gmail.com> * clean-up comments * fix for multi seq * apply suggestions && chain-heads comment * add a reference for chain_heads discussion --------- Co-authored-by: Aman Gupta <amangupta052@gmail.com>
1 parent 063d9c1 commit d789527

9 files changed

Lines changed: 168 additions & 74 deletions

File tree

common/speculative.cpp

Lines changed: 102 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -905,7 +905,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
905905

906906
int32_t n_embd = 0;
907907

908-
bool is_mem_shared = false;
908+
// One MTP draft driver, three modes (set once in the ctor):
909+
// is_mem_shared (gemma4): shares the target KV, runs all heads in one graph.
910+
// chain_heads (step35): n_mtp_layers trained heads, one per draft step.
911+
// neither (qwen35 / qwen35moe): a single trained MTP head.
912+
int32_t n_mtp_layers = 1;
913+
bool is_mem_shared = false; // gemma4
914+
bool chain_heads = false; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared
909915

910916
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
911917
// The last h-row of one process() call needs the first token of the NEXT
@@ -920,10 +926,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
920926
std::vector<std::vector<float>> verify_h;
921927
std::vector<int32_t> verify_h_rows;
922928

923-
// Per-seq draft length from the last draft() call, used in accept() to
924-
// roll back ctx_dft's recurrent state past the AR draft's redundant
925-
// pre-advancement before process() mirrored the verify batch.
926-
std::vector<uint16_t> last_n_drafted;
929+
std::vector<int> i_last;
930+
std::vector<std::vector<float>> chain_h;
927931

928932
common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
929933
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
@@ -936,6 +940,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
936940
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
937941
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
938942
"MTP input row width must match the target h_nextn width");
943+
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
939944

940945
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
941946
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
@@ -982,16 +987,25 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
982987
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
983988

984989
is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt;
990+
chain_heads = n_mtp_layers > 1 && !is_mem_shared;
991+
992+
if (chain_heads) {
993+
this->params.n_max = std::min(this->params.n_max, n_mtp_layers);
994+
995+
chain_h.assign(n_seq, {});
996+
for (auto & c : chain_h) {
997+
c.reserve((size_t) (this->params.n_max + 1) * n_embd);
998+
}
999+
}
9851000

9861001
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
9871002

1003+
i_last.assign(n_seq, -1);
9881004
i_batch_beg.assign(n_seq, -1);
9891005
i_batch_end.assign(n_seq, -1);
9901006

9911007
verify_h.assign(n_seq, {});
9921008
verify_h_rows.assign(n_seq, 0);
993-
994-
last_n_drafted.assign(n_seq, 0);
9951009
}
9961010

9971011
~common_speculative_impl_draft_mtp() override {
@@ -1097,9 +1111,34 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
10971111
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
10981112
}
10991113

1100-
const int32_t rc = llama_decode(ctx_dft, batch);
1101-
if (rc != 0) {
1102-
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]);
1114+
auto * mem_dft = llama_get_memory(ctx_dft);
1115+
1116+
bool ok = true;
1117+
for (int head = 0; head < n_mtp_layers; ++head) {
1118+
if (chain_heads) {
1119+
// ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544
1120+
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1121+
if (i_batch_beg[seq_id] < 0) {
1122+
continue;
1123+
}
1124+
llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1);
1125+
}
1126+
llama_set_nextn_layer_offset(ctx_dft, head);
1127+
}
1128+
1129+
const int32_t rc = llama_decode(ctx_dft, batch);
1130+
if (rc != 0) {
1131+
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
1132+
__func__, head, (int) rc, (int) batch_in.pos[0]);
1133+
ok = false;
1134+
break;
1135+
}
1136+
}
1137+
1138+
if (chain_heads) {
1139+
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
1140+
}
1141+
if (!ok) {
11031142
return false;
11041143
}
11051144
}
@@ -1134,7 +1173,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11341173
int n_drafting = 0;
11351174
std::vector<bool> drafting(n_seq);
11361175

1137-
const float * h_row = nullptr;
11381176
const size_t row_bytes = (size_t) n_embd * sizeof(float);
11391177

11401178
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@@ -1149,22 +1187,43 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11491187
common_sampler_reset(smpls[seq_id].get());
11501188

11511189
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
1190+
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, pending_h[seq_id].data(), row_bytes);
11521191

1153-
h_row = pending_h[seq_id].data();
1154-
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
1155-
}
1192+
i_last[seq_id] = batch.n_tokens - 1;
11561193

1157-
int ret = llama_decode(ctx_dft, batch);
1158-
if (ret != 0) {
1159-
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
1160-
return;
1194+
if (chain_heads) {
1195+
chain_h[seq_id].assign(pending_h[seq_id].begin(), pending_h[seq_id].end());
1196+
}
11611197
}
11621198

11631199
int i = 0;
11641200

11651201
while (n_drafting > 0) {
1166-
int i_batch = 0;
1202+
// each step decodes under a different head, i.e. a different decoder layer, and
1203+
// KV is per layer. process() filled this layer's KV only for positions < n_past
1204+
// (prompt + accepted prefix) — nothing in the draft region yet. so reset the
1205+
// draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact)
1206+
// and select head i so it rebuilds its own layer's KV there; decoding just the
1207+
// latest token would leave its attention reading cells only another head wrote.
1208+
if (chain_heads) {
1209+
auto * mem_dft = llama_get_memory(ctx_dft);
1210+
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
1211+
if (drafting[seq_id]) {
1212+
llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1);
1213+
}
1214+
}
1215+
llama_set_nextn_layer_offset(ctx_dft, i);
1216+
}
11671217

1218+
int ret = llama_decode(ctx_dft, batch);
1219+
if (ret != 0) {
1220+
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
1221+
break;
1222+
}
1223+
1224+
// rebuild the batch for the next step: the growing-KV paths re-add only the
1225+
// new token (the KV already holds the prefix), while chained heads re-add the
1226+
// whole prefix at the next head. dropped sequences are simply not re-added.
11681227
common_batch_clear(batch);
11691228

11701229
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
@@ -1174,9 +1233,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
11741233

11751234
auto * smpl = smpls[seq_id].get();
11761235

1177-
common_sampler_sample(smpl, ctx_dft, i_batch, true);
1178-
h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
1179-
++i_batch;
1236+
common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true);
1237+
const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]);
11801238

11811239
const auto * cur_p = common_sampler_get_candidates(smpl, true);
11821240

@@ -1210,30 +1268,41 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
12101268
continue;
12111269
}
12121270

1213-
if (is_mem_shared) {
1271+
if (chain_heads) {
1272+
// ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546
1273+
chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd);
1274+
1275+
const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far
1276+
for (int t = 0; t < n_rows; ++t) {
1277+
const llama_token tok = (t == 0) ? dp.id_last : result[t - 1];
1278+
common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1);
1279+
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd,
1280+
chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes);
1281+
}
1282+
} else if (is_mem_shared) {
12141283
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
12151284
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
12161285
common_batch_add(batch, id, dp.n_past, { seq_id }, true);
1286+
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
12171287
} else {
12181288
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
1289+
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
12191290
}
1220-
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
1221-
}
12221291

1223-
if (batch.n_tokens == 0) {
1224-
break;
1292+
i_last[seq_id] = batch.n_tokens - 1;
12251293
}
12261294

1227-
// evaluate the drafted tokens on the draft model
1228-
ret = llama_decode(ctx_dft, batch);
1229-
if (ret != 0) {
1230-
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
1295+
if (batch.n_tokens == 0) {
12311296
break;
12321297
}
12331298

12341299
++i;
12351300
}
12361301

1302+
if (chain_heads) {
1303+
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
1304+
}
1305+
12371306
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
12381307
auto & dp = dparams[seq_id];
12391308
if (!dp.drafting) {
@@ -1243,8 +1312,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
12431312
if (dp.result->size() < (size_t) params.n_min) {
12441313
dp.result->clear();
12451314
}
1246-
1247-
last_n_drafted[seq_id] = (uint16_t) dp.result->size();
12481315
}
12491316
}
12501317

@@ -1857,7 +1924,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
18571924

18581925
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
18591926
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
1860-
bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
1927+
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
18611928

18621929

18631930

@@ -1895,7 +1962,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
18951962
if (has_draft_eagle3) {
18961963
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
18971964
}
1898-
if (has_mtp) {
1965+
if (has_draft_mtp) {
18991966
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
19001967
}
19011968
}

include/llama.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -558,14 +558,15 @@ extern "C" {
558558
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
559559
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
560560

561-
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
562-
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
563-
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
564-
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
565-
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
566-
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
567-
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
568-
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
561+
LLAMA_API int32_t llama_model_n_ctx_train (const struct llama_model * model);
562+
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
563+
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
564+
LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model);
565+
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
566+
LLAMA_API int32_t llama_model_n_layer_nextn(const struct llama_model * model);
567+
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
568+
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
569+
LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model);
569570

570571
// Get the model's RoPE frequency scaling factor
571572
LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);

src/llama-context.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,10 @@ void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) {
11561156
sched_need_reserve = true;
11571157
}
11581158

1159+
void llama_context::set_nextn_layer_offset(int32_t offset) {
1160+
cparams.nextn_layer_offset = offset;
1161+
}
1162+
11591163
void llama_context::set_causal_attn(bool value) {
11601164
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
11611165

@@ -3699,6 +3703,10 @@ void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool valu
36993703
ctx->set_embeddings_layer_inp(lid, value);
37003704
}
37013705

3706+
void llama_set_nextn_layer_offset(llama_context * ctx, int32_t offset) {
3707+
ctx->set_nextn_layer_offset(offset);
3708+
}
3709+
37023710
llama_memory_t llama_get_memory(const struct llama_context * ctx) {
37033711
if (!ctx) {
37043712
return nullptr;

src/llama-context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ struct llama_context {
115115
void set_embeddings (bool value);
116116
void set_embeddings_nextn(bool value, bool masked);
117117
void set_embeddings_layer_inp(uint32_t lid, bool enable);
118+
void set_nextn_layer_offset(int32_t offset);
118119
void set_causal_attn(bool value);
119120
void set_warmup(bool value);
120121

src/llama-cparams.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ struct llama_cparams {
1818
int32_t n_threads; // number of threads to use for generation
1919
int32_t n_threads_batch; // number of threads to use for batch processing
2020

21+
int32_t nextn_layer_offset = 0;
22+
2123
float rope_freq_base;
2224
float rope_freq_scale;
2325

src/llama-ext.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,11 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c
9595
// If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits
9696
LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked);
9797

98+
// Select which appended NextN block the DECODER_MTP graph runs (offset past
99+
// the trunk: il = n_layer() + offset). Used by the speculative NextN driver to
100+
// chain multiple trained NextN heads. Default 0 (first head).
101+
LLAMA_API void llama_set_nextn_layer_offset(struct llama_context * ctx, int32_t offset);
102+
98103
// mirrors:
99104
// LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
100105
LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx);

src/llama-graph.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -682,9 +682,16 @@ struct llm_graph_params {
682682
}
683683
}
684684

685+
// TODO: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448035248
686+
if (cparams.nextn_layer_offset != other.cparams.nextn_layer_offset) {
687+
return false;
688+
}
689+
685690
return
686-
cparams.embeddings == other.cparams.embeddings &&
687-
cparams.causal_attn == other.cparams.causal_attn &&
691+
cparams.embeddings == other.cparams.embeddings &&
692+
cparams.embeddings_nextn == other.cparams.embeddings_nextn &&
693+
cparams.embeddings_nextn_masked == other.cparams.embeddings_nextn_masked &&
694+
cparams.causal_attn == other.cparams.causal_attn &&
688695
arch == other.arch &&
689696
gtype == other.gtype &&
690697
cvec == other.cvec &&

src/llama-model.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2312,6 +2312,10 @@ int32_t llama_model_n_layer(const llama_model * model) {
23122312
return model->hparams.n_layer();
23132313
}
23142314

2315+
int32_t llama_model_n_layer_nextn(const llama_model * model) {
2316+
return model->hparams.n_layer_nextn;
2317+
}
2318+
23152319
int32_t llama_model_n_head(const llama_model * model) {
23162320
return model->hparams.n_head();
23172321
}

0 commit comments

Comments
 (0)