Skip to content

Commit a421d66

Browse files
authored
MTP: clean-up (#9)
* MTP: clean-up * review: use llama_context_type instead of llama_graph_type * review: remove llama_model_has_mtp * review: fix convert issues * convert: fix pycheck * review: formatting * use `mtp-` for identifying mtp models * convert: fix mtp conversion
1 parent ebe4fca commit a421d66

20 files changed

Lines changed: 689 additions & 648 deletions

common/arg.cpp

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -335,11 +335,15 @@ static bool common_params_handle_remote_preset(common_params & params, llama_exa
335335
struct handle_model_result {
336336
bool found_mmproj = false;
337337
common_params_model mmproj;
338+
339+
bool found_mtp = false;
340+
common_params_model mtp;
338341
};
339342

340343
static handle_model_result common_params_handle_model(struct common_params_model & model,
341344
const std::string & bearer_token,
342-
bool offline) {
345+
bool offline,
346+
bool search_mtp = false) {
343347
handle_model_result result;
344348

345349
if (!model.docker_repo.empty()) {
@@ -354,7 +358,7 @@ static handle_model_result common_params_handle_model(struct common_params_model
354358
common_download_opts opts;
355359
opts.bearer_token = bearer_token;
356360
opts.offline = offline;
357-
auto download_result = common_download_model(model, opts, true);
361+
auto download_result = common_download_model(model, opts, true, search_mtp);
358362

359363
if (download_result.model_path.empty()) {
360364
LOG_ERR("error: failed to download model from Hugging Face\n");
@@ -368,6 +372,11 @@ static handle_model_result common_params_handle_model(struct common_params_model
368372
result.found_mmproj = true;
369373
result.mmproj.path = download_result.mmproj_path;
370374
}
375+
376+
if (!download_result.mtp_path.empty()) {
377+
result.found_mtp = true;
378+
result.mtp.path = download_result.mtp_path;
379+
}
371380
} else if (!model.url.empty()) {
372381
if (model.path.empty()) {
373382
auto f = string_split<std::string>(model.url, '#').front();
@@ -588,7 +597,11 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
588597

589598
// handle model and download
590599
if (!skip_model_download) {
591-
auto res = common_params_handle_model(params.model, params.hf_token, params.offline);
600+
const bool spec_type_mtp = std::find(params.speculative.types.begin(),
601+
params.speculative.types.end(),
602+
COMMON_SPECULATIVE_TYPE_MTP) != params.speculative.types.end();
603+
604+
auto res = common_params_handle_model(params.model, params.hf_token, params.offline, spec_type_mtp);
592605
if (params.no_mmproj) {
593606
params.mmproj = {};
594607
} else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) {
@@ -602,6 +615,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context
602615
break;
603616
}
604617
}
618+
// when --spec-type mtp is set and no draft model was provided explicitly,
619+
// fall back to the MTP head discovered alongside the -hf model
620+
if (spec_type_mtp && res.found_mtp &&
621+
params.speculative.draft.mparams.path.empty() &&
622+
params.speculative.draft.mparams.hf_repo.empty() &&
623+
params.speculative.draft.mparams.url.empty()) {
624+
params.speculative.draft.mparams.path = res.mtp.path;
625+
}
605626
common_params_handle_model(params.speculative.draft.mparams, params.hf_token, params.offline);
606627
common_params_handle_model(params.vocoder.model, params.hf_token, params.offline);
607628
}

common/download.cpp

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -566,8 +566,11 @@ static hf_cache::hf_files get_split_files(const hf_cache::hf_files & files,
566566
return result;
567567
}
568568

569-
static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
570-
const std::string & model) {
569+
// pick the best sibling GGUF whose filename contains `keyword` (e.g. "mmproj" / "mtp"),
570+
// preferring deeper shared directory prefix with the model, then closest quantization
571+
static hf_cache::hf_file find_best_sibling(const hf_cache::hf_files & files,
572+
const std::string & model,
573+
const std::string & keyword) {
571574
hf_cache::hf_file best;
572575
size_t best_depth = 0;
573576
int best_diff = 0;
@@ -579,20 +582,20 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
579582

580583
for (const auto & f : files) {
581584
if (!string_ends_with(f.path, ".gguf") ||
582-
f.path.find("mmproj") == std::string::npos) {
585+
f.path.find(keyword) == std::string::npos) {
583586
continue;
584587
}
585588

586-
auto mmproj_parts = string_split<std::string>(f.path, '/');
587-
auto mmproj_dir = mmproj_parts.end() - 1;
589+
auto sib_parts = string_split<std::string>(f.path, '/');
590+
auto sib_dir = sib_parts.end() - 1;
588591

589592
auto [_, dir] = std::mismatch(model_parts.begin(), model_dir,
590-
mmproj_parts.begin(), mmproj_dir);
591-
if (dir != mmproj_dir) {
593+
sib_parts.begin(), sib_dir);
594+
if (dir != sib_dir) {
592595
continue;
593596
}
594597

595-
size_t depth = dir - mmproj_parts.begin();
598+
size_t depth = dir - sib_parts.begin();
596599
auto bits = extract_quant_bits(f.path);
597600
auto diff = std::abs(bits - model_bits);
598601

@@ -606,6 +609,16 @@ static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
606609
return best;
607610
}
608611

612+
static hf_cache::hf_file find_best_mmproj(const hf_cache::hf_files & files,
613+
const std::string & model) {
614+
return find_best_sibling(files, model, "mmproj");
615+
}
616+
617+
static hf_cache::hf_file find_best_mtp(const hf_cache::hf_files & files,
618+
const std::string & model) {
619+
return find_best_sibling(files, model, "mtp-");
620+
}
621+
609622
static bool gguf_filename_is_model(const std::string & filepath) {
610623
if (!string_ends_with(filepath, ".gguf")) {
611624
return false;
@@ -617,7 +630,8 @@ static bool gguf_filename_is_model(const std::string & filepath) {
617630
}
618631

619632
return filename.find("mmproj") == std::string::npos &&
620-
filename.find("imatrix") == std::string::npos;
633+
filename.find("imatrix") == std::string::npos &&
634+
filename.find("mtp-") == std::string::npos;
621635
}
622636

623637
static hf_cache::hf_file find_best_model(const hf_cache::hf_files & files,
@@ -673,11 +687,13 @@ struct hf_plan {
673687
hf_cache::hf_file primary;
674688
hf_cache::hf_files model_files;
675689
hf_cache::hf_file mmproj;
690+
hf_cache::hf_file mtp;
676691
};
677692

678693
static hf_plan get_hf_plan(const common_params_model & model,
679694
const common_download_opts & opts,
680-
bool download_mmproj) {
695+
bool download_mmproj,
696+
bool download_mtp) {
681697
hf_plan plan;
682698
hf_cache::hf_files all;
683699

@@ -723,6 +739,10 @@ static hf_plan get_hf_plan(const common_params_model & model,
723739
plan.mmproj = find_best_mmproj(all, primary.path);
724740
}
725741

742+
if (download_mtp) {
743+
plan.mtp = find_best_mtp(all, primary.path);
744+
}
745+
726746
return plan;
727747
}
728748

@@ -756,21 +776,25 @@ static std::vector<download_task> get_url_tasks(const common_params_model & mode
756776

757777
common_download_model_result common_download_model(const common_params_model & model,
758778
const common_download_opts & opts,
759-
bool download_mmproj) {
779+
bool download_mmproj,
780+
bool download_mtp) {
760781
common_download_model_result result;
761782
std::vector<download_task> tasks;
762783
hf_plan hf;
763784

764785
bool is_hf = !model.hf_repo.empty();
765786

766787
if (is_hf) {
767-
hf = get_hf_plan(model, opts, download_mmproj);
788+
hf = get_hf_plan(model, opts, download_mmproj, download_mtp);
768789
for (const auto & f : hf.model_files) {
769790
tasks.push_back({f.url, f.local_path});
770791
}
771792
if (!hf.mmproj.path.empty()) {
772793
tasks.push_back({hf.mmproj.url, hf.mmproj.local_path});
773794
}
795+
if (!hf.mtp.path.empty()) {
796+
tasks.push_back({hf.mtp.url, hf.mtp.local_path});
797+
}
774798
} else if (!model.url.empty()) {
775799
tasks = get_url_tasks(model);
776800
} else {
@@ -807,6 +831,10 @@ common_download_model_result common_download_model(const common_params_model &
807831
if (!hf.mmproj.path.empty()) {
808832
result.mmproj_path = hf_cache::finalize_file(hf.mmproj);
809833
}
834+
835+
if (!hf.mtp.path.empty()) {
836+
result.mtp_path = hf_cache::finalize_file(hf.mtp);
837+
}
810838
} else {
811839
result.model_path = model.path;
812840
}
@@ -946,7 +974,8 @@ std::vector<common_cached_model_info> common_list_cached_models() {
946974
for (const auto & f : files) {
947975
auto split = get_gguf_split_info(f.path);
948976
if (split.index != 1 || split.tag.empty() ||
949-
split.prefix.find("mmproj") != std::string::npos) {
977+
split.prefix.find("mmproj") != std::string::npos ||
978+
split.prefix.find("MTP") != std::string::npos) {
950979
continue;
951980
}
952981
if (seen.insert(f.repo_id + ":" + split.tag).second) {

common/download.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ struct common_download_opts {
5959
struct common_download_model_result {
6060
std::string model_path;
6161
std::string mmproj_path;
62+
std::string mtp_path;
6263
};
6364

6465
// Download model from HuggingFace repo or URL
@@ -83,12 +84,14 @@ struct common_download_model_result {
8384
// when opts.offline=true, no network requests are made
8485
// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory
8586
// then with the closest quantization bits
87+
// when download_mtp=true, applies the same sibling search for an MTP-head GGUF
8688
//
87-
// returns result with model_path and mmproj_path (empty on failure)
89+
// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure)
8890
common_download_model_result common_download_model(
8991
const common_params_model & model,
9092
const common_download_opts & opts = {},
91-
bool download_mmproj = false
93+
bool download_mmproj = false,
94+
bool download_mtp = false
9295
);
9396

9497
// returns list of cached models

common/speculative.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1198,7 +1198,7 @@ common_speculative * common_speculative_init(common_params_speculative & params,
11981198
LOG_WRN("%s: draft model is not specified - cannot use 'draft' type\n", __func__);
11991199
has_draft = false;
12001200
}
1201-
} else if (has_draft_model) {
1201+
} else if (has_draft_model && !has_mtp && !has_draft_eagle3) {
12021202
LOG_WRN("%s: draft model is specified but 'draft' speculative type is not explicitly enabled - enabling it\n", __func__);
12031203
has_draft = true;
12041204
}

convert_hf_to_gguf.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class ModelBase:
9595
gguf_writer: gguf.GGUFWriter
9696
model_name: str | None
9797
metadata_override: Path | None
98+
metadata: gguf.Metadata
9899
dir_model_card: Path
99100
remote_hf_model_id: str | None
100101

@@ -5560,26 +5561,59 @@ class _Qwen35MtpMixin:
55605561
block_count: int
55615562
tensor_map: gguf.TensorNameMap
55625563

5564+
mtp_only: bool = False
5565+
no_mtp: bool = False
5566+
55635567
def __init__(self, *args, **kwargs):
55645568
super().__init__(*args, **kwargs)
5565-
self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("mtp_num_hidden_layers", 0)
5569+
self.block_count = self.hparams["num_hidden_layers"]
5570+
if not self.no_mtp:
5571+
self.block_count += self.hparams.get("mtp_num_hidden_layers", 0)
55665572
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
55675573

5574+
@classmethod
5575+
def filter_tensors(cls, item):
5576+
name, _ = item
5577+
if name.startswith("mtp."):
5578+
if cls.no_mtp:
5579+
return None
5580+
return item
5581+
if cls.mtp_only:
5582+
# In --mtp mode, drop trunk weights and keep only the shared embeddings/output
5583+
# tensors that the standalone MTP graph references at inference time.
5584+
canonical = name.replace("language_model.", "")
5585+
keep = canonical in (
5586+
"model.embed_tokens.weight", "model.norm.weight", "lm_head.weight",
5587+
"embed_tokens.weight", "norm.weight",
5588+
)
5589+
if not keep:
5590+
return None
5591+
return super().filter_tensors(item) # ty: ignore[unresolved-attribute]
5592+
55685593
def set_gguf_parameters(self):
55695594
super().set_gguf_parameters() # ty: ignore[unresolved-attribute]
5595+
if self.no_mtp:
5596+
return
55705597
if (n := self.hparams.get("mtp_num_hidden_layers", 0)) > 0:
55715598
self.gguf_writer.add_nextn_predict_layers(n)
55725599

5573-
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
5574-
# Multimodal Qwen3.5/3.6 wrap the text model under `model.language_model.*`.
5575-
if name.startswith("model.language_model."):
5576-
name = "model." + name[len("model.language_model."):]
5577-
elif name.startswith("language_model."):
5578-
name = name[len("language_model."):]
5600+
def prepare_metadata(self, vocab_only: bool):
5601+
# TextModel.prepare_metadata resolves a directory fname_out into a concrete
5602+
# file path, so snapshot is_dir() first to decide whether to apply the mtp- prefix.
5603+
from_dir = self.fname_out.is_dir()
5604+
super().prepare_metadata(vocab_only=vocab_only) # ty: ignore[unresolved-attribute]
5605+
5606+
if not self.mtp_only or not from_dir:
5607+
return
5608+
5609+
output_type: str = self.ftype.name.partition("_")[2] # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
5610+
fname_default: str = gguf.naming_convention(
5611+
self.metadata.name, self.metadata.basename, self.metadata.finetune, # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
5612+
self.metadata.version, size_label=None, output_type=output_type, model_type=None) # pyright: ignore[reportAttributeAccessIssue] # ty: ignore[unresolved-attribute]
5613+
self.fname_out = self.fname_out.parent / f"mtp-{fname_default}.gguf"
55795614

5615+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
55805616
# Remap MTP block tensors to llama.cpp's layer-indexed nextn naming.
5581-
# HF: mtp.layers.0.* (transformer block at MTP slot 0)
5582-
# mtp.fc / mtp.pre_fc_norm_embedding / mtp.pre_fc_norm_hidden / mtp.norm
55835617
if name.startswith("mtp."):
55845618
n_layer = self.hparams["num_hidden_layers"]
55855619
if name.find("layers.") != -1:
@@ -14034,6 +14068,14 @@ def parse_args() -> argparse.Namespace:
1403414068
"--mmproj", action="store_true",
1403514069
help="(Experimental) Export multimodal projector (mmproj) for vision models. This will only work on some vision models. A prefix 'mmproj-' will be added to the output file name.",
1403614070
)
14071+
parser.add_argument(
14072+
"--mtp", action="store_true",
14073+
help="(Experimental) Export only the multi-token prediction (MTP) head as a separate GGUF, suitable for use as a speculative draft. Output file name will get a '-MTP' suffix.",
14074+
)
14075+
parser.add_argument(
14076+
"--no-mtp", action="store_true",
14077+
help="(Experimental) Exclude the multi-token prediction (MTP) head from the converted GGUF. Pair with --mtp on a second run to publish trunk and MTP as two files. Note: the split form duplicates embeddings, so the bundled default is more space-efficient overall.",
14078+
)
1403714079
parser.add_argument(
1403814080
"--mistral-format", action="store_true",
1403914081
help="Whether the model is stored following the Mistral format.",
@@ -14193,6 +14235,20 @@ def main() -> None:
1419314235
else:
1419414236
model_class = MistralModel
1419514237

14238+
if args.mtp and args.no_mtp:
14239+
logger.error("--mtp and --no-mtp are mutually exclusive")
14240+
sys.exit(1)
14241+
14242+
if (args.mtp or args.no_mtp) and not issubclass(model_class, _Qwen35MtpMixin):
14243+
logger.error("--mtp / --no-mtp are only supported for Qwen3.5/3.6 text variants today")
14244+
sys.exit(1)
14245+
14246+
# set on the class so __init__ / filter_tensors see the correct mode
14247+
if args.no_mtp:
14248+
model_class.no_mtp = True # ty: ignore[unresolved-attribute]
14249+
if args.mtp:
14250+
model_class.mtp_only = True # ty: ignore[unresolved-attribute]
14251+
1419614252
model_instance = model_class(dir_model, output_type, fname_out,
1419714253
is_big_endian=args.bigendian, use_temp_file=args.use_temp_file,
1419814254
eager=args.no_lazy,

include/llama.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,11 @@ extern "C" {
198198
LLAMA_SPLIT_MODE_TENSOR = 3,
199199
};
200200

201+
enum llama_context_type {
202+
LLAMA_CONTEXT_TYPE_DEFAULT = 0,
203+
LLAMA_CONTEXT_TYPE_MTP = 1,
204+
};
205+
201206
// TODO: simplify (https://github.com/ggml-org/llama.cpp/pull/9294#pullrequestreview-2286561979)
202207
typedef struct llama_token_data {
203208
llama_token id; // token id
@@ -339,6 +344,7 @@ extern "C" {
339344
int32_t n_threads; // number of threads to use for generation
340345
int32_t n_threads_batch; // number of threads to use for batch processing
341346

347+
enum llama_context_type ctx_type; // set the context type (e.g. MTP)
342348
enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
343349
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
344350
enum llama_attention_type attention_type; // attention type to use for embeddings

src/llama-arch.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
4141
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
4242
{ LLM_ARCH_QWEN35, "qwen35" },
4343
{ LLM_ARCH_QWEN35MOE, "qwen35moe" },
44-
{ LLM_ARCH_QWEN35_MTP, "qwen35_mtp" },
45-
{ LLM_ARCH_QWEN35MOE_MTP, "qwen35moe_mtp" },
4644
{ LLM_ARCH_PHI2, "phi2" },
4745
{ LLM_ARCH_PHI3, "phi3" },
4846
{ LLM_ARCH_PHIMOE, "phimoe" },

src/llama-arch.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ enum llm_arch {
4545
LLM_ARCH_QWEN3VLMOE,
4646
LLM_ARCH_QWEN35,
4747
LLM_ARCH_QWEN35MOE,
48-
LLM_ARCH_QWEN35_MTP,
49-
LLM_ARCH_QWEN35MOE_MTP,
5048
LLM_ARCH_PHI2,
5149
LLM_ARCH_PHI3,
5250
LLM_ARCH_PHIMOE,

0 commit comments

Comments
 (0)