Skip to content

Commit fd9bc15

Browse files
yrougyclaude
andcommitted
merge : integrate PR ggml-org#22673 (MTP speculative decoding for qwen35/qwen35moe)
Adds llama_model_qwen35_mtp / llama_model_qwen35moe_mtp architectures that the server auto-loads from the same GGUF with override_arch when --spec-type mtp is requested. The MTP block runs as a separate draft context for speculative decoding, yielding ~2-3× throughput increase. Conflict resolutions vs our local changes: - qwen35.cpp / qwen35moe.cpp: removed our duplicate nextn_predict_layers block (now handled in the merged load_arch_hparams); kept TENSOR_SKIP for MTP-layer tensors in the base model to avoid loading ~200 MiB of unused weights into VRAM; extended TENSOR_SKIP block with the two new nextn tensors (embed_tokens, shared_head_head) using TENSOR_NOT_REQUIRED|TENSOR_SKIP so GUFs without them still work. - convert_hf_to_gguf.py: kept both _Qwen35MRopeMixin (mrope_section default) and _Qwen35MtpMixin (MTP block count/tensor remapping) as separate mixins; both classes now inherit from both. - tests/test-backend-ops.cpp: merged ggml_set_name + ggml_l2_norm from our side with the new keep_intermediates parameter from the PR. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2 parents dd42d51 + 5d5f1b4 commit fd9bc15

45 files changed

Lines changed: 1596 additions & 165 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

