Skip to content

Commit 5b4a8f1

Browse files
committed
Chain MTP
1 parent d69a72d commit 5b4a8f1

5 files changed

Lines changed: 58 additions & 20 deletions

File tree

common/speculative.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,12 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
416416

417417
int32_t n_embd = 0;
418418

419+
// Stacked-MTP draft chain length. Each MTP block k is trained to take the
420+
// (k-1)-th block's hidden state and predict the (k+1)-th token ahead, so
421+
// we can draft at most `n_mtp_layers` tokens per round before going
422+
// out-of-distribution.
423+
int32_t n_mtp_layers = 0;
424+
419425
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
420426
// The last h-row of one process() call needs the first token of the NEXT
421427
// call to pair with, so it's stashed here until that next call fires.
@@ -442,7 +448,9 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
442448
auto * ctx_dft = this->params.ctx_dft;
443449
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");
444450

445-
n_embd = llama_model_n_embd(llama_get_model(ctx_dft));
451+
n_embd = llama_model_n_embd (llama_get_model(ctx_dft));
452+
n_mtp_layers = llama_model_n_nextn(llama_get_model(ctx_dft));
453+
GGML_ASSERT(n_mtp_layers > 0 && "MTP draft requires the draft model to declare nextn_predict_layers > 0");
446454

447455
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
448456
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd);
@@ -635,9 +643,10 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
635643
std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes);
636644
}
637645

638-
// First draft step uses the first MTP block (step 0). Archs with a
639-
// single MTP block ignore this; multi-block archs (Step-3.5-Flash) use
640-
// it to round-robin across their N MTP layers.
646+
// Stacked-MTP is a *chain*: block k consumes block (k-1)'s hidden
647+
// state and predicts one token further than block (k-1). Step k uses
648+
// MTP block k; we cannot exceed `n_mtp_layers` steps without going
649+
// out of training distribution.
641650
llama_set_mtp_step(ctx_dft, 0);
642651

643652
int ret = llama_decode(ctx_dft, batch);
@@ -648,6 +657,10 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
648657

649658
int i = 0;
650659

