Skip to content

Commit 8ea0fa8

Browse files
committed
llama + spec: MTP support
1 parent 44fc110 commit 8ea0fa8

21 files changed

Lines changed: 870 additions & 37 deletions

common/arg.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3562,12 +3562,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
35623562
}
35633563
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
35643564
add_opt(common_arg(
3565-
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
3565+
{"--spec-type"}, "[none|mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
35663566
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
35673567
common_speculative_type_to_str(params.speculative.type).c_str()),
35683568
[](common_params & params, const std::string & value) {
35693569
if (value == "none") {
35703570
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
3571+
} else if (value == "mtp") {
3572+
params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
35713573
} else if (value == "ngram-cache") {
35723574
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
35733575
} else if (value == "ngram-simple") {

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ enum common_speculative_type {
159159
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
160160
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
161161
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
162+
COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction
162163
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
163164
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
164165
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values

common/speculative.cpp

Lines changed: 311 additions & 1 deletion
Large diffs are not rendered by default.

common/speculative.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ std::string common_speculative_type_to_str(enum common_speculative_type type);
1616

1717
common_speculative * common_speculative_init(
1818
common_params_speculative & params,
19-
llama_context * ctx_tgt);
19+
llama_context * ctx_tgt,
20+
llama_context * ctx_mtp = nullptr);
2021

2122
void common_speculative_free(common_speculative * spec);
2223

gguf-py/gguf/constants.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2018,7 +2018,14 @@ class MODEL_TENSOR(IntEnum):
20182018
MODEL_TENSOR.SSM_NORM,
20192019
MODEL_TENSOR.SSM_BETA,
20202020
MODEL_TENSOR.SSM_ALPHA,
2021-
MODEL_TENSOR.SSM_OUT
2021+
MODEL_TENSOR.SSM_OUT,
2022+
# NextN/MTP tensors - preserved but unused
2023+
MODEL_TENSOR.NEXTN_EH_PROJ,
2024+
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
2025+
MODEL_TENSOR.NEXTN_ENORM,
2026+
MODEL_TENSOR.NEXTN_HNORM,
2027+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
2028+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
20222029
],
20232030
MODEL_ARCH.QWEN35MOE: [
20242031
MODEL_TENSOR.TOKEN_EMBD,
@@ -2049,7 +2056,14 @@ class MODEL_TENSOR(IntEnum):
20492056
MODEL_TENSOR.SSM_NORM,
20502057
MODEL_TENSOR.SSM_BETA,
20512058
MODEL_TENSOR.SSM_ALPHA,
2052-
MODEL_TENSOR.SSM_OUT
2059+
MODEL_TENSOR.SSM_OUT,
2060+
# NextN/MTP tensors - preserved but unused
2061+
MODEL_TENSOR.NEXTN_EH_PROJ,
2062+
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
2063+
MODEL_TENSOR.NEXTN_ENORM,
2064+
MODEL_TENSOR.NEXTN_HNORM,
2065+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
2066+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
20532067
],
20542068
MODEL_ARCH.PLAMO: [
20552069
MODEL_TENSOR.TOKEN_EMBD,

include/llama.h

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@ extern "C" {
310310
// override key-value pairs of the model meta data
311311
const struct llama_model_kv_override * kv_overrides;
312312

313+
// override arch from GGUF to load MTP as a separate ctx
314+
const char * override_arch;
315+
313316
// Keep the booleans together to avoid misalignment during copy-by-value.
314317
bool vocab_only; // only load the vocabulary, no weights
315318
bool use_mmap; // use mmap if possible
@@ -967,6 +970,56 @@ extern "C" {
967970
// If true, all model tensors are activated during llama_decode() to load and cache their weights.
968971
LLAMA_API void llama_set_warmup(struct llama_context * ctx, bool warmup);
969972

973+
// Accessors for graph-output tensors used by speculative decoders that
974+
// need intermediate hidden states (e.g. MTP / NextN). Returns nullptr if
975+
// the most recent decode didn't populate the tensor. Call llama_synchronize
976+
// on the source context before reading via ggml_backend_tensor_get.
977+
LLAMA_API struct ggml_tensor * llama_context_get_t_h_pre_norm(struct llama_context * ctx);
978+
LLAMA_API struct ggml_tensor * llama_context_get_t_mtp_out (struct llama_context * ctx);
979+
980+
// Generic post-compute callback fired from inside process_ubatch after
981+
// each ubatch's compute finishes. Speculative decoders register this to
982+
// mirror the trunk's hidden state into a sibling context (e.g. an MTP
983+
// draft head) and decode into its KV. Pass cb = nullptr to clear.
984+
typedef void (*llama_post_ubatch_cb_t)(
985+
struct llama_context * ctx,
986+
int32_t n_tokens,
987+
const llama_token * tokens,
988+
const llama_pos * positions,
989+
struct ggml_tensor * t_h_pre_norm,
990+
void * user_data);
991+
992+
LLAMA_API void llama_set_post_ubatch_cb(
993+
struct llama_context * ctx,
994+
llama_post_ubatch_cb_t cb,
995+
void * user_data);
996+
997+
// Generic post-seq_rm callback fired from inside llama_context_seq_rm
998+
// after the trunk's memory.seq_rm completes. Speculative decoders that
999+
// mirror trunk KV state to a sibling context register this. Pass cb =
1000+
// nullptr to clear.
1001+
typedef void (*llama_post_seq_rm_cb_t)(
1002+
struct llama_context * ctx,
1003+
llama_seq_id seq_id,
1004+
llama_pos p0,
1005+
llama_pos p1,
1006+
void * user_data);
1007+
1008+
LLAMA_API void llama_set_post_seq_rm_cb(
1009+
struct llama_context * ctx,
1010+
llama_post_seq_rm_cb_t cb,
1011+
void * user_data);
1012+
1013+
// seq_rm on the trunk's memory plus dispatch to the registered post-
1014+
// seq_rm callback (no-op if none). Use in place of
1015+
// llama_memory_seq_rm(llama_get_memory(ctx), ...) at trunk-side seq_rm
1016+
// sites so observers (e.g. an MTP context) stay in lockstep.
1017+
LLAMA_API bool llama_context_seq_rm(
1018+
struct llama_context * ctx,
1019+
llama_seq_id seq_id,
1020+
llama_pos p0,
1021+
llama_pos p1);
1022+
9701023
// Set abort callback
9711024
LLAMA_API void llama_set_abort_callback(struct llama_context * ctx, ggml_abort_callback abort_callback, void * abort_callback_data);
9721025

src/llama-arch.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
4141
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
4242
{ LLM_ARCH_QWEN35, "qwen35" },
4343
{ LLM_ARCH_QWEN35MOE, "qwen35moe" },
44+
{ LLM_ARCH_QWEN35_MTP, "qwen35_mtp" },
4445
{ LLM_ARCH_PHI2, "phi2" },
4546
{ LLM_ARCH_PHI3, "phi3" },
4647
{ LLM_ARCH_PHIMOE, "phimoe" },
@@ -756,14 +757,15 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
756757
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
757758
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
758759
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
759-
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
760-
// These tensors only exist in the last layer(s) and are treated as output tensors
761-
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
762-
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
763-
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
764-
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
765-
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
766-
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
760+
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
761+
// last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so
762+
// the model loader doesn't fault on the block index.
763+
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
764+
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
765+
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
766+
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
767+
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
768+
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
767769
// Nemotron 3 Super
768770
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
769771
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

src/llama-arch.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ enum llm_arch {
4545
LLM_ARCH_QWEN3VLMOE,
4646
LLM_ARCH_QWEN35,
4747
LLM_ARCH_QWEN35MOE,
48+
LLM_ARCH_QWEN35_MTP,
4849
LLM_ARCH_PHI2,
4950
LLM_ARCH_PHI3,
5051
LLM_ARCH_PHIMOE,

src/llama-context.cpp

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,13 +1242,24 @@ llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, ll
12421242
return nullptr;
12431243
}
12441244

1245+
// Generic post-ubatch dispatch — speculative decoders register a callback
1246+
// here to mirror the trunk's hidden state into a sibling context.
1247+
if (post_ubatch_cb) {
1248+
post_ubatch_cb(this,
1249+
(int32_t) ubatch.n_tokens,
1250+
ubatch.token,
1251+
ubatch.pos,
1252+
res->t_h_pre_norm,
1253+
post_ubatch_ud);
1254+
}
1255+
12451256
ret = GGML_STATUS_SUCCESS;
12461257

12471258
return res;
12481259
}
12491260

12501261
int llama_context::encode(const llama_batch & batch_inp) {
1251-
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
1262+
GGML_ASSERT(batch_inp.token || batch_inp.embd);
12521263

12531264
if (batch_inp.n_tokens == 0) {
12541265
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
@@ -1538,7 +1549,7 @@ static bool needs_raw_logits(const llama_ubatch & ubatch, const std::map<llama_s
15381549
}
15391550

15401551
int llama_context::decode(const llama_batch & batch_inp) {
1541-
GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT
1552+
GGML_ASSERT(batch_inp.token || batch_inp.embd);
15421553

15431554
if (!memory) {
15441555
LLAMA_LOG_DEBUG("%s: cannot decode batches with this context (calling encode() instead)\n", __func__);
@@ -3095,6 +3106,32 @@ void llama_set_warmup(llama_context * ctx, bool warmup) {
30953106
ctx->set_warmup(warmup);
30963107
}
30973108

3109+
ggml_tensor * llama_context::get_t_h_pre_norm() const {
3110+
return gf_res_prev ? gf_res_prev->t_h_pre_norm : nullptr;
3111+
}
3112+
3113+
ggml_tensor * llama_context_get_t_h_pre_norm(struct llama_context * ctx) {
3114+
return ctx ? ctx->get_t_h_pre_norm() : nullptr;
3115+
}
3116+
3117+
ggml_tensor * llama_context::get_t_mtp_out() const {
3118+
return gf_res_prev ? gf_res_prev->t_mtp_out : nullptr;
3119+
}
3120+
3121+
ggml_tensor * llama_context_get_t_mtp_out(struct llama_context * ctx) {
3122+
return ctx ? ctx->get_t_mtp_out() : nullptr;
3123+
}
3124+
3125+
void llama_set_post_ubatch_cb(struct llama_context * ctx, llama_post_ubatch_cb_t cb, void * user_data) {
3126+
if (!ctx) return;
3127+
ctx->set_post_ubatch_cb(cb, user_data);
3128+
}
3129+
3130+
void llama_set_post_seq_rm_cb(struct llama_context * ctx, llama_post_seq_rm_cb_t cb, void * user_data) {
3131+
if (!ctx) return;
3132+
ctx->set_post_seq_rm_cb(cb, user_data);
3133+
}
3134+
30983135
void llama_synchronize(llama_context * ctx) {
30993136
ctx->synchronize();
31003137
}
@@ -3252,6 +3289,24 @@ bool llama_memory_seq_rm(
32523289
return mem->seq_rm(seq_id, p0, p1);
32533290
}
32543291

3292+
bool llama_context_seq_rm(
3293+
struct llama_context * ctx,
3294+
llama_seq_id seq_id,
3295+
llama_pos p0,
3296+
llama_pos p1) {
3297+
if (!ctx) {
3298+
return true;
3299+
}
3300+
const bool ok = llama_memory_seq_rm(llama_get_memory(ctx), seq_id, p0, p1);
3301+
3302+
// Dispatch to a registered observer (e.g. an MTP context wrapper) so
3303+
// sibling state stays in lockstep with the trunk's KV.
3304+
if (llama_post_seq_rm_cb_t cb = ctx->get_post_seq_rm_cb()) {
3305+
cb(ctx, seq_id, p0, p1, ctx->get_post_seq_rm_ud());
3306+
}
3307+
return ok;
3308+
}
3309+
32553310
void llama_memory_seq_cp(
32563311
llama_memory_t mem,
32573312
llama_seq_id seq_id_src,

src/llama-context.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,25 @@ struct llama_context {
6969
float * get_embeddings_ith(int32_t i);
7070
float * get_embeddings_seq(llama_seq_id seq_id);
7171

72+
// Accessors for graph-output tensors used by speculative decoders that
73+
// need intermediate hidden states. Return nullptr if the most recent
74+
// decode didn't populate them.
75+
ggml_tensor * get_t_h_pre_norm() const;
76+
ggml_tensor * get_t_mtp_out() const;
77+
78+
// Post-ubatch / post-seq_rm callbacks. See llama.h for semantics.
79+
// Pass cb=nullptr to clear.
80+
void set_post_ubatch_cb(llama_post_ubatch_cb_t cb, void * user_data) {
81+
post_ubatch_cb = cb;
82+
post_ubatch_ud = user_data;
83+
}
84+
void set_post_seq_rm_cb(llama_post_seq_rm_cb_t cb, void * user_data) {
85+
post_seq_rm_cb = cb;
86+
post_seq_rm_ud = user_data;
87+
}
88+
llama_post_seq_rm_cb_t get_post_seq_rm_cb() const { return post_seq_rm_cb; }
89+
void * get_post_seq_rm_ud() const { return post_seq_rm_ud; }
90+
7291
llama_token * get_sampled_tokens() const;
7392
llama_token get_sampled_token_ith(int32_t idx);
7493

@@ -253,6 +272,13 @@ struct llama_context {
253272

254273
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
255274

275+
// Generic post-compute / post-seq_rm callbacks. Speculative decoders that
276+
// need to mirror the trunk's state into a sibling context register here.
277+
llama_post_ubatch_cb_t post_ubatch_cb = nullptr;
278+
void * post_ubatch_ud = nullptr;
279+
llama_post_seq_rm_cb_t post_seq_rm_cb = nullptr;
280+
void * post_seq_rm_ud = nullptr;
281+
256282
std::unique_ptr<llama_memory_i> memory;
257283

258284
// decode output (2-dimensional array: [n_outputs][n_vocab])

0 commit comments

Comments
 (0)