From 388098cf0ea2e5957e0214d2353a04a2db6873f1 Mon Sep 17 00:00:00 2001 From: dusterbloom <32869278+dusterbloom@users.noreply.github.com> Date: Thu, 21 May 2026 14:11:46 +0200 Subject: [PATCH] refactor(mtp): MtpSource enum + auto-detect MTP tensors Per @howard0su's review on #237 (lines 57, 59): - 57: 'shall we use the unsloth single-file MTP-in-target GGUF?' - 59: 'why not a enum?' Replaces: const char * mtp_gguf_path = nullptr; const char * mtp_draft_source = nullptr; // "chain" | "mtp_topk" with: enum class MtpSource { None, Native, ExternalDrafter, Auto }; MtpSource mtp_source = MtpSource::None; const char * mtp_gguf_path = nullptr; // only for ExternalDrafter bool mtp_use_topk = false; // false=chain, true=mtp_topk Adds gguf_contains_mtp_tensors() probe (keyed on qwen35.nextn_predict_layers metadata) so --mtp-gguf becomes optional when the primary GGUF embeds MTP tensors (unsloth single-file case). Stacked on #237. dflash_server arg parsing updated to: - --mtp-source [none|native|external|auto] (new explicit flag) - --mtp-gguf PATH (now optional; only needed for ExternalDrafter) - Old --mtp-draft-source string flag warns + ignored (migration aid) - --mtp-gamma alone triggers Auto detection All test_common_mtp_orchestrator tests still pass (mock-based, unaffected by the config-surface change). --- dflash/src/common/backend_factory.cpp | 54 ++++++++++++++++++- dflash/src/common/backend_factory.h | 30 +++++++++-- dflash/src/qwen35/qwen35_backend.cpp | 4 +- dflash/src/qwen35/qwen35_backend.h | 10 ++-- dflash/src/qwen35/qwen35_daemon.cpp | 22 +++++++- dflash/src/qwen35/qwen35_daemon.h | 9 ++-- dflash/src/server/server_main.cpp | 74 ++++++++++++++++++++++----- dflash/test/test_dflash.cpp | 38 +++++--------- 8 files changed, 188 insertions(+), 53 deletions(-) diff --git a/dflash/src/common/backend_factory.cpp b/dflash/src/common/backend_factory.cpp index 6dac51ef..2621ed05 100644 --- a/dflash/src/common/backend_factory.cpp +++ b/dflash/src/common/backend_factory.cpp @@ -8,6 +8,8 @@ #include "qwen3_backend.h" #include "gemma4_backend.h" +#include "gguf.h" + #include namespace dflash27b { @@ -17,6 +19,26 @@ std::string detect_arch(const char * model_path) { return info.arch; } +bool gguf_contains_mtp_tensors(const std::string & path) { + gguf_init_params gp{}; + gp.no_alloc = true; + gp.ctx = nullptr; + gguf_context * gguf = gguf_init_from_file(path.c_str(), gp); + if (!gguf) return false; + + // MTP-capable GGUF files carry `qwen35.nextn_predict_layers` > 0. + // This is the canonical indicator used by qwen35_mtp_loader.cpp. + bool found = false; + int64_t kid = gguf_find_key(gguf, "qwen35.nextn_predict_layers"); + if (kid >= 0) { + uint32_t n = gguf_get_val_u32(gguf, kid); + found = (n > 0); + } + + gguf_free(gguf); + return found; +} + std::unique_ptr create_backend(const BackendArgs & args) { if (!args.model_path) { std::fprintf(stderr, "[backend_factory] model_path is null\n"); @@ -32,6 +54,18 @@ std::unique_ptr create_backend(const BackendArgs & args) { std::fprintf(stderr, "[backend_factory] detected arch=%s\n", arch.c_str()); + // Resolve MtpSource::Auto before constructing the backend. + MtpSource resolved_source = args.mtp_source; + if (resolved_source == MtpSource::Auto) { + if (gguf_contains_mtp_tensors(args.model_path)) { + std::fprintf(stderr, "[backend_factory] mtp=auto: nextn_predict_layers found -> Native\n"); + resolved_source = MtpSource::Native; + } else { + std::fprintf(stderr, "[backend_factory] mtp=auto: no nextn_predict_layers -> None\n"); + resolved_source = MtpSource::None; + } + } + if (arch == "qwen35") { Qwen35Config cfg; cfg.target_path = args.model_path; @@ -50,11 +84,27 @@ std::unique_ptr create_backend(const BackendArgs & args) { cfg.ddtree_temp = args.ddtree_temp; cfg.ddtree_chain_seed = args.ddtree_chain_seed; cfg.use_feature_mirror = args.use_feature_mirror; - cfg.mtp_gguf_path = args.mtp_gguf_path; cfg.mtp_gamma = args.mtp_gamma; - cfg.mtp_draft_source = args.mtp_draft_source; + cfg.mtp_use_topk = args.mtp_use_topk; cfg.mtp_draft_topk = args.mtp_draft_topk; + // Map resolved MtpSource to the paths Qwen35Backend expects. + // Qwen35Backend uses cfg_.mtp_gguf_path != nullptr as the MTP-active sentinel. + switch (resolved_source) { + case MtpSource::Native: + // MTP tensors live inside the target GGUF itself. + cfg.mtp_gguf_path = args.model_path; + break; + case MtpSource::ExternalDrafter: + cfg.mtp_gguf_path = args.mtp_gguf_path; + break; + case MtpSource::None: + case MtpSource::Auto: // Auto is fully resolved above; this arm is unreachable. + default: + cfg.mtp_gguf_path = nullptr; + break; + } + auto backend = std::make_unique(cfg); if (!backend->init()) { std::fprintf(stderr, "[backend_factory] Qwen35Backend init failed\n"); diff --git a/dflash/src/common/backend_factory.h b/dflash/src/common/backend_factory.h index de86a502..e4ee0eb7 100644 --- a/dflash/src/common/backend_factory.h +++ b/dflash/src/common/backend_factory.h @@ -18,6 +18,15 @@ namespace dflash27b { +// ─── MTP source selection ──────────────────────────────────────────────── +// Replaces the old free-form mtp_draft_source string (@howard0su #237, line 59). +enum class MtpSource { + None, // no MTP speculator + Native, // MTP heads co-located in the target GGUF (e.g. unsloth single-file) + ExternalDrafter, // separate MTP-head GGUF supplied via mtp_gguf_path + Auto, // probe target GGUF for nextn_predict_layers; Native if found, else None +}; + // ─── Backend creation arguments ───────────────────────────────────────── // A superset of all per-arch config fields. The factory reads only those // relevant to the detected arch; unused fields are silently ignored. @@ -53,20 +62,35 @@ struct BackendArgs { bool use_feature_mirror = false; // MTP (Multi-Token Prediction) speculator — mutually exclusive with --draft. - // When mtp_gguf_path is set, the backend ignores draft_path. - const char * mtp_gguf_path = nullptr; + // mtp_source drives which loading path is taken: + // None → MTP disabled; mtp_gguf_path ignored. + // Native → MTP heads embedded in model_path GGUF (single-file, e.g. unsloth). + // mtp_gguf_path is left nullptr; the factory sets it to model_path. + // ExternalDrafter→ Separate MTP-head GGUF at mtp_gguf_path (required). + // Auto → factory calls gguf_contains_mtp_tensors(model_path): if true, + // resolves to Native; otherwise resolves to None. + MtpSource mtp_source = MtpSource::None; + const char * mtp_gguf_path = nullptr; // required only for ExternalDrafter int mtp_gamma = 0; // 0 = MTP loaded but not active; >0 = chain depth - const char * mtp_draft_source = nullptr; // "chain" (default) | "mtp_topk" + bool mtp_use_topk = false; // false = chain (default), true = mtp_topk strategy int mtp_draft_topk = 1; }; // ─── Factory function ─────────────────────────────────────────────────── // Inspects model_path GGUF metadata, constructs the correct backend, and // calls init(). Returns nullptr on failure (diagnostic printed to stderr). +// When args.mtp_source == Auto, resolves to Native or None before +// constructing; the resolved value is not written back into args. std::unique_ptr create_backend(const BackendArgs & args); // Returns the detected architecture string without creating a backend. // Useful for early dispatch (e.g. printing which backend will be used). std::string detect_arch(const char * model_path); +// Returns true if the GGUF at `path` contains MTP-head tensors. +// Heuristic: presence of `qwen35.nextn_predict_layers` metadata key with +// a value > 0. Pure metadata scan — no tensor allocation, no GPU touch. +// Used by create_backend() when mtp_source == Auto. +bool gguf_contains_mtp_tensors(const std::string & path); + } // namespace dflash27b diff --git a/dflash/src/qwen35/qwen35_backend.cpp b/dflash/src/qwen35/qwen35_backend.cpp index f644bfff..9042f27f 100644 --- a/dflash/src/qwen35/qwen35_backend.cpp +++ b/dflash/src/qwen35/qwen35_backend.cpp @@ -1004,13 +1004,13 @@ bool Qwen35Backend::init_mtp_() { return false; } - if (cfg_.mtp_draft_source && std::strcmp(cfg_.mtp_draft_source, "mtp_topk") == 0) { + if (cfg_.mtp_use_topk) { mtp_module_->set_draft_topk(std::max(1, cfg_.mtp_draft_topk)); } std::printf("[mtp] loaded gamma=%d source=%s\n", cfg_.mtp_gamma, - cfg_.mtp_draft_source ? cfg_.mtp_draft_source : "chain"); + cfg_.mtp_use_topk ? "mtp_topk" : "chain"); std::fflush(stdout); return true; } diff --git a/dflash/src/qwen35/qwen35_backend.h b/dflash/src/qwen35/qwen35_backend.h index 9652e75a..ef1f4763 100644 --- a/dflash/src/qwen35/qwen35_backend.h +++ b/dflash/src/qwen35/qwen35_backend.h @@ -12,6 +12,7 @@ #pragma once #include "common/model_backend.h" +#include "common/backend_factory.h" #include "common/dflash_target.h" #include "common/device_placement.h" #include "step_graph.h" @@ -58,10 +59,13 @@ struct Qwen35Config { bool ddtree_chain_seed = true; bool use_feature_mirror = false; - // MTP (Multi-Token Prediction) speculator — mutually exclusive with draft - const char * mtp_gguf_path = nullptr; // path to fused MTP GGUF (or nullptr = DFlash) + // MTP (Multi-Token Prediction) speculator — mutually exclusive with draft. + // mtp_gguf_path != nullptr is the MTP-active sentinel (set by backend_factory + // based on MtpSource). For MtpSource::Native it is set to target_path; + // for MtpSource::ExternalDrafter it is the external GGUF path. + const char * mtp_gguf_path = nullptr; // path to GGUF containing MTP tensors int mtp_gamma = 0; // max speculation depth - const char * mtp_draft_source = nullptr; // "chain" | "mtp_topk" | nullptr -> "chain" + bool mtp_use_topk = false; // false = chain (default), true = mtp_topk int mtp_draft_topk = 1; // top-k for mtp_topk mode }; diff --git a/dflash/src/qwen35/qwen35_daemon.cpp b/dflash/src/qwen35/qwen35_daemon.cpp index 06cba99d..75497612 100644 --- a/dflash/src/qwen35/qwen35_daemon.cpp +++ b/dflash/src/qwen35/qwen35_daemon.cpp @@ -30,11 +30,29 @@ int run_qwen35_daemon(const Qwen35DaemonArgs & args) { cfg.ddtree_temp = args.ddtree_temp; cfg.ddtree_chain_seed = args.ddtree_chain_seed; cfg.use_feature_mirror = args.use_feature_mirror; - cfg.mtp_gguf_path = args.mtp_gguf_path; cfg.mtp_gamma = args.mtp_gamma; - cfg.mtp_draft_source = args.mtp_draft_source; + cfg.mtp_use_topk = args.mtp_use_topk; cfg.mtp_draft_topk = args.mtp_draft_topk; + // Resolve MtpSource to the mtp_gguf_path sentinel that Qwen35Backend expects. + switch (args.mtp_source) { + case MtpSource::Native: + cfg.mtp_gguf_path = args.target_path; + break; + case MtpSource::ExternalDrafter: + cfg.mtp_gguf_path = args.mtp_gguf_path; + break; + case MtpSource::Auto: + cfg.mtp_gguf_path = dflash27b::gguf_contains_mtp_tensors(args.target_path) + ? args.target_path + : nullptr; + break; + case MtpSource::None: + default: + cfg.mtp_gguf_path = nullptr; + break; + } + Qwen35Backend backend(cfg); if (!backend.init()) return 1; diff --git a/dflash/src/qwen35/qwen35_daemon.h b/dflash/src/qwen35/qwen35_daemon.h index 92eda4ed..abc1b979 100644 --- a/dflash/src/qwen35/qwen35_daemon.h +++ b/dflash/src/qwen35/qwen35_daemon.h @@ -5,6 +5,7 @@ #pragma once +#include "common/backend_factory.h" #include "device_placement.h" #include @@ -35,10 +36,12 @@ struct Qwen35DaemonArgs { bool ddtree_chain_seed = true; bool use_feature_mirror = false; - // MTP (Multi-Token Prediction) speculator — mutually exclusive with draft - const char * mtp_gguf_path = nullptr; // path to fused MTP GGUF (or nullptr = DFlash) + // MTP (Multi-Token Prediction) speculator — mutually exclusive with draft. + // The daemon uses BackendArgs directly; these fields mirror BackendArgs. + MtpSource mtp_source = MtpSource::None; + const char * mtp_gguf_path = nullptr; // required only for ExternalDrafter int mtp_gamma = 0; // max speculation depth - const char * mtp_draft_source = nullptr; // "chain" | "mtp_topk" | nullptr -> "chain" + bool mtp_use_topk = false; // false = chain, true = mtp_topk int mtp_draft_topk = 1; // top-k for mtp_topk mode }; diff --git a/dflash/src/server/server_main.cpp b/dflash/src/server/server_main.cpp index c027761b..049f41d9 100644 --- a/dflash/src/server/server_main.cpp +++ b/dflash/src/server/server_main.cpp @@ -55,10 +55,15 @@ static void print_usage(const char * prog) { " --prefill-skip-park Skip park/unpark (for >=32GB GPUs)\n" "\n" "MTP speculative decoding (mutually exclusive with --draft):\n" - " --mtp-gguf MTP drafter GGUF path\n" + " --mtp-source \n" + " MTP source (default: auto when --mtp-gamma given)\n" + " none = disable MTP\n" + " native = MTP heads in target GGUF (unsloth single-file)\n" + " external = separate GGUF via --mtp-gguf\n" + " auto = probe target GGUF; native if found, else none\n" + " --mtp-gguf MTP GGUF path (required only for --mtp-source external)\n" " --mtp-gamma Speculation chain depth (default: 0 = disabled)\n" - " --mtp-draft-source Draft strategy (default: chain)\n" - " --mtp-draft-topk Top-k for mtp_topk mode (default: 1)\n" + " --mtp-draft-topk Top-k draft strategy (default: chain; >1 enables mtp_topk)\n" "\n", prog); } @@ -124,14 +129,34 @@ int main(int argc, char ** argv) { sconfig.pflash_drafter_path = argv[++i]; } else if (std::strcmp(argv[i], "--prefill-skip-park") == 0) { sconfig.pflash_skip_park = true; + } else if (std::strcmp(argv[i], "--mtp-source") == 0 && i + 1 < argc) { + const char * src = argv[++i]; + if (std::strcmp(src, "none") == 0) + bargs.mtp_source = MtpSource::None; + else if (std::strcmp(src, "native") == 0) + bargs.mtp_source = MtpSource::Native; + else if (std::strcmp(src, "external") == 0) + bargs.mtp_source = MtpSource::ExternalDrafter; + else if (std::strcmp(src, "auto") == 0) + bargs.mtp_source = MtpSource::Auto; + else { + std::fprintf(stderr, "[server] unknown --mtp-source: '%s' (expected: none|native|external|auto)\n", src); + print_usage(argv[0]); + return 1; + } } else if (std::strcmp(argv[i], "--mtp-gguf") == 0 && i + 1 < argc) { bargs.mtp_gguf_path = argv[++i]; } else if (std::strcmp(argv[i], "--mtp-gamma") == 0 && i + 1 < argc) { bargs.mtp_gamma = std::atoi(argv[++i]); } else if (std::strcmp(argv[i], "--mtp-draft-source") == 0 && i + 1 < argc) { - bargs.mtp_draft_source = argv[++i]; + ++i; // consume the argument + std::fprintf(stderr, + "[server] WARNING: --mtp-draft-source is deprecated. " + "Use --mtp-source [none|native|external|auto] and " + "--mtp-draft-topk instead.\n"); } else if (std::strcmp(argv[i], "--mtp-draft-topk") == 0 && i + 1 < argc) { bargs.mtp_draft_topk = std::atoi(argv[++i]); + if (bargs.mtp_draft_topk > 1) bargs.mtp_use_topk = true; } else if (std::strcmp(argv[i], "--cache-type-k") == 0 && i + 1 < argc) { cache_type_k = argv[++i]; } else if (std::strcmp(argv[i], "--cache-type-v") == 0 && i + 1 < argc) { @@ -150,10 +175,29 @@ int main(int argc, char ** argv) { sconfig.max_ctx = bargs.device.max_ctx; } - // --draft and --mtp-gguf are mutually exclusive; MTP wins if both are set. - if (bargs.draft_path && bargs.mtp_gguf_path) { + // Infer MtpSource from legacy flags when --mtp-source is absent. + // --mtp-gguf without --mtp-source → ExternalDrafter (backward compat) + // --mtp-gamma without --mtp-source → Auto (probe the target GGUF) + if (bargs.mtp_source == MtpSource::None) { + if (bargs.mtp_gguf_path) { + bargs.mtp_source = MtpSource::ExternalDrafter; + } else if (bargs.mtp_gamma > 0) { + bargs.mtp_source = MtpSource::Auto; + } + } + + // Validate: ExternalDrafter requires --mtp-gguf. + if (bargs.mtp_source == MtpSource::ExternalDrafter && !bargs.mtp_gguf_path) { + std::fprintf(stderr, + "[server] ERROR: --mtp-source external requires --mtp-gguf \n"); + return 1; + } + + // --draft and MTP are mutually exclusive; MTP wins if both are set. + const bool mtp_active = (bargs.mtp_source != MtpSource::None); + if (bargs.draft_path && mtp_active) { std::fprintf(stderr, - "[server] WARNING: --draft and --mtp-gguf both set; ignoring --draft.\n" + "[server] WARNING: --draft and MTP both set; ignoring --draft.\n" "[server] MTP speculation takes precedence over DFlash draft.\n"); bargs.draft_path = nullptr; } @@ -176,7 +220,7 @@ int main(int argc, char ** argv) { // Default MTP head_kv capacity to backbone max_ctx so prompts up to max_ctx // never overflow the head_kv buffer (the old hardcoded 8192 caused a silent // server crash when agentic prompts exceeded that length). - if (bargs.mtp_gguf_path && sconfig.max_ctx > 0) { + if (mtp_active && sconfig.max_ctx > 0) { char ctx_str[32]; std::snprintf(ctx_str, sizeof(ctx_str), "%d", sconfig.max_ctx); setenv("DFLASH27B_MTP_CTX", ctx_str, 0); // don't overwrite user env @@ -241,11 +285,17 @@ int main(int argc, char ** argv) { std::fprintf(stderr, "[server] │ fa_window = %d\n", bargs.fa_window); std::fprintf(stderr, "[server] │ ddtree = %s\n", bargs.ddtree_mode ? "ON" : "off"); std::fprintf(stderr, "[server] │ ddtree_budget = %d\n", bargs.ddtree_budget); - if (bargs.mtp_gguf_path) { - std::fprintf(stderr, "[server] │ mtp_gguf = %s\n", bargs.mtp_gguf_path); + if (mtp_active) { + const char * src_str = + bargs.mtp_source == MtpSource::Native ? "native" : + bargs.mtp_source == MtpSource::ExternalDrafter ? "external" : + bargs.mtp_source == MtpSource::Auto ? "auto" : "none"; + std::fprintf(stderr, "[server] │ mtp_source = %s\n", src_str); + if (bargs.mtp_gguf_path) + std::fprintf(stderr, "[server] │ mtp_gguf = %s\n", bargs.mtp_gguf_path); std::fprintf(stderr, "[server] │ mtp_gamma = %d\n", bargs.mtp_gamma); - std::fprintf(stderr, "[server] │ mtp_draft_src = %s\n", - bargs.mtp_draft_source ? bargs.mtp_draft_source : "chain (default)"); + std::fprintf(stderr, "[server] │ mtp_draft_strat = %s\n", + bargs.mtp_use_topk ? "mtp_topk" : "chain (default)"); } std::fprintf(stderr, "[server] │ cors = %s\n", sconfig.enable_cors ? "ON" : "off"); std::fprintf(stderr, "[server] │ cache_type_k = %s\n", diff --git a/dflash/test/test_dflash.cpp b/dflash/test/test_dflash.cpp index ce0bed13..3af38463 100644 --- a/dflash/test/test_dflash.cpp +++ b/dflash/test/test_dflash.cpp @@ -650,16 +650,7 @@ static int run_target_layer_split_harness( return 0; } -// draft_source values: -// "chain" — existing MtpChainRunner path (MTP argmax chain, verify_batch -// sequential verify). K=1, no DDTree. -// "mtp_topk" — experiment C: configure set_draft_topk(K), call step_batch, -// build DDTree from per-head top-K, then verify the DDTree's -// top-1 chain through verify_batch (target chain verify is the -// only verify surface available on DFlashTarget today; a true -// tree-mask verify would require lifting test_dflash.cpp's -// spec-decode loop out of the qwen35 graph builder — see the -// BLOCKER note in qwen35-mtp experiment-C wiring docs). +// use_topk: false = MtpChainRunner chain path (default), true = mtp_topk DDTree. static int run_qwen35_mtp_harness(const char * target_path, const char * mtp_gguf_path, const char * prompt_path, @@ -669,7 +660,7 @@ static int run_qwen35_mtp_harness(const char * target_path, int prompt_id, int target_gpu, int max_ctx, - const char * draft_source, + bool use_topk, int draft_topk, int ddtree_budget, bool ddtree_chain_seed, @@ -825,7 +816,7 @@ static int run_qwen35_mtp_harness(const char * target_path, base_pos++; if (target->is_eos(next)) break; } - } else if (!draft_source || std::strcmp(draft_source, "chain") == 0) { + } else if (!use_topk) { mtp_module->reset_chain(); GenerateRequest req; req.n_gen = n_gen - 1; @@ -843,7 +834,7 @@ static int run_qwen35_mtp_harness(const char * target_path, generated.insert(generated.end(), res.tokens.begin(), res.tokens.end()); accepted = runner.stats().total_accepted; proposed = runner.stats().total_proposed; - } else if (std::strcmp(draft_source, "mtp_topk") == 0) { + } else if (use_topk) { // ── experiment C: MTP top-K → DDTree → chain-verify ───────── // 1. step_batch with K>1 populates StepOutput.topk_logprobs/ids // on every emitted head (length K, sorted DESCENDING). @@ -1093,9 +1084,6 @@ static int run_qwen35_mtp_harness(const char * target_path, "mean_tree_size=%.2f mean_gamma=%.2f\n", K, ddtree_budget, (int)ddtree_chain_seed, n_steps, mean_tree_size, mean_gamma); - } else { - std::fprintf(stderr, "unknown --draft-source: %s (expected chain|mtp_topk)\n", draft_source); - return 2; } } auto t_decode1 = std::chrono::steady_clock::now(); @@ -1116,7 +1104,7 @@ static int run_qwen35_mtp_harness(const char * target_path, { const double accept_rate = proposed > 0 ? (double)accepted / (double)proposed : 0.0; - const char * src = (draft_source && *draft_source) ? draft_source : "chain"; + const char * src = use_topk ? "mtp_topk" : "chain"; std::printf("RESULT_JSON {" "\"draft_source\":\"%s\"," "\"gamma\":%d," @@ -1256,11 +1244,8 @@ int main(int argc, char ** argv) { int mtp_gamma = 2; int mtp_n_gen = 0; int mtp_prompt_id = 0; - // Experiment-C draft source for the MTP harness. "chain" preserves - // the existing MtpChainRunner path; "mtp_topk" wires set_draft_topk + - // build_ddtree (see run_qwen35_mtp_harness for the BLOCKER on true - // tree-mask verify). - const char * mtp_draft_source = "chain"; + // MTP draft strategy: false = chain (default), true = mtp_topk. + bool mtp_use_topk = false; int mtp_draft_topk = 4; int target_gpu = 0; int draft_gpu = 0; @@ -1333,10 +1318,10 @@ int main(int argc, char ** argv) { target_split_load_draft = true; } else if (std::strncmp(argv[i], "--draft-source=", 15) == 0) { - mtp_draft_source = argv[i] + 15; + mtp_use_topk = (std::strcmp(argv[i] + 15, "mtp_topk") == 0); } else if (std::strcmp(argv[i], "--draft-source") == 0) { - if (i + 1 < argc) mtp_draft_source = argv[++i]; + if (i + 1 < argc) mtp_use_topk = (std::strcmp(argv[++i], "mtp_topk") == 0); } else if (std::strncmp(argv[i], "--draft-topk=", 13) == 0) { mtp_draft_topk = std::max(1, std::atoi(argv[i] + 13)); @@ -1659,9 +1644,10 @@ int main(int argc, char ** argv) { qargs.ddtree_temp = ddtree_temp; qargs.ddtree_chain_seed = ddtree_chain_seed; qargs.use_feature_mirror = false; + qargs.mtp_source = dflash27b::MtpSource::ExternalDrafter; qargs.mtp_gguf_path = mtp_gguf_path; qargs.mtp_gamma = mtp_gamma; - qargs.mtp_draft_source = mtp_draft_source; + qargs.mtp_use_topk = mtp_use_topk; qargs.mtp_draft_topk = mtp_draft_topk; return dflash27b::run_qwen35_daemon(qargs); } @@ -1673,7 +1659,7 @@ int main(int argc, char ** argv) { prompt_path, n_gen, out_path, mtp_gamma, mtp_prompt_id, target_gpu, max_ctx_eff, - mtp_draft_source, + mtp_use_topk, mtp_draft_topk, ddtree_budget, ddtree_chain_seed,