Skip to content

Commit 02a65ab

Browse files
committed
llama + spec: MTP support
1 parent 44fc110 commit 02a65ab

25 files changed

Lines changed: 843 additions & 39 deletions

common/arg.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3562,12 +3562,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
35623562
}
35633563
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
35643564
add_opt(common_arg(
3565-
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
3565+
{"--spec-type"}, "[none|mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
35663566
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
35673567
common_speculative_type_to_str(params.speculative.type).c_str()),
35683568
[](common_params & params, const std::string & value) {
35693569
if (value == "none") {
35703570
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
3571+
} else if (value == "mtp") {
3572+
params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
35713573
} else if (value == "ngram-cache") {
35723574
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
35733575
} else if (value == "ngram-simple") {

common/common.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1500,7 +1500,6 @@ struct llama_context_params common_context_params_to_llama(const common_params &
15001500
cparams.n_ctx = params.n_ctx;
15011501
cparams.n_seq_max = params.n_parallel;
15021502
{
1503-
// TODO: add for MTP
15041503
const bool has_spec = (params.speculative.type != COMMON_SPECULATIVE_TYPE_NONE)
15051504
|| params.speculative.has_dft();
15061505
cparams.n_rollback_max = has_spec ? (uint32_t) params.speculative.draft.n_max : 0u;

common/common.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ enum common_speculative_type {
159159
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
160160
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
161161
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
162+
COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction
162163
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
163164
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
164165
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
@@ -347,11 +348,17 @@ struct common_params_speculative_ngram_cache {
347348
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding
348349
};
349350

351+
struct common_params_speculative_mtp {
352+
llama_model * model = nullptr;
353+
llama_context_params cparams;
354+
};
355+
350356
struct common_params_speculative {
351357
// TODO: become a vector in order to support "chains of speculators"
352358
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
353359

354360
common_params_speculative_draft draft;
361+
common_params_speculative_mtp mtp;
355362

356363
common_params_speculative_ngram_mod ngram_mod;
357364
common_params_speculative_ngram_map ngram_simple;
@@ -363,6 +370,10 @@ struct common_params_speculative {
363370
bool has_dft() const {
364371
return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty();
365372
}
373+
374+
bool has_mtp() const {
375+
return type == COMMON_SPECULATIVE_TYPE_MTP && mtp.model != nullptr;
376+
}
366377
};
367378

368379
struct common_params_vocoder {

common/speculative.cpp

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
2222
COMMON_SPECULATIVE_TYPE_NONE,
2323
COMMON_SPECULATIVE_TYPE_DRAFT,
2424
COMMON_SPECULATIVE_TYPE_EAGLE3,
25+
COMMON_SPECULATIVE_TYPE_MTP,
2526
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
2627
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
2728
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
@@ -33,6 +34,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
3334
{"none", COMMON_SPECULATIVE_TYPE_NONE},
3435
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
3536
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
37+
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
3638
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
3739
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
3840
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@@ -608,6 +610,171 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
608610
}
609611
};
610612

613+
struct common_speculative_state_mtp : public common_speculative_state {
614+
llama_context * ctx_tgt = nullptr;
615+
llama_context * ctx_mtp = nullptr;
616+
617+
llama_batch batch; // single token draft step
618+
common_sampler * smpl = nullptr;
619+
int32_t n_embd = 0;
620+
621+
uint16_t last_n_drafted = 0;
622+
int32_t last_n_accepted = -1;
623+
624+
common_speculative_state_mtp(enum common_speculative_type type,
625+
llama_context * ctx_tgt,
626+
llama_context * ctx_mtp)
627+
: common_speculative_state(type), ctx_tgt(ctx_tgt), ctx_mtp(ctx_mtp) {
628+
GGML_ASSERT(ctx_tgt && ctx_mtp);
629+
const llama_model * model_mtp = llama_get_model(ctx_mtp);
630+
n_embd = llama_model_n_embd(model_mtp);
631+
632+
{
633+
common_params_sampling sparams;
634+
sparams.no_perf = false;
635+
sparams.top_k = 1;
636+
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
637+
smpl = common_sampler_init(model_mtp, sparams);
638+
}
639+
640+
// TODO: multiple seq support
641+
batch = llama_batch_init(/*n_tokens=*/ 1, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
642+
batch.token = (llama_token *) malloc(sizeof(llama_token));
643+
batch.n_tokens = 1;
644+
batch.n_seq_id[0] = 1;
645+
batch.seq_id[0][0] = 0;
646+
batch.logits[0] = 1;
647+
648+
llama_set_mtp(ctx_tgt, ctx_mtp);
649+
}
650+
651+
~common_speculative_state_mtp() override {
652+
llama_set_mtp(ctx_tgt, nullptr);
653+
llama_batch_free(batch);
654+
common_sampler_free(smpl);
655+
if (ctx_mtp) {
656+
llama_free(ctx_mtp);
657+
}
658+
}
659+
660+
void begin(const llama_tokens & prompt) override {
661+
last_n_accepted = -1;
662+
last_n_drafted = 0;
663+
664+
const int32_t N = (int32_t) prompt.size();
665+
if (N <= 0) {
666+
return;
667+
}
668+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
669+
if (pos_max < N - 1) {
670+
LOG_WRN("%s: ctx_mtp pos_max=%d < N-1=%d — "
671+
"streaming hook may not be registered or not all prefill rows "
672+
"have logits=true. Drafts may degrade.\n",
673+
__func__, (int) pos_max, N - 1);
674+
}
675+
}
676+
677+
void draft(
678+
const common_params_speculative & params,
679+
const llama_tokens & prompt_tgt,
680+
llama_token id_last,
681+
llama_tokens & draft_tokens) override {
682+
GGML_UNUSED(prompt_tgt);
683+
draft_tokens.clear();
684+
685+
// accept with no-accepts (i.e. 0 accepts) returns early, but we still need to remove from the MTP kv-cache
686+
// TODO: check if bug in other spec states
687+
if (last_n_drafted > 0) {
688+
const int32_t n_to_drop = (int32_t) last_n_drafted - 1;
689+
if (n_to_drop > 0) {
690+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
691+
if (pos_max >= 0) {
692+
const llama_pos drop_from = pos_max - n_to_drop + 1;
693+
llama_memory_seq_rm(llama_get_memory(ctx_mtp), 0, drop_from, -1);
694+
}
695+
}
696+
last_n_drafted = 0;
697+
last_n_accepted = 0;
698+
}
699+
700+
const int32_t n_max = std::max(1, params.draft.n_max);
701+
const size_t row_bytes = (size_t) n_embd * sizeof(float);
702+
703+
llama_token cond_tok = id_last;
704+
llama_pos pos = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0) + 1;
705+
706+
// auto-regressive loop for MTP
707+
for (int32_t k = 0; k < n_max; ++k) {
708+
ggml_tensor * src;
709+
int32_t src_row;
710+
if (k == 0) {
711+
src = llama_context_get_t_h_pre_norm(ctx_tgt);
712+
if (last_n_accepted < 0) {
713+
// First draft after begin(): trunk's most recent decode is
714+
// the last prefill ubatch; its last row is h_{N-1}.
715+
src_row = (src && src->ne[1] > 0) ? (int32_t) src->ne[1] - 1 : 0;
716+
} else {
717+
src_row = last_n_accepted;
718+
}
719+
llama_synchronize(ctx_tgt);
720+
} else {
721+
// for the AR path get the mtp_out from the mtp ctx
722+
src = llama_context_get_t_mtp_out(ctx_mtp);
723+
src_row = src ? (int32_t) src->ne[1] - 1 : 0;
724+
llama_synchronize(ctx_mtp);
725+
}
726+
if (!src) {
727+
LOG_WRN("%s: missing source tensor at k=%d; stopping chain\n", __func__, k);
728+
return;
729+
}
730+
ggml_backend_tensor_get(src, batch.embd,
731+
(size_t) src_row * row_bytes, row_bytes);
732+
733+
batch.token[0] = cond_tok;
734+
batch.pos[0] = pos;
735+
736+
const int32_t dec_rc = llama_decode(ctx_mtp, batch);
737+
if (dec_rc != 0) {
738+
LOG_DBG("%s: llama_decode rc=%d at k=%d; stopping chain\n", __func__, dec_rc, k);
739+
return;
740+
}
741+
742+
const llama_token best = common_sampler_sample(smpl, ctx_mtp, 0);
743+
common_sampler_accept(smpl, best, /*accept_grammar=*/ false);
744+
draft_tokens.push_back(best);
745+
cond_tok = best;
746+
++pos;
747+
}
748+
749+
last_n_drafted = (uint16_t) draft_tokens.size();
750+
}
751+
752+
void accept(uint16_t n_accepted) override {
753+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
754+
const int32_t n_drafted_last = (int32_t) last_n_drafted;
755+
const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1);
756+
if (pos_max < 0) {
757+
last_n_accepted = (int32_t) n_accepted;
758+
return;
759+
}
760+
if (n_to_drop > 0) {
761+
const llama_pos drop_from = pos_max - n_to_drop + 1;
762+
llama_memory_seq_rm(llama_get_memory(ctx_mtp), /*seq_id=*/ 0,
763+
/*p0=*/ drop_from, /*p1=*/ -1);
764+
}
765+
last_n_drafted = 0;
766+
last_n_accepted = (int32_t) n_accepted;
767+
}
768+
769+
int32_t n_max(const common_params_speculative & params) const override {
770+
return std::max(1, params.draft.n_max);
771+
}
772+
773+
int32_t n_min(const common_params_speculative & params) const override {
774+
return std::max(1, params.draft.n_min);
775+
}
776+
};
777+
611778
// state of self-speculation (simple implementation, not ngram-map)
612779
struct common_speculative_state_ngram_simple : public common_speculative_state {
613780
common_ngram_simple_config config;
@@ -963,6 +1130,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
9631130
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
9641131
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
9651132
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
1133+
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
9661134
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
9671135
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
9681136
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
@@ -994,11 +1162,24 @@ common_speculative * common_speculative_init(
9941162
}
9951163
}
9961164

1165+
llama_context * ctx_mtp = nullptr;
1166+
if (params.has_mtp()) {
1167+
ctx_mtp = llama_init_from_model(params.mtp.model, params.mtp.cparams);
1168+
if (ctx_mtp == nullptr) {
1169+
LOG_ERR("%s", "failed to create MTP context\n");
1170+
if (ctx_dft) {
1171+
llama_free(ctx_dft);
1172+
}
1173+
return nullptr;
1174+
}
1175+
}
1176+
9971177
// Compute the implementations to use based on the config and their order of preference
9981178
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
9991179
{
10001180
bool has_draft = !params.draft.mparams.path.empty();
10011181
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
1182+
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP) && (ctx_mtp != nullptr);
10021183

10031184
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
10041185
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@@ -1045,6 +1226,9 @@ common_speculative * common_speculative_init(
10451226
if (has_draft_eagle3) {
10461227
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
10471228
}
1229+
if (has_mtp) {
1230+
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
1231+
}
10481232
}
10491233

10501234
std::vector<std::unique_ptr<common_speculative_state>> impls = {};
@@ -1069,6 +1253,11 @@ common_speculative * common_speculative_init(
10691253
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
10701254
break;
10711255
}
1256+
case COMMON_SPECULATIVE_TYPE_MTP: {
1257+
impls.push_back(std::make_unique<common_speculative_state_mtp>(
1258+
config.type, ctx_tgt, ctx_mtp));
1259+
break;
1260+
}
10721261
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
10731262
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
10741263

convert_hf_to_gguf.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5415,13 +5415,62 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
54155415
yield from super().modify_tensors(data_torch, name, bid)
54165416

54175417

5418+
class _Qwen35MtpMixin:
5419+
"""Shared MTP wiring for Qwen3.5/3.6 text variants. The HF config carries
5420+
the MTP block under `mtp_num_hidden_layers` and the tensors under
5421+
`mtp.*`; we extend block_count, emit the nextn metadata key, and remap
5422+
`mtp.*` to the standard layer-indexed nextn naming so the existing
5423+
tensor_map handles them."""
5424+
5425+
def __init__(self, *args, **kwargs):
5426+
super().__init__(*args, **kwargs)
5427+
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("mtp_num_hidden_layers", 0)
5428+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
5429+
5430+
def set_gguf_parameters(self):
5431+
super().set_gguf_parameters()
5432+
if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0:
5433+
self.gguf_writer.add_nextn_predict_layers(n)
5434+
5435+
def modify_tensors(self, data_torch, name: str, bid):
5436+
# Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`.
5437+
if name.startswith("model.language_model."):
5438+
name = "model." + name[len("model.language_model."):]
5439+
elif name.startswith("language_model."):
5440+
name = name[len("language_model."):]
5441+
5442+
# Remap MTP block tensors to llama.cpp's layer-indexed nextn naming.
5443+
# HF: mtp.layers.0.* (transformer block at MTP slot 0)
5444+
# mtp.fc / mtp.pre_fc_norm_embedding / mtp.pre_fc_norm_hidden / mtp.norm
5445+
if name.startswith("mtp."):
5446+
n_layer = self.hparams["num_hidden_layers"]
5447+
if name.find("layers.") != -1:
5448+
assert bid is not None
5449+
name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}")
5450+
else:
5451+
remapper = {
5452+
"mtp.fc": "model.layers.{bid}.eh_proj",
5453+
"mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
5454+
"mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
5455+
"mtp.norm": "model.layers.{bid}.shared_head.norm",
5456+
}
5457+
stem = Path(name).stem
5458+
suffix = Path(name).suffix
5459+
tmpl = remapper[stem] + suffix
5460+
for b in range(n_layer, self.block_count):
5461+
yield from super().modify_tensors(data_torch, tmpl.format(bid=b), b)
5462+
return
5463+
5464+
yield from super().modify_tensors(data_torch, name, bid)
5465+
5466+
54185467
@ModelBase.register("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM")
5419-
class Qwen3_5TextModel(_LinearAttentionVReorderBase):
5468+
class Qwen3_5TextModel(_Qwen35MtpMixin, _LinearAttentionVReorderBase):
54205469
model_arch = gguf.MODEL_ARCH.QWEN35
54215470

54225471

54235472
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
5424-
class Qwen3_5MoeTextModel(_LinearAttentionVReorderBase):
5473+
class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _LinearAttentionVReorderBase):
54255474
model_arch = gguf.MODEL_ARCH.QWEN35MOE
54265475

54275476

0 commit comments

Comments
 (0)