Support Step3.5/3.7 flash mtp3#24340
Conversation
|
I think we want at least @ggerganov and @am17an here for the discussion about how to solve multi-layer MTP in core. |
|
Yes, I propose an initial approach here. I think it's semantically correct while keeping the changes relatively small. |
|
Yeah, I purposefully wanted the original StepFun MTP PR to be small because I did a little foray into implementing the full MTP and then saw it would be quite a challenging task, think it's good to discuss this :) |
|
@ggerganov @am17an Would you have some time to take a look ? |
|
From what I understand, this can be achieved if we fix the draft length (via |
|
@am17an Yeah, i guess the ctx_dft state you mean is the mtp_layer_offset I added. getting 3-layer mtp to run is the easy part; the rest of the diff is all about keeping the KV cache correct. Unlike gemma4 (all nextn layers in one graph, shared target KV, single position), step35's heads are chained with a sample in between and run on their own kv. So I can't just bump the layer per step on the single-token loop, or head 46 attends back to a cell only head 45 ever wrote on its layer and reads garbage. That's why each step has to seq_rm the round and re-decode the accumulated prefix on the current head's layer. The correct semantics in vLLM architecture can reference mtp3. I think it can be done more simply under llama.cpp's architecture. |
|
I think you can still optionally add the |
|
Right, seq_rm is part of it. And also need to re-decode the whole accumulated prefix [id_last, draft_1, …] on the current head's layer each step, so that head writes its own layer's kv for every position. |
|
I see, so you have to keep the draft_tokens and embeddings to copy them in each subsequent draft round. I think you can keep these two vectors host side and add them while rebuilding the batch. As an aside, this would not be able to use CUDA graphs as the topology for the draft will keep changing (i.e. batch size goes from 1 to 2 to 3 etc). |
Yep, that's exactly what the current implementation does.
I'm aware. But each one needs the token the previous head sampled so is hard to avoid. And the perf cost is contained, draft is at most n_nextn tiny decodes. |
|
Yes but I think the current implementation can be simplified. |
| // Each slot's embd is the hidden produced by the PREVIOUS head for that token | ||
| // (slot 0 is always pending_h = trunk h). Per-step seq_rm keeps each head's KV | ||
| // on a clean, position-aligned slot set. | ||
| void draft_multi_head(common_speculative_draft_params_vec & dparams) { |
There was a problem hiding this comment.
this should be part of draft rather a separate function. The only difference is that you need to add embd + token of the last sampled head to the batch
There was a problem hiding this comment.
right, have merged
cffdd9a to
2952d83
Compare
am17an
left a comment
There was a problem hiding this comment.
You can clean-up the comments a bit to follow the rest of the repo. If the code is self-explanatory we prefer not to add comments in cpp files (.h files is encouraged). If something is a bit non-intuitive (like seq_rm in this PR) then it takes sense to add a comment to explain. You should also check MTP performance/correctness of Qwen3.6 and Gemma4
| LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model); | ||
| LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model); | ||
| LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); | ||
| // Number of appended NextN/MTP prediction blocks (0 if the model has none) |
There was a problem hiding this comment.
| // Number of appended NextN/MTP prediction blocks (0 if the model has none) |
Also need to fix the alignment
| // MTP (multi-token prediction): which appended NextN/MTP block the | ||
| // DECODER_MTP graph runs, as an offset past the trunk (il = n_layer() + offset). | ||
| // 0 selects the first MTP head; the speculative driver bumps it per draft step. | ||
| int32_t mtp_layer_offset = 0; |
There was a problem hiding this comment.
replace occurences of "mtp" with "nextn" to be make it more consistent
| void set_embeddings (bool value); | ||
| void set_embeddings_nextn(bool value, bool masked); | ||
| void set_embeddings_layer_inp(uint32_t lid, bool enable); | ||
| void set_mtp_layer_offset(int32_t offset); |
There was a problem hiding this comment.
| void set_mtp_layer_offset(int32_t offset); | |
| void set_nextn_layer_offset(int32_t offset); |
Co-authored-by: Aman Gupta <amangupta052@gmail.com>
|
Test MTP performance/correctness of Qwen3.6 and Gemma4 on H800, with parameter 📜 performance of Qwen3.6 and Gemma4.Qwen3.6 master Qwen3.6 new Gemma4 master Gemma4 new 📜 correctness of Qwen3.6 and Gemma4.Qwen3.6 master Qwen3.6 new Gemma4 master Gemma4 new |
|
@CISC could you take a look :) |
LGTM, but @ggerganov should sign off on the API. |
ggerganov
left a comment
There was a problem hiding this comment.
Looks incorrect with multiple sequences.
|
You're right. Have fixed it (matching the eagle3 pattern), verified single-sequence output unchanged and concurrent runs now hit 0 decode failures on both unified and non-unified caches |
| auto * mem_dft = llama_get_memory(ctx_dft); | ||
|
|
||
| bool ok = true; | ||
| for (int head = 0; head < n_mtp_layers; ++head) { | ||
| if (chain_heads) { | ||
| for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { | ||
| if (i_batch_beg[seq_id] < 0) { | ||
| continue; | ||
| } | ||
| llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1); | ||
| } | ||
| llama_set_nextn_layer_offset(ctx_dft, head); | ||
| } | ||
|
|
||
| const int32_t rc = llama_decode(ctx_dft, batch); | ||
| if (rc != 0) { | ||
| LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n", | ||
| __func__, head, (int) rc, (int) batch_in.pos[0]); | ||
| ok = false; | ||
| break; | ||
| } | ||
| } | ||
|
|
||
| if (chain_heads) { | ||
| llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes | ||
| } |
There was a problem hiding this comment.
I don't understand the logic here - seems incorrect. Every head iteration will basically erase the result of the previous iteration.
There was a problem hiding this comment.
Each head runs a different layer, set_nextn_layer_offset(head) makes graph_mtp build layer n_layer()+head, i.e. 45/46/47. So each head writes its own k_l[il]/v_l[il].
seq_rm here doesn't drop any KV data — it just clears the cell metadata so find_slot hands back the same cells at the same positions for every head. Without it, head 46/47 would land on fresh cells and we'd get duplicate positions in v_cells .
So after the loop those cells hold valid KV for all three MTP layers at once. it's the teacher-forcing catch-up that seeds each head's layer so the next draft() round attends to a correct, target-aligned cache.
There was a problem hiding this comment.
Ok, got it. That's interesting.
|
@ggerganov Hey, just a quick ping on this pr when you have a chance :) |
|
Here is a minor patch I wanted to push, but don't have the permission: diff --git a/common/speculative.cpp b/common/speculative.cpp
index fd0cf138f..d7a177b7b 100644
--- a/common/speculative.cpp
+++ b/common/speculative.cpp
@@ -1027,6 +1027,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
bool ok = true;
for (int head = 0; head < n_mtp_layers; ++head) {
if (chain_heads) {
+ // ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
@@ -1837,7 +1838,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
- bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
+ bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
@@ -1875,7 +1876,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
}
- if (has_mtp) {
+ if (has_draft_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
}
} |
| if (chain_heads) { | ||
| chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd); | ||
|
|
||
| const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far | ||
| for (int t = 0; t < n_rows; ++t) { | ||
| const llama_token tok = (t == 0) ? dp.id_last : result[t - 1]; | ||
| common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1); | ||
| std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, | ||
| chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes); | ||
| } |
There was a problem hiding this comment.
This seems incorrect - every next draft, we decode all previous tokens again. Why is that? Normally, we should decode just the latest token.
There was a problem hiding this comment.
You're right that every draft step re-decodes the whole prefix — that's intentional. This path isn't a normal AR draft; it's an accumulating-batch + per-step-seq_rm flow, because each step runs a different head.
After process(), every head's KV is already filled for the prompt + accepted prefix (positions < n_past), and the per-step seq_rm lower bound is n_past, so we never re-decode the prompt. What gets replayed each step is only the draft region — id_last plus the few tokens drafted so far (capped at n_max=3).
Why replay it under each head: since we switch heads per step, the current head hasn't written its own KV for any draft-region position yet this round. If we only decoded the latest token, its attention over the earlier draft positions would read cells that only another head ever wrote → garbage. So each step we seq_rm the draft region (so find_slot reuses the same slots and positions stay aligned), switch the head, and replay the accumulated prefix so this head fills its own KV.
For contrast, the other two branches reuse a single head across steps, so their prefix KV is already valid and they just append the latest token
There was a problem hiding this comment.
Ah yes. I'm still not used to that approach, but it seems correct. Add a reference to this explanation in the code.
There was a problem hiding this comment.
Agreed it's not intuitive, it took me quite a while to design this too :) Added a comment and committed the other suggestions as well.
There was a problem hiding this comment.
Add a reference:
diff --git a/common/speculative.cpp b/common/speculative.cpp
index d7a177b7b..f8a6287c2 100644
--- a/common/speculative.cpp
+++ b/common/speculative.cpp
@@ -1184,6 +1184,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
}
if (chain_heads) {
+ // ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546
chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd);
const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far|
|
||
| std::vector<int> i_last(n_seq, -1); | ||
|
|
||
| std::vector<std::vector<float>> chain_h; |
There was a problem hiding this comment.
Should be allocated and reserved once at construction time.
| cparams.embeddings == other.cparams.embeddings && | ||
| cparams.embeddings_nextn == other.cparams.embeddings_nextn && | ||
| cparams.embeddings_nextn_masked == other.cparams.embeddings_nextn_masked && | ||
| cparams.nextn_layer_offset == other.cparams.nextn_layer_offset && |
There was a problem hiding this comment.
This is correct, but effectively it will disable graph reuse during drafting. However, there isn't a better way to do it for now as we don't have a mechanism to do layer selection at compute time. It's something to think about in the future.
There was a problem hiding this comment.
Add TODO to not forget:
diff --git a/src/llama-graph.h b/src/llama-graph.h
index d2a1b39d4..ac00d6cc6 100644
--- a/src/llama-graph.h
+++ b/src/llama-graph.h
@@ -682,11 +682,15 @@ struct llm_graph_params {
}
}
+ // TODO: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448035248
+ if (cparams.nextn_layer_offset != other.cparams.nextn_layer_offset) {
+ return false;
+ }
+
return
cparams.embeddings == other.cparams.embeddings &&
cparams.embeddings_nextn == other.cparams.embeddings_nextn &&
cparams.embeddings_nextn_masked == other.cparams.embeddings_nextn_masked &&
- cparams.nextn_layer_offset == other.cparams.nextn_layer_offset &&
cparams.causal_attn == other.cparams.causal_attn &&
arch == other.arch &&
gtype == other.gtype &&
ggerganov
left a comment
There was a problem hiding this comment.
Not sure why the EditorConfig check is failing - seems like a false-positive. Should be good to merge.
For some reason the line number is the would-be-merged line number, so sometimes it's a bit off. |
Yes, though I merged master into my working copy of this branch and it didn't result into whitespaces. Not sure what caused this. |
|
A big thank you for this PR folks! I've been from ~18tok/s to ~30tok/s with coding tasks on a Strix Halo (with ROCm, the |
Overview
follow-up to #23274.(cc @pwilkin )
📜 Full data-flow trace — couldn't think of a good way to draw this, so I wrote it all down instead. It's long, but every byte is load-bearing.
Notation:
token@pos/h(pos)— positions are explicit (0-indexed)h_tgt(p)— target NextN hidden at p (before the output norm)h45(p)/h46(p)— head 45/46 output hidden, chained between heads while draftingpending_h = h_tgt(pos of id_last − 1)— always the trunk h, regardless of chainingExample: a 4-token prompt at positions 0–3.
Core strategy
Each MTP head is its own decoder layer with its own KV, and the driver runs one head per
llama_decode. Aseq_rmbefore each head clears the range it re-decodes, so it reuses the same slots (find_slotis deterministic) instead of stacking duplicate positions;find_slot/apply_ubatchare untouched. The two phases differ only in what feeds the heads:process()h_tgt, right-shifted by one (not the inter-head hidden)draft()pending_h)Only the trunk
h_tgtcrosses rounds, sopending_h/verify_hstay single-layer.MTP Block Selection Strategy
cparams.mtp_layer_offset(src/llama-cparams.h) — picks which appended MTP block theDECODER_MTPgraph runs:il = n_layer() + offset. Default 0.graph_mtpselects the head by offset (il = n_layer() + cparams.mtp_layer_offset, was a hardcodedn_layer()).graph_mtpnow gathers its output rows viabuild_inp_out_ids(), like the trunk graph. The fix that makes chaining work: from step 1 on, the output is the last batch row, not row 0, so without it heads 46/47 read the wrong row. Identity gather whenn_outputs == n_tokens, so the single-head path is unchanged.n_layer_nextnMTP blocks.n_maxis clamped to the head count when chaining (each head used once).Results
The command is identical on both machines; only
--spec-draft-n-maxand the build change. Before = single-block MTP on master (one head, looped when n-max > 1); after = the three-layer chain.DGX Spark GB10
Before (single-block MTP, master)
--spec-draft-n-max 2--spec-draft-n-max 3After (three-layer MTP, this PR)
--spec-draft-n-max 2--spec-draft-n-max 3Mac Studio M4 Max
Before (single-block MTP, master)
--spec-draft-n-max 2--spec-draft-n-max 3After (three-layer MTP, this PR)
--spec-draft-n-max 2--spec-draft-n-max 3Requirements