common/arg.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3568,12 +3568,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
35683568
}
35693569
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
35703570
add_opt(common_arg(
3571-
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
3571+
{"--spec-type"}, "[none|mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
35723572
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
35733573
common_speculative_type_to_str(params.speculative.type).c_str()),
35743574
[](common_params & params, const std::string & value) {
35753575
if (value == "none") {
35763576
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
3577+
} else if (value == "mtp") {
3578+
params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
35773579
} else if (value == "ngram-cache") {
35783580
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
35793581
} else if (value == "ngram-simple") {

common/common.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1420,6 +1420,11 @@ common_context_seq_rm_type common_context_can_seq_rm(llama_context * ctx) {
14201420
goto done;
14211421
}
14221422

1423+
if (llama_n_rs_seq(ctx) > 0) {
1424+
res = COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED;
1425+
goto done;
1426+
}
1427+
14231428
// try to remove the last tokens
14241429
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
14251430
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
@@ -1490,6 +1495,12 @@ struct llama_context_params common_context_params_to_llama(const common_params &
14901495

14911496
cparams.n_ctx = params.n_ctx;
14921497
cparams.n_seq_max = params.n_parallel;
1498+
{
1499+
// enable partial rollback only for MTP, each recurrent slot requires memory
1500+
// and MTP uses max 3-4 slots vs other techniques
1501+
const bool has_mtp_spec = params.speculative.type == COMMON_SPECULATIVE_TYPE_MTP;
1502+
cparams.n_rs_seq = has_mtp_spec ? (uint32_t) params.speculative.draft.n_max : 0u;
1503+
}
14931504
cparams.n_batch = params.n_batch;
14941505
cparams.n_ubatch = params.n_ubatch;
14951506
cparams.n_threads = params.cpuparams.n_threads;

common/common.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ enum common_speculative_type {
159159
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
160160
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
161161
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
162+
COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction
162163
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
163164
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
164165
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
@@ -347,11 +348,17 @@ struct common_params_speculative_ngram_cache {
347348
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding
348349
};
349350

351+
struct common_params_speculative_mtp {
352+
llama_model * model = nullptr;
353+
llama_context_params cparams;
354+
};
355+
350356
struct common_params_speculative {
351357
// TODO: become a vector in order to support "chains of speculators"
352358
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE;
353359

354360
common_params_speculative_draft draft;
361+
common_params_speculative_mtp mtp;
355362

356363
common_params_speculative_ngram_mod ngram_mod;
357364
common_params_speculative_ngram_map ngram_simple;
@@ -879,9 +886,10 @@ std::string common_get_model_endpoint();
879886
//
880887

881888
enum common_context_seq_rm_type {
882-
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
883-
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
884-
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
889+
COMMON_CONTEXT_SEQ_RM_TYPE_NO = 0, // seq_rm not supported (e.g. no memory module)
890+
COMMON_CONTEXT_SEQ_RM_TYPE_PART = 1, // can seq_rm partial sequences
891+
COMMON_CONTEXT_SEQ_RM_TYPE_FULL = 2, // can seq_rm full sequences only
892+
COMMON_CONTEXT_SEQ_RM_TYPE_PART_BOUNDED = 3, // can seq_rm partial sequences, bounded by n_rs_seq
885893
};
886894

887895
// check if the llama_context can remove sequences

common/speculative.cpp

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
2222
COMMON_SPECULATIVE_TYPE_NONE,
2323
COMMON_SPECULATIVE_TYPE_DRAFT,
2424
COMMON_SPECULATIVE_TYPE_EAGLE3,
25+
COMMON_SPECULATIVE_TYPE_MTP,
2526
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
2627
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
2728
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
@@ -33,6 +34,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
3334
{"none", COMMON_SPECULATIVE_TYPE_NONE},
3435
{"draft", COMMON_SPECULATIVE_TYPE_DRAFT},
3536
{"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3},
37+
{"mtp", COMMON_SPECULATIVE_TYPE_MTP},
3638
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
3739
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
3840
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
@@ -599,6 +601,171 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
599601
}
600602
};
601603

604+
struct common_speculative_state_mtp : public common_speculative_state {
605+
llama_context * ctx_tgt = nullptr;
606+
llama_context * ctx_mtp = nullptr;
607+
608+
llama_batch batch; // single token draft step
609+
common_sampler * smpl = nullptr;
610+
int32_t n_embd = 0;
611+
612+
uint16_t last_n_drafted = 0;
613+
int32_t last_n_accepted = -1;
614+
615+
common_speculative_state_mtp(enum common_speculative_type type,
616+
llama_context * ctx_tgt,
617+
llama_context * ctx_mtp)
618+
: common_speculative_state(type), ctx_tgt(ctx_tgt), ctx_mtp(ctx_mtp) {
619+
GGML_ASSERT(ctx_tgt && ctx_mtp);
620+
const llama_model * model_mtp = llama_get_model(ctx_mtp);
621+
n_embd = llama_model_n_embd(model_mtp);
622+
623+
{
624+
common_params_sampling sparams;
625+
sparams.no_perf = false;
626+
sparams.top_k = 1;
627+
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
628+
smpl = common_sampler_init(model_mtp, sparams);
629+
}
630+
631+
// TODO: multiple seq support
632+
batch = llama_batch_init(/*n_tokens=*/ 1, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
633+
batch.token = (llama_token *) malloc(sizeof(llama_token));
634+
batch.n_tokens = 1;
635+
batch.n_seq_id[0] = 1;
636+
batch.seq_id[0][0] = 0;
637+
batch.logits[0] = 1;
638+
639+
llama_set_mtp(ctx_tgt, ctx_mtp);
640+
}
641+
642+
~common_speculative_state_mtp() override {
643+
llama_set_mtp(ctx_tgt, nullptr);
644+
llama_batch_free(batch);
645+
common_sampler_free(smpl);
646+
if (ctx_mtp) {
647+
llama_free(ctx_mtp);
648+
}
649+
}
650+
651+
void begin(const llama_tokens & prompt) override {
652+
last_n_accepted = -1;
653+
last_n_drafted = 0;
654+
655+
const int32_t N = (int32_t) prompt.size();
656+
if (N <= 0) {
657+
return;
658+
}
659+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
660+
if (pos_max < N - 1) {
661+
LOG_WRN("%s: ctx_mtp pos_max=%d < N-1=%d — "
662+
"streaming hook may not be registered or not all prefill rows "
663+
"have logits=true. Drafts may degrade.\n",
664+
__func__, (int) pos_max, N - 1);
665+
}
666+
}
667+
668+
void draft(
669+
const common_params_speculative & params,
670+
const llama_tokens & prompt_tgt,
671+
llama_token id_last,
672+
llama_tokens & draft_tokens) override {
673+
GGML_UNUSED(prompt_tgt);
674+
draft_tokens.clear();
675+
676+
// accept with no-accepts (i.e. 0 accepts) returns early, but we still need to remove from the MTP kv-cache
677+
// TODO: check if bug in other spec states
678+
if (last_n_drafted > 0) {
679+
const int32_t n_to_drop = (int32_t) last_n_drafted - 1;
680+
if (n_to_drop > 0) {
681+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
682+
if (pos_max >= 0) {
683+
const llama_pos drop_from = pos_max - n_to_drop + 1;
684+
llama_memory_seq_rm(llama_get_memory(ctx_mtp), 0, drop_from, -1);
685+
}
686+
}
687+
last_n_drafted = 0;
688+
last_n_accepted = 0;
689+
}
690+
691+
const int32_t n_max = std::max(1, params.draft.n_max);
692+
const size_t row_bytes = (size_t) n_embd * sizeof(float);
693+
694+
llama_token cond_tok = id_last;
695+
llama_pos pos = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0) + 1;
696+
697+
// auto-regressive loop for MTP
698+
for (int32_t k = 0; k < n_max; ++k) {
699+
ggml_tensor * src;
700+
int32_t src_row;
701+
if (k == 0) {
702+
src = llama_context_get_t_h_pre_norm(ctx_tgt);
703+
if (last_n_accepted < 0) {
704+
// First draft after begin(): trunk's most recent decode is
705+
// the last prefill ubatch; its last row is h_{N-1}.
706+
src_row = (src && src->ne[1] > 0) ? (int32_t) src->ne[1] - 1 : 0;
707+
} else {
708+
src_row = last_n_accepted;
709+
}
710+
llama_synchronize(ctx_tgt);
711+
} else {
712+
// for the AR path get the mtp_out from the mtp ctx
713+
src = llama_context_get_t_mtp_out(ctx_mtp);
714+
src_row = src ? (int32_t) src->ne[1] - 1 : 0;
715+
llama_synchronize(ctx_mtp);
716+
}
717+
if (!src) {
718+
LOG_WRN("%s: missing source tensor at k=%d; stopping chain\n", __func__, k);
719+
return;
720+
}
721+
ggml_backend_tensor_get(src, batch.embd,
722+
(size_t) src_row * row_bytes, row_bytes);
723+
724+
batch.token[0] = cond_tok;
725+
batch.pos[0] = pos;
726+
727+
const int32_t dec_rc = llama_decode(ctx_mtp, batch);
728+
if (dec_rc != 0) {
729+
LOG_DBG("%s: llama_decode rc=%d at k=%d; stopping chain\n", __func__, dec_rc, k);
730+
return;
731+
}
732+
733+
const llama_token best = common_sampler_sample(smpl, ctx_mtp, 0);
734+
common_sampler_accept(smpl, best, /*accept_grammar=*/ false);
735+
draft_tokens.push_back(best);
736+
cond_tok = best;
737+
++pos;
738+
}
739+
740+
last_n_drafted = (uint16_t) draft_tokens.size();
741+
}
742+
743+
void accept(uint16_t n_accepted) override {
744+
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_mtp), 0);
745+
const int32_t n_drafted_last = (int32_t) last_n_drafted;
746+
const int32_t n_to_drop = std::max(0, n_drafted_last - (int32_t) n_accepted - 1);
747+
if (pos_max < 0) {
748+
last_n_accepted = (int32_t) n_accepted;
749+
return;
750+
}
751+
if (n_to_drop > 0) {
752+
const llama_pos drop_from = pos_max - n_to_drop + 1;
753+
llama_memory_seq_rm(llama_get_memory(ctx_mtp), /*seq_id=*/ 0,
754+
/*p0=*/ drop_from, /*p1=*/ -1);
755+
}
756+
last_n_drafted = 0;
757+
last_n_accepted = (int32_t) n_accepted;
758+
}
759+
760+
int32_t n_max(const common_params_speculative & params) const override {
761+
return std::max(1, params.draft.n_max);
762+
}
763+
764+
int32_t n_min(const common_params_speculative & params) const override {
765+
return std::max(1, params.draft.n_min);
766+
}
767+
};
768+
602769
// state of self-speculation (simple implementation, not ngram-map)
603770
struct common_speculative_state_ngram_simple : public common_speculative_state {
604771
common_ngram_simple_config config;
@@ -952,6 +1119,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
9521119
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
9531120
case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft";
9541121
case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3";
1122+
case COMMON_SPECULATIVE_TYPE_MTP: return "mtp";
9551123
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
9561124
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
9571125
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
@@ -983,11 +1151,24 @@ common_speculative * common_speculative_init(
9831151
}
9841152
}
9851153

1154+
llama_context * ctx_mtp = nullptr;
1155+
if (params.type == COMMON_SPECULATIVE_TYPE_MTP) {
1156+
ctx_mtp = llama_init_from_model(params.mtp.model, params.mtp.cparams);
1157+
if (ctx_mtp == nullptr) {
1158+
LOG_ERR("%s", "failed to create MTP context\n");
1159+
if (ctx_dft) {
1160+
llama_free(ctx_dft);
1161+
}
1162+
return nullptr;
1163+
}
1164+
}
1165+
9861166
// Compute the implementations to use based on the config and their order of preference
9871167
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
9881168
{
9891169
bool has_draft = !params.draft.mparams.path.empty();
9901170
bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3
1171+
bool has_mtp = (params.type == COMMON_SPECULATIVE_TYPE_MTP) && (ctx_mtp != nullptr);
9911172

9921173
bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE);
9931174
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
@@ -1034,6 +1215,9 @@ common_speculative * common_speculative_init(
10341215
if (has_draft_eagle3) {
10351216
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params));
10361217
}
1218+
if (has_mtp) {
1219+
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_MTP, params));
1220+
}
10371221
}
10381222

