Skip to content

Commit a428b01

Browse files
committed
spec: support MTP
1 parent db8e326 commit a428b01

25 files changed

Lines changed: 1213 additions & 43 deletions

common/arg.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3550,12 +3550,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
35503550
}
35513551
).set_spec().set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_SPEC_DRAFT_MODEL"));
35523552
add_opt(common_arg(
3553-
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
3553+
{"--spec-type"}, "[none|mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
35543554
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
35553555
common_speculative_type_to_str(params.speculative.type).c_str()),
35563556
[](common_params & params, const std::string & value) {
35573557
if (value == "none") {
35583558
params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE;
3559+
} else if (value == "mtp") {
3560+
params.speculative.type = COMMON_SPECULATIVE_TYPE_MTP;
35593561
} else if (value == "ngram-cache") {
35603562
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE;
35613563
} else if (value == "ngram-simple") {

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ enum common_speculative_type {
159159
COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding
160160
COMMON_SPECULATIVE_TYPE_DRAFT, // draft model
161161
COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model
162+
COMMON_SPECULATIVE_TYPE_MTP, // multi-token prediction head loaded from the target GGUF
162163
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
163164
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
164165
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values

common/speculative.cpp

Lines changed: 336 additions & 0 deletions
Large diffs are not rendered by default.

convert_hf_to_gguf.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5518,13 +5518,70 @@ def set_gguf_parameters(self):
55185518
self.gguf_writer.add_rope_dimension_sections(self._QWEN35_DEFAULT_MROPE_SECTION)
55195519

55205520

5521+
class _Qwen35MtpMixin:
5522+
"""Shared MTP wiring for Qwen3.5/3.6 text variants. The HF config carries
5523+
the MTP block under `mtp_num_hidden_layers` and the tensors under
5524+
`mtp.*`; we extend block_count, emit the nextn metadata key, and remap
5525+
`mtp.*` to the standard layer-indexed nextn naming so the existing
5526+
tensor_map handles them."""
5527+
5528+
# Class-level annotations so the type checker understands the attributes
5529+
# available on the concrete subclasses in the MRO
5530+
hparams: dict[str, Any]
5531+
model_arch: gguf.MODEL_ARCH
5532+
gguf_writer: gguf.GGUFWriter
5533+
block_count: int
5534+
tensor_map: gguf.TensorNameMap
5535+
5536+
def __init__(self, *args, **kwargs):
5537+
super().__init__(*args, **kwargs)
5538+
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("mtp_num_hidden_layers", 0)
5539+
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
5540+
5541+
def set_gguf_parameters(self):
5542+
super().set_gguf_parameters() # ty: ignore[unresolved-attribute]
5543+
if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0:
5544+
self.gguf_writer.add_nextn_predict_layers(n)
5545+
5546+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
5547+
# Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`.
5548+
if name.startswith("model.language_model."):
5549+
name = "model." + name[len("model.language_model."):]
5550+
elif name.startswith("language_model."):
5551+
name = name[len("language_model."):]
5552+
5553+
# Remap MTP block tensors to llama.cpp's layer-indexed nextn naming.
5554+
# HF: mtp.layers.0.* (transformer block at MTP slot 0)
5555+
# mtp.fc / mtp.pre_fc_norm_embedding / mtp.pre_fc_norm_hidden / mtp.norm
5556+
if name.startswith("mtp."):
5557+
n_layer = self.hparams["num_hidden_layers"]
5558+
if name.find("layers.") != -1:
5559+
assert bid is not None
5560+
name = name.replace(f"mtp.layers.{bid}", f"model.layers.{bid + n_layer}")
5561+
else:
5562+
remapper = {
5563+
"mtp.fc": "model.layers.{bid}.eh_proj",
5564+
"mtp.pre_fc_norm_embedding": "model.layers.{bid}.enorm",
5565+
"mtp.pre_fc_norm_hidden": "model.layers.{bid}.hnorm",
5566+
"mtp.norm": "model.layers.{bid}.shared_head.norm",
5567+
}
5568+
stem = Path(name).stem
5569+
suffix = Path(name).suffix
5570+
tmpl = remapper[stem] + suffix
5571+
for b in range(n_layer, self.block_count):
5572+
yield from super().modify_tensors(data_torch, tmpl.format(bid=b), b) # ty: ignore[unresolved-attribute]
5573+
return
5574+
5575+
yield from super().modify_tensors(data_torch, name, bid) # ty: ignore[unresolved-attribute]
5576+
5577+
55215578
@ModelBase.register("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM")
5522-
class Qwen3_5TextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase):
5579+
class Qwen3_5TextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
55235580
model_arch = gguf.MODEL_ARCH.QWEN35
55245581

55255582

55265583
@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM")
5527-
class Qwen3_5MoeTextModel(_Qwen35MRopeMixin, _LinearAttentionVReorderBase):
5584+
class Qwen3_5MoeTextModel(_Qwen35MtpMixin, _Qwen35MRopeMixin, _LinearAttentionVReorderBase):
55285585
model_arch = gguf.MODEL_ARCH.QWEN35MOE
55295586

55305587

gguf-py/gguf/constants.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2109,7 +2109,14 @@ class MODEL_TENSOR(IntEnum):
21092109
MODEL_TENSOR.SSM_NORM,
21102110
MODEL_TENSOR.SSM_BETA,
21112111
MODEL_TENSOR.SSM_ALPHA,
2112-
MODEL_TENSOR.SSM_OUT
2112+
MODEL_TENSOR.SSM_OUT,
2113+
# NextN/MTP tensors - preserved but unused
2114+
MODEL_TENSOR.NEXTN_EH_PROJ,
2115+
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
2116+
MODEL_TENSOR.NEXTN_ENORM,
2117+
MODEL_TENSOR.NEXTN_HNORM,
2118+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
2119+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
21132120
],
21142121
MODEL_ARCH.QWEN35MOE: [
21152122
MODEL_TENSOR.TOKEN_EMBD,
@@ -2140,7 +2147,14 @@ class MODEL_TENSOR(IntEnum):
21402147
MODEL_TENSOR.SSM_NORM,
21412148
MODEL_TENSOR.SSM_BETA,
21422149
MODEL_TENSOR.SSM_ALPHA,
2143-
MODEL_TENSOR.SSM_OUT
2150+
MODEL_TENSOR.SSM_OUT,
2151+
# NextN/MTP tensors - preserved but unused
2152+
MODEL_TENSOR.NEXTN_EH_PROJ,
2153+
MODEL_TENSOR.NEXTN_EMBED_TOKENS,
2154+
MODEL_TENSOR.NEXTN_ENORM,
2155+
MODEL_TENSOR.NEXTN_HNORM,
2156+
MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD,
2157+
MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM,
21442158
],
21452159
MODEL_ARCH.PLAMO: [
21462160
MODEL_TENSOR.TOKEN_EMBD,

include/llama.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,9 @@ extern "C" {
310310
// override key-value pairs of the model meta data
311311
const struct llama_model_kv_override * kv_overrides;
312312

313+
// override architecture from GGUF (e.g. load the MTP head of a Qwen3.5 GGUF as "qwen35_mtp")
314+
const char * override_arch;
315+
313316
// Keep the booleans together to avoid misalignment during copy-by-value.
314317
bool vocab_only; // only load the vocabulary, no weights
315318
bool use_mmap; // use mmap if possible

src/llama-arch.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
4141
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
4242
{ LLM_ARCH_QWEN35, "qwen35" },
4343
{ LLM_ARCH_QWEN35MOE, "qwen35moe" },
44+
{ LLM_ARCH_QWEN35_MTP, "qwen35_mtp" },
45+
{ LLM_ARCH_QWEN35MOE_MTP, "qwen35moe_mtp" },
4446
{ LLM_ARCH_PHI2, "phi2" },
4547
{ LLM_ARCH_PHI3, "phi3" },
4648
{ LLM_ARCH_PHIMOE, "phimoe" },
@@ -757,14 +759,15 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
757759
{LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
758760
{LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
759761
{LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
760-
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
761-
// These tensors only exist in the last layer(s) and are treated as output tensors
762-
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
763-
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
764-
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
765-
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
766-
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
767-
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}},
762+
// NextN/MTP tensors are stored per-block (blk.%d.nextn.*) even though only the
763+
// last nextn_predict_layers blocks carry them. Classify as LAYER_REPEATING so
764+
// the model loader doesn't fault on the block index.
765+
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
766+
{LLM_TENSOR_NEXTN_EMBED_TOKENS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
767+
{LLM_TENSOR_NEXTN_ENORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_GET_ROWS}},
768+
{LLM_TENSOR_NEXTN_HNORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
769+
{LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
770+
{LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
768771
// Nemotron 3 Super
769772
{LLM_TENSOR_FFN_LATENT_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
770773
{LLM_TENSOR_FFN_LATENT_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},

src/llama-arch.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ enum llm_arch {
4545
LLM_ARCH_QWEN3VLMOE,
4646
LLM_ARCH_QWEN35,
4747
LLM_ARCH_QWEN35MOE,
48+
LLM_ARCH_QWEN35_MTP,
49+
LLM_ARCH_QWEN35MOE_MTP,
4850
LLM_ARCH_PHI2,
4951
LLM_ARCH_PHI3,
5052
LLM_ARCH_PHIMOE,

0 commit comments

Comments
 (0)