660+
// Cap draft depth: never run more MTP chain steps than there are
661+
// trained MTP blocks. `params.n_max` may be larger; we just stop.
662+
const int n_chain_max = n_mtp_layers;
663+
651664
while (n_drafting > 0) {
652665
int i_batch = 0;
653666

@@ -704,8 +717,16 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl {
704717
break;
705718
}
706719

707-
// Step i+1: feed the i-th sampled draft token into the (i+1)-th
708-
// MTP block. Multi-block archs round-robin via mtp_step % N.
720+
// We just sampled the (i+1)-th draft token (T_{i+1}). If the
721+
// chain has no further block (T_{i+2} would need MTP block i+1
722+
// which doesn't exist), stop — sampling, not decoding, is what
723+
// matters for the drafted result.
724+
if (i + 1 >= n_chain_max) {
725+
break;
726+
}
727+
728+
// Step i+1: feed the i-th sampled draft token into MTP block i+1.
729+
// Direct indexing, no modulo — the chain is bounded by n_mtp_layers.
709730
llama_set_mtp_step(ctx_dft, (uint32_t)(i + 1));
710731

711732
// evaluate the drafted tokens on the draft model

src/llama-cparams.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,12 @@ struct llama_cparams {
3131
bool embeddings_pre_norm; // also extract the hidden state before the final output norm
3232
bool embeddings_pre_norm_masked; // extract for only rows where batch.logits != 0
3333

34-
// MTP draft-step index, used by archs with num_nextn_predict_layers > 1 to
35-
// round-robin across MTP blocks (matches vllm's spec_step_idx). The graph
36-
// builder selects `il = n_main + (mtp_step % nextn_predict_layers)`.
34+
// MTP draft-chain step index. Stacked-MTP archs (Step-3.5, with multiple
35+
// trained MTP blocks) interpret block `k` as: takes block (k-1)'s output
36+
// hidden state + the embedding of block (k-1)'s sampled token and
37+
// predicts one token further. The chain is bounded — block k is only
38+
// valid for step k, with no wrap-around — so the speculative driver caps
39+
// the AR loop at `nextn_predict_layers`.
3740
uint32_t mtp_step;
3841
bool causal_attn;
3942
bool offload_kqv;

src/llama-ext.h

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ using llama_memory_breakdown = std::map<ggml_backend_buffer_type_t, llama_memory
8585
LLAMA_API int32_t llama_model_n_expert (const struct llama_model * model);
8686
LLAMA_API int32_t llama_model_n_devices(const struct llama_model * model);
8787

88+
// Number of NextN/MTP prediction blocks (0 if the model has none). For
89+
// stacked-MTP architectures this caps the maximum useful speculative draft
90+
// depth: each block is a distinct chain step and they cannot be reused
91+
// because each block expects the previous block's hidden state as input.
92+
LLAMA_API int32_t llama_model_n_nextn (const struct llama_model * model);
93+
8894
LLAMA_API ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i);
8995

9096
LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_context * ctx);
@@ -109,8 +115,10 @@ LLAMA_API float * llama_get_embeddings_pre_norm_ith(struct llama_context * ctx,
109115
// MTP draft-step index (round-robin selector across MTP blocks)
110116
//
111117

112-
// Set the MTP draft-step index for the next llama_decode call. Used by archs
113-
// with num_nextn_predict_layers > 1 to round-robin across their MTP blocks
114-
// (matches vllm's spec_step_idx). Pass step = 0 for the first draft token,
115-
// step = 1 for the second, etc. The graph builder reads cparams.mtp_step.
118+
// Set the MTP draft-chain step index for the next llama_decode call. Stacked
119+
// MTP architectures consume one block per step in chain order — block 0
120+
// produces logits to sample T_{t+1}, block 1 produces T_{t+2}, etc.
121+
// `step` must be in [0, num_nextn_predict_layers): there is no wrap-around,
122+
// because block k expects block (k-1)'s output hidden state as input and is
123+
// out-of-distribution for any other position.
116124
LLAMA_API void llama_set_mtp_step(struct llama_context * ctx, uint32_t step);

src/llama-model.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2503,6 +2503,10 @@ int32_t llama_model_n_devices(const struct llama_model * model) {
25032503
return (int32_t)model->devices.size();
25042504
}
25052505

2506+
int32_t llama_model_n_nextn(const struct llama_model * model) {
2507+
return (int32_t) model->hparams.nextn_predict_layers;
2508+
}
2509+
25062510
ggml_backend_dev_t llama_model_get_device(const struct llama_model * model, int i) {
25072511
if (i < 0 || i >= (int)model->devices.size()) {
25082512
return nullptr;

src/models/step35.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -362,13 +362,15 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr
362362
: llm_graph_context(params) {
363363
GGML_ASSERT(hparams.nextn_predict_layers > 0 && "STEP35 MTP requires nextn_predict_layers > 0");
364364

365-
// Round-robin across MTP blocks at draft step boundaries. Matches vllm's
366-
// `current_step_idx = spec_step_idx % num_mtp_layers` (step3p5_mtp.py).
367-
// The first MTP block lives at layer index `n_main`; the speculative
368-
// driver bumps `cparams.mtp_step` between AR iterations.
369-
const int n_main = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
370-
const int step_offset = (int) (cparams.mtp_step % hparams.nextn_predict_layers);
371-
const int il = n_main + step_offset;
365+
// Stacked-MTP is a chain, NOT a round-robin: MTP block k expects the
366+
// hidden state output by block k-1 (or the backbone for k=0). Block k can
367+
// only predict the (k+1)-th token ahead and is undefined input
368+
// distribution for any other step. The speculative driver caps the AR
369+
// loop at `num_mtp_layers`; we assert that here.
370+
GGML_ASSERT(cparams.mtp_step < hparams.nextn_predict_layers &&
371+
"STEP35 MTP: draft step exceeds number of trained MTP blocks (no wrap-around)");
372+
const int n_main = (int) hparams.n_layer - (int) hparams.nextn_predict_layers;
373+
const int il = n_main + (int) cparams.mtp_step;
372374
const auto & layer = model.layers[il];
373375

374376
GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj");

0 commit comments

Comments
 (0)