Skip to content

Commit 2c40559

Browse files
committed
MTP: clean-up (#9)
* MTP: clean-up * review: use llama_context_type instead of llama_graph_type * review: remove llama_model_has_mtp * review: fix convert issues * convert: fix pycheck * review: formatting * use `mtp-` for identifying mtp models * convert: fix mtp conversion
1 parent c871587 commit 2c40559

20 files changed

Lines changed: 689 additions & 648 deletions

common/arg.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,15 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
335335
struct handle_model_result {
336336
bool found_mmproj = false;
337337
common_params_model mmproj;
338+
339+
bool found_mtp = false;
340+
common_params_model mtp;
338341
};
339342

340343
static handle_model_result common_params_handle_model(struct common_params_model & model,
341344
const std::string & bearer_token,
342-
bool offline) {
345+
bool offline,
346+
bool search_mtp = false) {
343347
handle_model_result result;
344348

345349
if (!model.docker_repo.empty()) {
@@ -354,7 +358,7 @@ static handle_model_result common_params_handle_model(struct common_params_model
354358
common_download_opts opts;
355359
opts.bearer_token = bearer_token;
356360
opts.offline = offline;
357-
auto download_result = common_download_model(model, opts, true);
361+
auto download_result = common_download_model(model, opts, true, search_mtp);
358362

359363
if (download_result.model_path.empty()) {
360364
LOG_ERR("error: failed to download model from Hugging Face\n");
@@ -368,6 +372,11 @@ static handle_model_result common_params_handle_model(struct common_params_model
368372
result.found_mmproj = true;
369373
result.mmproj.path = download_result.mmproj_path;
370374
}
375+
376+
if (!download_result.mtp_path.empty()) {
377+
result.found_mtp = true;
378+
result.mtp.path = download_result.mtp_path;
379+
}
371380
} else if (!model.url.empty()) {
372381
if (model.path.empty()) {
373382
auto f = string_split<std::string>(model.url, '#').front();
@@ -436,7 +445,11 @@ static bool parse_bool_value(const std::string & value) {
436445
//
437446

438447
void common_params_handle_models(common_params & params, llama_example curr_ex) {
439-
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
448+
const bool spec_type_mtp = std::find(params.speculative.types.begin(),
449+
params.speculative.types.end(),
450+
COMMON_SPECULATIVE_TYPE_MTP) != params.speculative.types.end();
451+
452+
auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_mtp);
440453
if (params.no_mmproj) {
441454
params.mmproj = {};
442455
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
@@ -450,6 +463,14 @@ void common_params_handle_models(common_params & params, llama_example curr_ex)
450463
break;
451464
}
452465
}
466+
// when --spec-type mtp is set and no draft model was provided explicitly,
467+
// fall back to the MTP head discovered alongside the -hf model
468+
if (spec_type_mtp && res.found_mtp &&
469+
params.speculative.draft.mparams.path.empty() &&
470+
params.speculative.draft.mparams.hf_repo.empty() &&
471+
params.speculative.draft.mparams.url.empty()) {
472+
params.speculative.draft.mparams.path = res.mtp.path;
473+
}
453474
common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline);
454475
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
455476
}

common/download.cpp

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -566,8 +566,11 @@ static hf_cache::hf_files get_split_files(const hf_cache::hf_files & files,
566566
return result;
567567
}
568568

569-
static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
570-
const std::string & model) {
569+
// pick the best sibling GGUF whose filename contains `keyword` (e.g. "mmproj" / "mtp"),
570+
// preferring deeper shared directory prefix with the model, then closest quantization
571+
static hf_cache::hf_file find_best_sibling(const hf_cache::hf_files & files,
572+
const std::string & model,
573+
const std::string & keyword) {
571574
hf_cache::hf_file best;
572575
size_t best_depth = 0;
573576
int best_diff = 0;
@@ -579,20 +582,20 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
579582

580583
for (const auto & f : files) {
581584
if (!string_ends_with(f.path, ".gguf") ||
582-
f.path.find("mmproj") == std::string::npos) {
585+
f.path.find(keyword) == std::string::npos) {
583586
continue;
584587
}
585588

586-
auto mmproj_parts = string_split<std::string>(f.path, '/');
587-
auto mmproj_dir = mmproj_parts.end() - 1;
589+
auto sib_parts = string_split<std::string>(f.path, '/');
590+
auto sib_dir = sib_parts.end() - 1;
588591

589592
auto [_, dir] = std::mismatch(model_parts.begin(), model_dir,
590-
mmproj_parts.begin(), mmproj_dir);
591-
if (dir != mmproj_dir) {
593+
sib_parts.begin(), sib_dir);
594+
if (dir != sib_dir) {
592595
continue;
593596
}
594597

595-
size_t depth = dir - mmproj_parts.begin();
598+
size_t depth = dir - sib_parts.begin();
596599
auto bits = extract_quant_bits(f.path);
597600
auto diff = std::abs(bits - model_bits);
598601

@@ -606,6 +609,16 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
606609
return best;
607610
}
608611

612+
static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
613+
const std::string & model) {
614+
return find_best_sibling(files, model, "mmproj");
615+
}
616+
617+
static hf_cache::hf_file find_best_mtp(const hf_cache::hf_files & files,
618+
const std::string & model) {
619+
return find_best_sibling(files, model, "mtp-");
620+
}
621+
609622
static bool gguf_filename_is_model(const std::string & filepath) {
610623
if (!string_ends_with(filepath, ".gguf")) {
611624
return false;
@@ -617,7 +630,8 @@ static bool gguf_filename_is_model(const std::string & filepath) {
617630
}
618631

619632
return filename.find("mmproj") == std::string::npos &&
620-
filename.find("imatrix") == std::string::npos;
633+
filename.find("imatrix") == std::string::npos &&
634+
filename.find("mtp-") == std::string::npos;
621635
}
622636

623637
static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
@@ -673,11 +687,13 @@ struct hf_plan {
673687
hf_cache::hf_file primary;
674688
hf_cache::hf_files model_files;
675689
hf_cache::hf_file mmproj;
690+
hf_cache::hf_file mtp;
676691
};
677692

678693
static hf_plan get_hf_plan(const common_params_model & model,
679694
const common_download_opts & opts,
680-
bool download_mmproj) {
695+
bool download_mmproj,
696+
bool download_mtp) {
681697
hf_plan plan;
682698
hf_cache::hf_files all;
683699

@@ -723,6 +739,10 @@ static hf_plan get_hf_plan(const common_params_model & model,
723739
plan.mmproj = find_best_mmproj(all, primary.path);
724740
}
725741

742+
if (download_mtp) {
743+
plan.mtp = find_best_mtp(all, primary.path);
744+
}
745+
726746
return plan;
727747
}
728748

@@ -756,21 +776,25 @@ static std::vector<download_task> get_url_tasks(const common_params_model & mode
756776

757777
common_download_model_result common_download_model(const common_params_model & model,
758778
const common_download_opts & opts,
759-
bool download_mmproj) {
779+
bool download_mmproj,
780+
bool download_mtp) {
760781
common_download_model_result result;
761782
std::vector<download_task> tasks;
762783
hf_plan hf;
763784

764785
bool is_hf = !model.hf_repo.empty();
765786

766787
if (is_hf) {
767-
hf = get_hf_plan(model, opts, download_mmproj);
788+
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
768789
for (const auto & f : hf.model_files) {
769790
tasks.push_back({f.url, f.local_path});
770791
}
771792
if (!hf.mmproj.path.empty()) {
772793
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
773794
}
795+
if (!hf.mtp.path.empty()) {
796+
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
797+
}
774798
} else if (!model.url.empty()) {
775799
tasks = get_url_tasks(model);
776800
} else {
@@ -807,6 +831,10 @@ common_download_model_result common_download_model(const common_params_model &
807831
if (!hf.mmproj.path.empty()) {
808832
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
809833
}
834+
835+
if (!hf.mtp.path.empty()) {
836+
result.mtp_path = hf_cache::finalize_file(hf.mtp);
837+
}
810838
} else {
811839
result.model_path = model.path;
812840
}
@@ -946,7 +974,8 @@ std::vector<common_cached_model_info> common_list_cached_models() {
946974
for (const auto & f : files) {
947975
auto split = get_gguf_split_info(f.path);
948976
if (split.index != 1 || split.tag.empty() ||
949-
split.prefix.find("mmproj") != std::string::npos) {
977+
split.prefix.find("mmproj") != std::string::npos ||
978+
split.prefix.find("MTP") != std::string::npos) {
950979
continue;
951980
}
952981
if (seen.insert(f.repo_id + ":" + split.tag).second) {

common/download.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ struct common_download_opts {
5959
struct common_download_model_result {
6060
std::string model_path;
6161
std::string mmproj_path;
62+
std::string mtp_path;
6263
};
6364

6465
// Download model from HuggingFace repo or URL
@@ -83,12 +84,14 @@ struct common_download_model_result {
8384
// when opts.offline=true, no network requests are made
8485
// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory
8586
// then with the closest quantization bits
87+
// when download_mtp=true, applies the same sibling search for an MTP-head GGUF
8688
//
87-
// returns result with model_path and mmproj_path (empty on failure)
89+
// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure)
8890
common_download_model_result common_download_model(
8991
const common_params_model & model,
9092
const common_download_opts & opts = {},
91-
bool download_mmproj = false
93+
bool download_mmproj = false,
94+
bool download_mtp = false
9295
);
9396

9497
// returns list of cached models

common/speculative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1198,7 +1198,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
11981198
LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__);
11991199
has_draft = false;
12001200
}
1201-
} else if (has_draft_model) {
1201+
} else if (has_draft_model && !has_mtp && !has_draft_eagle3) {
12021202
LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__);
12031203
has_draft = true;
12041204
}

convert_hf_to_gguf.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class ModelBase:
9595
gguf_writer: gguf.GGUFWriter
9696
model_name: str | None
9797
metadata_override: Path | None
98+
metadata: gguf.Metadata
9899
dir_model_card: Path
99100
remote_hf_model_id: str | None
100101

@@ -5564,26 +5565,59 @@ class _Qwen35MtpMixin:
55645565
block_count: int
55655566
tensor_map: gguf.TensorNameMap
55665567

5568+
mtp_only: bool = False
5569+
no_mtp: bool = False
5570+
55675571
def __init__(self, *args, **kwargs):
55685572
super().__init__(*args, **kwargs)
5569-
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("mtp_num_hidden_layers", 0)
5573+
self.block_count = self.hparams["num_hidden_layers"]
5574+
if not self.no_mtp:
5575+
self.block_count += self.hparams.get("mtp_num_hidden_layers", 0)
55705576
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
55715577

5578+
@classmethod
5579+
def filter_tensors(cls, item):
5580+
name, _ = item
5581+
if name.startswith("mtp."):
5582+
if cls.no_mtp:
5583+
return None
5584+
return item
5585+
if cls.mtp_only:
5586+
# In --mtp mode, drop trunk weights and keep only the shared embeddings/output
5587+
# tensors that the standalone MTP graph references at inference time.
5588+
canonical = name.replace("language_model.", "")
5589+
keep = canonical in (
5590+
"model.embed_tokens.weight", "model.norm.weight", "lm_head.weight",
5591+
"embed_tokens.weight", "norm.weight",
5592+
)
5593+
if not keep:
5594+
return None
5595+
return super().filter_tensors(item) # ty: ignore[unresolved-attribute]
5596+
55725597
def set_gguf_parameters(self):
55735598
super().set_gguf_parameters() # ty: ignore[unresolved-attribute]
5599+
if self.no_mtp:
5600+
return
55745601
if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0:
55755602
self.gguf_writer.add_nextn_predict_layers(n)
55765603

5577-
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
5578-
# Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`.
5579-
if name.startswith("model.language_model."):
5580-
name = "model." + name[len("model.language_model."):]
5581-
elif name.startswith("language_model."):
5582-
name = name[len("language_model."):]
5604+
def prepare_metadata(self, vocab_only: bool):
5605+
# TextModel.prepare_metadata resolves a directory fname_out into a concrete
5606+
# file path, so snapshot is_dir() first to decide whether to apply the mtp- prefix.
5607+
from_dir = self.fname_out.is_dir()
5608+
super().prepare_metadata(vocab_only=vocab_only) # ty: ignore[unresolved-attribute]
5609+
5610+
if not self.mtp_only or not from_dir:
5611+
return
55835612

5613+
output_type: str = self.ftype.name.partition("_")[2] # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
5614+
fname_default: str = gguf.naming_convention(
5615+
self.metadata.name, self.metadata.basename, self.metadata.finetune, # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
5616+
self.metadata.version, size_label=None, output_type=output_type, model_type=None) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
5617+
self.fname_out = self.fname_out.parent / f"mtp-{fname_default}.gguf"
5618+
5619+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
55845620
# Remap MTP block tensors to llama.cpp's layer-indexed nextn naming.
5585-
# HF: mtp.layers.0.* (transformer block at MTP slot 0)
5586-
# mtp.fc / mtp.pre_fc_norm_embedding / mtp.pre_fc_norm_hidden / mtp.norm
55875621
if name.startswith("mtp."):
55885622
n_layer = self.hparams["num_hidden_layers"]
55895623
if name.find("layers.") != -1:
@@ -14109,6 +14143,14 @@ def parse_args() -> argparse.Namespace:
1410914143
"--mmproj", action="store_true",
1411014144
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
1411114145
)
14146+
parser.add_argument(
14147+
"--mtp", action="store_true",
14148+
help="(Experimental) Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. Output file name will get a '-MTP' suffix.",
14149+
)
14150+
parser.add_argument(
14151+
"--no-mtp", action="store_true",
14152+
help="(Experimental) Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, so the bundled default is more space-efficient overall.",
14153+
)
1411214154
parser.add_argument(
1411314155
"--mistral-format", action="store_true",
1411414156
help="Whether the model is stored following the Mistral format.",
@@ -14268,6 +14310,20 @@ def main() -> None:
1426814310
else:
1426914311
model_class = MistralModel
1427014312

14313+
if args.mtp and args.no_mtp:
14314+
logger.error("--mtp and --no-mtp are mutually exclusive")
14315+
sys.exit(1)
14316+
14317+
if (args.mtp or args.no_mtp) and not issubclass(model_class, _Qwen35MtpMixin):
14318+
logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 text variants today")
14319+
sys.exit(1)
14320+
14321+
# set on the class so __init__ / filter_tensors see the correct mode
14322+
if args.no_mtp:
14323+
model_class.no_mtp = True # ty: ignore[unresolved-attribute]
14324+
if args.mtp:
14325+
model_class.mtp_only = True # ty: ignore[unresolved-attribute]
14326+
1427114327
model_instance = model_class(dir_model, output_type, fname_out,
1427214328
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
1427314329
eager=args.no_lazy,

include/llama.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,11 @@ extern "C" {
198198
LLAMA_SPLIT_MODE_TENSOR = 3,
199199
};
200200

201+
enum llama_context_type {
202+
LLAMA_CONTEXT_TYPE_DEFAULT = 0,
203+
LLAMA_CONTEXT_TYPE_MTP = 1,
204+
};
205+
201206
// TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979)
202207
typedef struct llama_token_data {
203208
llama_token id; // token id
@@ -339,6 +344,7 @@ extern "C" {
339344
int32_t n_threads; // number of threads to use for generation
340345
int32_t n_threads_batch; // number of threads to use for batch processing
341346

347+
enum llama_context_type ctx_type; // set the context type (e.g. MTP)
342348
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
343349
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
344350
enum llama_attention_type attention_type; // attention type to use for embeddings

src/llama-arch.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
4141
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
4242
{ LLM_ARCH_QWEN35, "qwen35" },
4343
{ LLM_ARCH_QWEN35MOE, "qwen35moe" },
44-
{ LLM_ARCH_QWEN35_MTP, "qwen35_mtp" },
45-
{ LLM_ARCH_QWEN35MOE_MTP, "qwen35moe_mtp" },
4644
{ LLM_ARCH_PHI2, "phi2" },
4745
{ LLM_ARCH_PHI3, "phi3" },
4846
{ LLM_ARCH_PHIMOE, "phimoe" },

src/llama-arch.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ enum llm_arch {
4545
LLM_ARCH_QWEN3VLMOE,
4646
LLM_ARCH_QWEN35,
4747
LLM_ARCH_QWEN35MOE,
48-
LLM_ARCH_QWEN35_MTP,
49-
LLM_ARCH_QWEN35MOE_MTP,
5048
LLM_ARCH_PHI2,
5149
LLM_ARCH_PHI3,
5250
LLM_ARCH_PHIMOE,

0 commit comments

Comments
 (0)