Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions dflash/src/common/backend_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include "qwen3_backend.h"
#include "gemma4_backend.h"

#include "gguf.h"

#include <cstdio>

namespace dflash27b {
Expand All @@ -17,6 +19,26 @@ std::string detect_arch(const char * model_path) {
return info.arch;
}

bool gguf_contains_mtp_tensors(const std::string & path) {
gguf_init_params gp{};
gp.no_alloc = true;
gp.ctx = nullptr;
gguf_context * gguf = gguf_init_from_file(path.c_str(), gp);
if (!gguf) return false;

// MTP-capable GGUF files carry `qwen35.nextn_predict_layers` > 0.
// This is the canonical indicator used by qwen35_mtp_loader.cpp.
bool found = false;
int64_t kid = gguf_find_key(gguf, "qwen35.nextn_predict_layers");
if (kid >= 0) {
uint32_t n = gguf_get_val_u32(gguf, kid);
found = (n > 0);
}

gguf_free(gguf);
return found;
}

std::unique_ptr<ModelBackend> create_backend(const BackendArgs & args) {
if (!args.model_path) {
std::fprintf(stderr, "[backend_factory] model_path is null\n");
Expand All @@ -32,6 +54,18 @@ std::unique_ptr<ModelBackend> create_backend(const BackendArgs & args) {

std::fprintf(stderr, "[backend_factory] detected arch=%s\n", arch.c_str());

// Resolve MtpSource::Auto before constructing the backend.
MtpSource resolved_source = args.mtp_source;
if (resolved_source == MtpSource::Auto) {
if (gguf_contains_mtp_tensors(args.model_path)) {
std::fprintf(stderr, "[backend_factory] mtp=auto: nextn_predict_layers found -> Native\n");
resolved_source = MtpSource::Native;
} else {
std::fprintf(stderr, "[backend_factory] mtp=auto: no nextn_predict_layers -> None\n");
resolved_source = MtpSource::None;
}
}

if (arch == "qwen35") {
Qwen35Config cfg;
cfg.target_path = args.model_path;
Expand All @@ -50,11 +84,27 @@ std::unique_ptr<ModelBackend> create_backend(const BackendArgs & args) {
cfg.ddtree_temp = args.ddtree_temp;
cfg.ddtree_chain_seed = args.ddtree_chain_seed;
cfg.use_feature_mirror = args.use_feature_mirror;
cfg.mtp_gguf_path = args.mtp_gguf_path;
cfg.mtp_gamma = args.mtp_gamma;
cfg.mtp_draft_source = args.mtp_draft_source;
cfg.mtp_use_topk = args.mtp_use_topk;
cfg.mtp_draft_topk = args.mtp_draft_topk;

// Map resolved MtpSource to the paths Qwen35Backend expects.
// Qwen35Backend uses cfg_.mtp_gguf_path != nullptr as the MTP-active sentinel.
switch (resolved_source) {
case MtpSource::Native:
// MTP tensors live inside the target GGUF itself.
cfg.mtp_gguf_path = args.model_path;
break;
case MtpSource::ExternalDrafter:
cfg.mtp_gguf_path = args.mtp_gguf_path;
break;
case MtpSource::None:
case MtpSource::Auto: // Auto is fully resolved above; this arm is unreachable.
default:
cfg.mtp_gguf_path = nullptr;
break;
}

auto backend = std::make_unique<Qwen35Backend>(cfg);
if (!backend->init()) {
std::fprintf(stderr, "[backend_factory] Qwen35Backend init failed\n");
Expand Down
30 changes: 27 additions & 3 deletions dflash/src/common/backend_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@

namespace dflash27b {

// ─── MTP source selection ────────────────────────────────────────────────
// Replaces the old free-form mtp_draft_source string (@howard0su #237, line 59).
enum class MtpSource {
None, // no MTP speculator
Native, // MTP heads co-located in the target GGUF (e.g. unsloth single-file)
ExternalDrafter, // separate MTP-head GGUF supplied via mtp_gguf_path
Auto, // probe target GGUF for nextn_predict_layers; Native if found, else None
};

// ─── Backend creation arguments ─────────────────────────────────────────
// A superset of all per-arch config fields. The factory reads only those
// relevant to the detected arch; unused fields are silently ignored.
Expand Down Expand Up @@ -53,20 +62,35 @@ struct BackendArgs {
bool use_feature_mirror = false;

// MTP (Multi-Token Prediction) speculator — mutually exclusive with --draft.
// When mtp_gguf_path is set, the backend ignores draft_path.
const char * mtp_gguf_path = nullptr;
// mtp_source drives which loading path is taken:
// None → MTP disabled; mtp_gguf_path ignored.
// Native → MTP heads embedded in model_path GGUF (single-file, e.g. unsloth).
// mtp_gguf_path is left nullptr; the factory sets it to model_path.
// ExternalDrafter→ Separate MTP-head GGUF at mtp_gguf_path (required).
// Auto → factory calls gguf_contains_mtp_tensors(model_path): if true,
// resolves to Native; otherwise resolves to None.
MtpSource mtp_source = MtpSource::None;
const char * mtp_gguf_path = nullptr; // required only for ExternalDrafter
int mtp_gamma = 0; // 0 = MTP loaded but not active; >0 = chain depth
const char * mtp_draft_source = nullptr; // "chain" (default) | "mtp_topk"
bool mtp_use_topk = false; // false = chain (default), true = mtp_topk strategy
int mtp_draft_topk = 1;
};

// ─── Factory function ───────────────────────────────────────────────────
// Inspects model_path GGUF metadata, constructs the correct backend, and
// calls init(). Returns nullptr on failure (diagnostic printed to stderr).
// When args.mtp_source == Auto, resolves to Native or None before
// constructing; the resolved value is not written back into args.
std::unique_ptr<ModelBackend> create_backend(const BackendArgs & args);

// Returns the detected architecture string without creating a backend.
// Useful for early dispatch (e.g. printing which backend will be used).
std::string detect_arch(const char * model_path);

// Returns true if the GGUF at `path` contains MTP-head tensors.
// Heuristic: presence of `qwen35.nextn_predict_layers` metadata key with
// a value > 0. Pure metadata scan — no tensor allocation, no GPU touch.
// Used by create_backend() when mtp_source == Auto.
bool gguf_contains_mtp_tensors(const std::string & path);

} // namespace dflash27b
4 changes: 2 additions & 2 deletions dflash/src/qwen35/qwen35_backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1004,13 +1004,13 @@ bool Qwen35Backend::init_mtp_() {
return false;
}

if (cfg_.mtp_draft_source && std::strcmp(cfg_.mtp_draft_source, "mtp_topk") == 0) {
if (cfg_.mtp_use_topk) {
mtp_module_->set_draft_topk(std::max(1, cfg_.mtp_draft_topk));
}

std::printf("[mtp] loaded gamma=%d source=%s\n",
cfg_.mtp_gamma,
cfg_.mtp_draft_source ? cfg_.mtp_draft_source : "chain");
cfg_.mtp_use_topk ? "mtp_topk" : "chain");
std::fflush(stdout);
return true;
}
Expand Down
10 changes: 7 additions & 3 deletions dflash/src/qwen35/qwen35_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#pragma once

#include "common/model_backend.h"
#include "common/backend_factory.h"
#include "common/dflash_target.h"
#include "common/device_placement.h"
#include "step_graph.h"
Expand Down Expand Up @@ -58,10 +59,13 @@ struct Qwen35Config {
bool ddtree_chain_seed = true;
bool use_feature_mirror = false;

// MTP (Multi-Token Prediction) speculator — mutually exclusive with draft
const char * mtp_gguf_path = nullptr; // path to fused MTP GGUF (or nullptr = DFlash)
// MTP (Multi-Token Prediction) speculator — mutually exclusive with draft.
// mtp_gguf_path != nullptr is the MTP-active sentinel (set by backend_factory
// based on MtpSource). For MtpSource::Native it is set to target_path;
// for MtpSource::ExternalDrafter it is the external GGUF path.
const char * mtp_gguf_path = nullptr; // path to GGUF containing MTP tensors
int mtp_gamma = 0; // max speculation depth
const char * mtp_draft_source = nullptr; // "chain" | "mtp_topk" | nullptr -> "chain"
bool mtp_use_topk = false; // false = chain (default), true = mtp_topk
int mtp_draft_topk = 1; // top-k for mtp_topk mode
};

Expand Down
22 changes: 20 additions & 2 deletions dflash/src/qwen35/qwen35_daemon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,29 @@ int run_qwen35_daemon(const Qwen35DaemonArgs & args) {
cfg.ddtree_temp = args.ddtree_temp;
cfg.ddtree_chain_seed = args.ddtree_chain_seed;
cfg.use_feature_mirror = args.use_feature_mirror;
cfg.mtp_gguf_path = args.mtp_gguf_path;
cfg.mtp_gamma = args.mtp_gamma;
cfg.mtp_draft_source = args.mtp_draft_source;
cfg.mtp_use_topk = args.mtp_use_topk;
cfg.mtp_draft_topk = args.mtp_draft_topk;

// Resolve MtpSource to the mtp_gguf_path sentinel that Qwen35Backend expects.
switch (args.mtp_source) {
case MtpSource::Native:
cfg.mtp_gguf_path = args.target_path;
break;
case MtpSource::ExternalDrafter:
cfg.mtp_gguf_path = args.mtp_gguf_path;
break;
case MtpSource::Auto:
cfg.mtp_gguf_path = dflash27b::gguf_contains_mtp_tensors(args.target_path)
? args.target_path
: nullptr;
break;
case MtpSource::None:
default:
cfg.mtp_gguf_path = nullptr;
break;
}

Qwen35Backend backend(cfg);
if (!backend.init()) return 1;

Expand Down
9 changes: 6 additions & 3 deletions dflash/src/qwen35/qwen35_daemon.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#pragma once

#include "common/backend_factory.h"
#include "device_placement.h"
#include <string>

Expand Down Expand Up @@ -35,10 +36,12 @@ struct Qwen35DaemonArgs {
bool ddtree_chain_seed = true;
bool use_feature_mirror = false;

// MTP (Multi-Token Prediction) speculator — mutually exclusive with draft
const char * mtp_gguf_path = nullptr; // path to fused MTP GGUF (or nullptr = DFlash)
// MTP (Multi-Token Prediction) speculator — mutually exclusive with draft.
// The daemon uses BackendArgs directly; these fields mirror BackendArgs.
MtpSource mtp_source = MtpSource::None;
const char * mtp_gguf_path = nullptr; // required only for ExternalDrafter
int mtp_gamma = 0; // max speculation depth
const char * mtp_draft_source = nullptr; // "chain" | "mtp_topk" | nullptr -> "chain"
bool mtp_use_topk = false; // false = chain, true = mtp_topk
int mtp_draft_topk = 1; // top-k for mtp_topk mode
};

Expand Down
74 changes: 62 additions & 12 deletions dflash/src/server/server_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,15 @@ static void print_usage(const char * prog) {
" --prefill-skip-park Skip park/unpark (for >=32GB GPUs)\n"
"\n"
"MTP speculative decoding (mutually exclusive with --draft):\n"
" --mtp-gguf <path> MTP drafter GGUF path\n"
" --mtp-source <none|native|external|auto>\n"
" MTP source (default: auto when --mtp-gamma given)\n"
" none = disable MTP\n"
" native = MTP heads in target GGUF (unsloth single-file)\n"
" external = separate GGUF via --mtp-gguf\n"
" auto = probe target GGUF; native if found, else none\n"
" --mtp-gguf <path> MTP GGUF path (required only for --mtp-source external)\n"
" --mtp-gamma <int> Speculation chain depth (default: 0 = disabled)\n"
" --mtp-draft-source <chain|mtp_topk> Draft strategy (default: chain)\n"
" --mtp-draft-topk <int> Top-k for mtp_topk mode (default: 1)\n"
" --mtp-draft-topk <int> Top-k draft strategy (default: chain; >1 enables mtp_topk)\n"
"\n", prog);
}

Expand Down Expand Up @@ -124,14 +129,34 @@ int main(int argc, char ** argv) {
sconfig.pflash_drafter_path = argv[++i];
} else if (std::strcmp(argv[i], "--prefill-skip-park") == 0) {
sconfig.pflash_skip_park = true;
} else if (std::strcmp(argv[i], "--mtp-source") == 0 && i + 1 < argc) {
const char * src = argv[++i];
if (std::strcmp(src, "none") == 0)
bargs.mtp_source = MtpSource::None;
else if (std::strcmp(src, "native") == 0)
bargs.mtp_source = MtpSource::Native;
else if (std::strcmp(src, "external") == 0)
bargs.mtp_source = MtpSource::ExternalDrafter;
else if (std::strcmp(src, "auto") == 0)
bargs.mtp_source = MtpSource::Auto;
else {
std::fprintf(stderr, "[server] unknown --mtp-source: '%s' (expected: none|native|external|auto)\n", src);
print_usage(argv[0]);
return 1;
}
} else if (std::strcmp(argv[i], "--mtp-gguf") == 0 && i + 1 < argc) {
bargs.mtp_gguf_path = argv[++i];
} else if (std::strcmp(argv[i], "--mtp-gamma") == 0 && i + 1 < argc) {
bargs.mtp_gamma = std::atoi(argv[++i]);
} else if (std::strcmp(argv[i], "--mtp-draft-source") == 0 && i + 1 < argc) {
bargs.mtp_draft_source = argv[++i];
++i; // consume the argument
std::fprintf(stderr,
"[server] WARNING: --mtp-draft-source is deprecated. "
"Use --mtp-source [none|native|external|auto] and "
"--mtp-draft-topk <N> instead.\n");
} else if (std::strcmp(argv[i], "--mtp-draft-topk") == 0 && i + 1 < argc) {
bargs.mtp_draft_topk = std::atoi(argv[++i]);
if (bargs.mtp_draft_topk > 1) bargs.mtp_use_topk = true;
} else if (std::strcmp(argv[i], "--cache-type-k") == 0 && i + 1 < argc) {
cache_type_k = argv[++i];
} else if (std::strcmp(argv[i], "--cache-type-v") == 0 && i + 1 < argc) {
Expand All @@ -150,10 +175,29 @@ int main(int argc, char ** argv) {
sconfig.max_ctx = bargs.device.max_ctx;
}

// --draft and --mtp-gguf are mutually exclusive; MTP wins if both are set.
if (bargs.draft_path && bargs.mtp_gguf_path) {
// Infer MtpSource from legacy flags when --mtp-source is absent.
// --mtp-gguf without --mtp-source → ExternalDrafter (backward compat)
// --mtp-gamma without --mtp-source → Auto (probe the target GGUF)
if (bargs.mtp_source == MtpSource::None) {
if (bargs.mtp_gguf_path) {
bargs.mtp_source = MtpSource::ExternalDrafter;
} else if (bargs.mtp_gamma > 0) {
bargs.mtp_source = MtpSource::Auto;
}
}

// Validate: ExternalDrafter requires --mtp-gguf.
if (bargs.mtp_source == MtpSource::ExternalDrafter && !bargs.mtp_gguf_path) {
std::fprintf(stderr,
"[server] ERROR: --mtp-source external requires --mtp-gguf <path>\n");
return 1;
}

// --draft and MTP are mutually exclusive; MTP wins if both are set.
const bool mtp_active = (bargs.mtp_source != MtpSource::None);
if (bargs.draft_path && mtp_active) {
std::fprintf(stderr,
"[server] WARNING: --draft and --mtp-gguf both set; ignoring --draft.\n"
"[server] WARNING: --draft and MTP both set; ignoring --draft.\n"
"[server] MTP speculation takes precedence over DFlash draft.\n");
bargs.draft_path = nullptr;
}
Expand All @@ -176,7 +220,7 @@ int main(int argc, char ** argv) {
// Default MTP head_kv capacity to backbone max_ctx so prompts up to max_ctx
// never overflow the head_kv buffer (the old hardcoded 8192 caused a silent
// server crash when agentic prompts exceeded that length).
if (bargs.mtp_gguf_path && sconfig.max_ctx > 0) {
if (mtp_active && sconfig.max_ctx > 0) {
char ctx_str[32];
std::snprintf(ctx_str, sizeof(ctx_str), "%d", sconfig.max_ctx);
setenv("DFLASH27B_MTP_CTX", ctx_str, 0); // don't overwrite user env
Expand Down Expand Up @@ -241,11 +285,17 @@ int main(int argc, char ** argv) {
std::fprintf(stderr, "[server] │ fa_window = %d\n", bargs.fa_window);
std::fprintf(stderr, "[server] │ ddtree = %s\n", bargs.ddtree_mode ? "ON" : "off");
std::fprintf(stderr, "[server] │ ddtree_budget = %d\n", bargs.ddtree_budget);
if (bargs.mtp_gguf_path) {
std::fprintf(stderr, "[server] │ mtp_gguf = %s\n", bargs.mtp_gguf_path);
if (mtp_active) {
const char * src_str =
bargs.mtp_source == MtpSource::Native ? "native" :
bargs.mtp_source == MtpSource::ExternalDrafter ? "external" :
bargs.mtp_source == MtpSource::Auto ? "auto" : "none";
std::fprintf(stderr, "[server] │ mtp_source = %s\n", src_str);
if (bargs.mtp_gguf_path)
std::fprintf(stderr, "[server] │ mtp_gguf = %s\n", bargs.mtp_gguf_path);
std::fprintf(stderr, "[server] │ mtp_gamma = %d\n", bargs.mtp_gamma);
std::fprintf(stderr, "[server] │ mtp_draft_src = %s\n",
bargs.mtp_draft_source ? bargs.mtp_draft_source : "chain (default)");
std::fprintf(stderr, "[server] │ mtp_draft_strat = %s\n",
bargs.mtp_use_topk ? "mtp_topk" : "chain (default)");
}
std::fprintf(stderr, "[server] │ cors = %s\n", sconfig.enable_cors ? "ON" : "off");
std::fprintf(stderr, "[server] │ cache_type_k = %s\n",
Expand Down
Loading