10391223
std::vector<std::unique_ptr<common_speculative_state>> impls = {};
@@ -1058,6 +1242,11 @@ common_speculative * common_speculative_init(
10581242
impls.push_back(std::make_unique<common_speculative_state_eagle3>(config.type));
10591243
break;
10601244
}
1245+
case COMMON_SPECULATIVE_TYPE_MTP: {
1246+
impls.push_back(std::make_unique<common_speculative_state_mtp>(
1247+
config.type, ctx_tgt, ctx_mtp));
1248+
break;
1249+
}
10611250
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
10621251
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
10631252

convert_hf_to_gguf.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5521,13 +5521,70 @@ def set_gguf_parameters(self):
55215521
self.gguf_writer.add_rope_dimension_sections(self._QWEN35_DEFAULT_MROPE_SECTION)
55225522

55235523

5524+
class _Qwen35MtpMixin:
5525+
"""Shared MTP wiring for Qwen3.5/3.6 text variants. The HF config carries
5526+
the MTP block under `mtp_num_hidden_layers` and the tensors under
5527+
`mtp.*`; we extend block_count, emit the nextn metadata key, and remap
5528+
`mtp.*` to the standard layer-indexed nextn naming so the existing
5529+
tensor_map handles them."""
5530+
5531+
# Class-level annotations so the type checker understands the attributes
5532+
# available on the concrete subclasses in the MRO
5533+
hparams: dict[str, Any]
5534+
model_arch: gguf.MODEL_ARCH
5535+
gguf_writer: gguf.GGUFWriter
5536+
block_count: int
5537+
tensor_map: gguf.TensorNameMap
5538+
5539+
def __init__(self, *args, **kwargs):
5540+
super().__init__(*args, **kwargs)
5541+
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("mtp_num_hidden_layers", 0)
5542+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
5543+
5544+
def set_gguf_parameters(self):
5545+
super().set_gguf_parameters() # ty: ignore[unresolved-attribute]
5546+
if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0:
5547+
self.gguf_writer.add_nextn_predict_layers(n)
5548+
5549+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
5550+
# Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`.
5551+
if name.startswith("model.language_model."):
5552+
name = "model." + name[len("model.language_model."):]
5553+
elif name.startswith("language_model."):
5554+
name = name[len("language_model."):]
5555+
5556+
# Remap MTP block tensors to llama.cpp's layer-indexed nextn naming.
5557+
# HF: mtp.layers.0.* (transformer block at MTP slot 0)
5558+
# mtp.fc / mtp.pre_fc_norm_embedding / mtp.pre_fc_norm_hidden / mtp.norm
5559+
if name.startswith("mtp."):
5560+
n_layer = self.hparams["num_hidden_layers"]
5561+
if name.find("layers.") != -1:
5562+
assert bid is not None
5563+
name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}")
5564+
else:
5565+
remapper = {
5566+
"mtp.fc": "model.layers.{bid}.eh_proj",
5567+
"mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
5568+
"mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
5569+
"mtp.norm": "model.layers.{bid}.shared_head.norm",
5570+
}
5571+
stem = Path(name).stem
5572+
suffix = Path(name).suffix
5573+
tmpl = remapper[stem] + suffix
5574+
for b in range(n_layer, self.block_count):
5575+
yield from super().modify_tensors(data_torch, tmpl.format(bid=b), b) # ty: ignore[unresolved-attribute]
5576+
return
5577+
5578+
yield from super().modify_tensors(data_torch, name, bid) # ty: ignore[unresolved-attribute]
5579+
5580+
55245581
@ModelBase.register("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM")
5525-
class Qwen3_5TextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase):
5582+
class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
55265583
model_arch = gguf.MODEL_ARCH.QWEN35
55275584

55285585

55295586
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
5530-
class Qwen3_5MoeTextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase):
5587+
class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
55315588
model_arch = gguf.MODEL_ARCH.QWEN35MOE
55325589

55335590

0 commit comments

Comments
 (0)