Skip to content

Commit 10829db

Browse files
committed
llama + spec: MTP support
1 parent c5e0227 commit 10829db

25 files changed

Lines changed: 876 additions & 42 deletions

common/arg.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3562,12 +3562,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
35623562
}
35633563
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
35643564
add_opt(common_arg(
3565-
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
3565+
{"--spec-type"}, "[none|mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
35663566
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
35673567
common_speculative_type_to_str(params.speculative.type).c_str()),
35683568
[](common_params & params, const std::string & value) {
35693569
if (value == "none") {
35703570
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
3571+
} else if (value == "mtp") {
3572+
params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
35713573
} else if (value == "ngram-cache") {
35723574
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
35733575
} else if (value == "ngram-simple") {

common/common.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1496,7 +1496,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
14961496
cparams.n_ctx = params.n_ctx;
14971497
cparams.n_seq_max = params.n_parallel;
14981498
{
1499-
// TODO: add for MTP
15001499
const bool has_spec = (params.speculative.type != COMMON_SPECULATIVE_TYPE_NONE)
15011500
|| params.speculative.has_dft();
15021501
cparams.n_rs_seq = has_spec ? (uint32_t) params.speculative.draft.n_max : 0u;

common/common.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ enum common_speculative_type {
159159
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
160160
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
161161
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
162+
COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction
162163
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
163164
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
164165
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
@@ -347,11 +348,17 @@ struct common_params_speculative_ngram_cache {
347348
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding
348349
};
349350

351+
struct common_params_speculative_mtp {
352+
llama_model * model = nullptr;
353+
llama_context_params cparams;
354+
};
355+
350356
struct common_params_speculative {
351357
// TODO: become a vector in order to support "chains of speculators"
352358
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
353359

354360
common_params_speculative_draft draft;
361+
common_params_speculative_mtp mtp;
355362

356363
common_params_speculative_ngram_mod ngram_mod;
357364
common_params_speculative_ngram_map ngram_simple;
@@ -363,6 +370,10 @@ struct common_params_speculative {
363370
bool has_dft() const {
364371
return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty();
365372
}
373+
374+
bool has_mtp() const {
375+
return type == COMMON_SPECULATIVE_TYPE_MTP && mtp.model != nullptr;
376+
}
366377
};
367378

368379
struct common_params_vocoder {

common/speculative.cpp

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
2222
COMMON_SPECULATIVE_TYPE_NONE,
2323
COMMON_SPECULATIVE_TYPE_DRAFT,
2424
COMMON_SPECULATIVE_TYPE_EAGLE3,
25+
COMMON_SPECULATIVE_TYPE_MTP,
2526
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
2627
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
2728
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
@@ -33,6 +34,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
3334
{"none", COMMON_SPECULATIVE_TYPE_NONE},
3435
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
3536
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
37+
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
3638
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
3739
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
3840
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@@ -599,6 +601,171 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
599601
}
600602
};
601603

604+
struct common_speculative_state_mtp : public common_speculative_state {
605+
llama_context * ctx_tgt = nullptr;
606+
llama_context * ctx_mtp = nullptr;
607+
608+
llama_batch batch; // single token draft step
609+
common_sampler * smpl = nullptr;
610+
int32_t n_embd = 0;
611+
612+
uint16_t last_n_drafted = 0;
613+
int32_t last_n_accepted = -1;
614+
615+
common_speculative_state_mtp(enum common_speculative_type type,
616+
llama_context * ctx_tgt,
617+
llama_context * ctx_mtp)
618+
: common_speculative_state(type), ctx_tgt(ctx_tgt), ctx_mtp(ctx_mtp) {
619+
GGML_ASSERT(ctx_tgt && ctx_mtp);
620+
const llama_model * model_mtp = llama_get_model(ctx_mtp);
621+
n_embd = llama_model_n_embd(model_mtp);
622+
623+
{
624+
common_params_sampling sparams;
625+
sparams.no_perf = false;
626+
sparams.top_k = 1;
627+
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
628+
smpl = common_sampler_init(model_mtp, sparams);
629+
}
630+
631+
// TODO: multiple seq support
632+
batch = llama_batch_init(/*n_tokens=*/ 1, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
633+
batch.token = (llama_token *) malloc(sizeof(llama_token));
634+
batch.n_tokens = 1;
635+
batch.n_seq_id[0] = 1;
636+
batch.seq_id[0][0] = 0;
637+
batch.logits[0] = 1;
638+
639+
llama_set_mtp(ctx_tgt, ctx_mtp);
640+
}
641+
642+
~common_speculative_state_mtp() override {
643+
llama_set_mtp(ctx_tgt, nullptr);
644+
llama_batch_free(batch);
645+
common_sampler_free(smpl);
646+
if (ctx_mtp) {
647+
llama_free(ctx_mtp);
648+
}
649+
}
650+
651+
void begin(const llama_tokens & prompt) override {
652+
last_n_accepted = -1;
653+
last_n_drafted = 0;
654+
655+
const int32_t N = (int32_t) prompt.size();
656+
if (N <= 0) {
657+
return;
658+
}
659+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
660+
if (pos_max < N - 1) {
661+
LOG_WRN("%s: ctx_mtp pos_max=%d < N-1=%d — "
662+
"streaming hook may not be registered or not all prefill rows "
663+
"have logits=true. Drafts may degrade.\n",
664+
__func__, (int) pos_max, N - 1);
665+
}
666+
}
667+
668+
void draft(
669+
const common_params_speculative & params,
670+
const llama_tokens & prompt_tgt,
671+
llama_token id_last,
672+
llama_tokens & draft_tokens) override {
673+
GGML_UNUSED(prompt_tgt);
674+
draft_tokens.clear();
675+
676+
// accept with no-accepts (i.e. 0 accepts) returns early, but we still need to remove from the MTP kv-cache
677+
// TODO: check if bug in other spec states
678+
if (last_n_drafted > 0) {
679+
const int32_t n_to_drop = (int32_t) last_n_drafted - 1;
680+
if (n_to_drop > 0) {
681+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
682+
if (pos_max >= 0) {
683+
const llama_pos drop_from = pos_max - n_to_drop + 1;
684+
llama_memory_seq_rm(llama_get_memory(ctx_mtp), 0, drop_from, -1);
685+
}
686+
}
687+
last_n_drafted = 0;
688+
last_n_accepted = 0;
689+
}
690+
691+
const int32_t n_max = std::max(1, params.draft.n_max);
692+
const size_t row_bytes = (size_t) n_embd * sizeof(float);
693+
694+
llama_token cond_tok = id_last;
695+
llama_pos pos = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0) + 1;
696+
697+
// auto-regressive loop for MTP
698+
for (int32_t k = 0; k < n_max; ++k) {
699+
ggml_tensor * src;
700+
int32_t src_row;
701+
if (k == 0) {
702+
src = llama_context_get_t_h_pre_norm(ctx_tgt);
703+
if (last_n_accepted < 0) {
704+
// First draft after begin(): trunk's most recent decode is
705+
// the last prefill ubatch; its last row is h_{N-1}.
706+
src_row = (src && src->ne[1] > 0) ? (int32_t) src->ne[1] - 1 : 0;
707+
} else {
708+
src_row = last_n_accepted;
709+
}
710+
llama_synchronize(ctx_tgt);
711+
} else {
712+
// for the AR path get the mtp_out from the mtp ctx
713+
src = llama_context_get_t_mtp_out(ctx_mtp);
714+
src_row = src ? (int32_t) src->ne[1] - 1 : 0;
715+
llama_synchronize(ctx_mtp);
716+
}
717+
if (!src) {
718+
LOG_WRN("%s: missing source tensor at k=%d; stopping chain\n", __func__, k);
719+
return;
720+
}
721+
ggml_backend_tensor_get(src, batch.embd,
722+
(size_t) src_row * row_bytes, row_bytes);
723+
724+
batch.token[0] = cond_tok;
725+
batch.pos[0] = pos;
726+
727+
const int32_t dec_rc = llama_decode(ctx_mtp, batch);
728+
if (dec_rc != 0) {
729+
LOG_DBG("%s: llama_decode rc=%d at k=%d; stopping chain\n", __func__, dec_rc, k);
730+
return;
731+
}
732+
733+
const llama_token best = common_sampler_sample(smpl, ctx_mtp, 0);
734+
common_sampler_accept(smpl, best, /*accept_grammar=*/ false);
735+
draft_tokens.push_back(best);
736+
cond_tok = best;
737+
++pos;
738+
}
739+
740+
last_n_drafted = (uint16_t) draft_tokens.size();
741+
}
742+
743+
void accept(uint16_t n_accepted) override {
744+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
745+
const int32_t n_drafted_last = (int32_t) last_n_drafted;
746+
const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1);
747+
if (pos_max < 0) {
748+
last_n_accepted = (int32_t) n_accepted;
749+
return;
750+
}
751+
if (n_to_drop > 0) {
752+
const llama_pos drop_from = pos_max - n_to_drop + 1;
753+
llama_memory_seq_rm(llama_get_memory(ctx_mtp), /*seq_id=*/ 0,
754+
/*p0=*/ drop_from, /*p1=*/ -1);
755+
}
756+
last_n_drafted = 0;
757+
last_n_accepted = (int32_t) n_accepted;
758+
}
759+
760+
int32_t n_max(const common_params_speculative & params) const override {
761+
return std::max(1, params.draft.n_max);
762+
}
763+
764+
int32_t n_min(const common_params_speculative & params) const override {
765+
return std::max(1, params.draft.n_min);
766+
}
767+
};
768+
602769
// state of self-speculation (simple implementation, not ngram-map)
603770
struct common_speculative_state_ngram_simple : public common_speculative_state {
604771
common_ngram_simple_config config;
@@ -952,6 +1119,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
9521119
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
9531120
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
9541121
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
1122+
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
9551123
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
9561124
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
9571125
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
@@ -983,11 +1151,24 @@ common_speculative * common_speculative_init(
9831151
}
9841152
}
9851153

1154+
llama_context * ctx_mtp = nullptr;
1155+
if (params.has_mtp()) {
1156+
ctx_mtp = llama_init_from_model(params.mtp.model, params.mtp.cparams);
1157+
if (ctx_mtp == nullptr) {
1158+
LOG_ERR("%s", "failed to create MTP context\n");
1159+
if (ctx_dft) {
1160+
llama_free(ctx_dft);
1161+
}
1162+
return nullptr;
1163+
}
1164+
}
1165+
9861166
// Compute the implementations to use based on the config and their order of preference
9871167
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
9881168
{
9891169
bool has_draft = !params.draft.mparams.path.empty();
9901170
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
1171+
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP) && (ctx_mtp != nullptr);
9911172

9921173
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
9931174
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@@ -1034,6 +1215,9 @@ common_speculative * common_speculative_init(
10341215
if (has_draft_eagle3) {
10351216
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
10361217
}
1218+
if (has_mtp) {
1219+
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
1220+
}
10371221
}
10381222

10391223
std::vector<std::unique_ptr<common_speculative_state>> impls = {};
@@ -1058,6 +1242,11 @@ common_speculative * common_speculative_init(
10581242
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
10591243
break;
10601244
}
1245+
case COMMON_SPECULATIVE_TYPE_MTP: {
1246+
impls.push_back(std::make_unique<common_speculative_state_mtp>(
1247+
config.type, ctx_tgt, ctx_mtp));
1248+
break;
1249+
}
10611250
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
10621251
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
10631252

convert_hf_to_gguf.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5446,13 +5446,62 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
54465446
yield from super().modify_tensors(data_torch, name, bid)
54475447

54485448

5449+
class _Qwen35MtpMixin:
5450+
"""Shared MTP wiring for Qwen3.5/3.6 text variants. The HF config carries
5451+
the MTP block under `mtp_num_hidden_layers` and the tensors under
5452+
`mtp.*`; we extend block_count, emit the nextn metadata key, and remap
5453+
`mtp.*` to the standard layer-indexed nextn naming so the existing
5454+
tensor_map handles them."""
5455+
5456+
def __init__(self, *args, **kwargs):
5457+
super().__init__(*args, **kwargs)
5458+
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("mtp_num_hidden_layers", 0)
5459+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
5460+
5461+
def set_gguf_parameters(self):
5462+
super().set_gguf_parameters()
5463+
if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0:
5464+
self.gguf_writer.add_nextn_predict_layers(n)
5465+
5466+
def modify_tensors(self, data_torch, name: str, bid):
5467+
# Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`.
5468+
if name.startswith("model.language_model."):
5469+
name = "model." + name[len("model.language_model."):]
5470+
elif name.startswith("language_model."):
5471+
name = name[len("language_model."):]
5472+
5473+
# Remap MTP block tensors to llama.cpp's layer-indexed nextn naming.
5474+
# HF: mtp.layers.0.* (transformer block at MTP slot 0)
5475+
# mtp.fc / mtp.pre_fc_norm_embedding / mtp.pre_fc_norm_hidden / mtp.norm
5476+
if name.startswith("mtp."):
5477+
n_layer = self.hparams["num_hidden_layers"]
5478+
if name.find("layers.") != -1:
5479+
assert bid is not None
5480+
name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}")
5481+
else:
5482+
remapper = {
5483+
"mtp.fc": "model.layers.{bid}.eh_proj",
5484+
"mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
5485+
"mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
5486+
"mtp.norm": "model.layers.{bid}.shared_head.norm",
5487+
}
5488+
stem = Path(name).stem
5489+
suffix = Path(name).suffix
5490+
tmpl = remapper[stem] + suffix
5491+
for b in range(n_layer, self.block_count):
5492+
yield from super().modify_tensors(data_torch, tmpl.format(bid=b), b)
5493+
return
5494+
5495+
yield from super().modify_tensors(data_torch, name, bid)
5496+
5497+
54495498
@ModelBase.register("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM")
5450-
class Qwen3_5TextModel(_LinearAttentionVReorderBase):
5499+
class Qwen3_5TextModel(_Qwen35MtpMixin, _LinearAttentionVReorderBase):
54515500
model_arch = gguf.MODEL_ARCH.QWEN35
54525501

54535502

54545503
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
5455-
class Qwen3_5MoeTextModel(_LinearAttentionVReorderBase):
5504+
class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _LinearAttentionVReorderBase):
54565505
model_arch = gguf.MODEL_ARCH.QWEN35MOE
54575506

54585507

0 commit comments

Comments
 (0)