Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions dflash/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,11 @@ add_library(dflash_common STATIC
src/qwen35/qwen35_layer_split_dflash_target.cpp
src/qwen35/layer_split_daemon_loop.cpp
src/qwen35/qwen35_daemon.cpp
src/qwen35/qwen35_mtp.cpp
src/qwen35/qwen35_mtp_graph.cpp
src/qwen35/qwen35_mtp_loader.cpp
src/common/mtp_chain_runner.cpp
src/common/mtp_orchestrator.cpp
src/common/sampler.cpp
src/common/daemon_loop.cpp
src/common/gguf_inspect.cpp
Expand Down Expand Up @@ -521,6 +526,13 @@ if(DFLASH27B_TESTS)
target_include_directories(test_kv_quant PRIVATE ${DFLASH27B_SRC_INCLUDE_DIRS})
target_link_libraries(test_kv_quant PRIVATE dflash_common)
endif()
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_common_mtp_orchestrator.cpp")
add_executable(test_common_mtp_orchestrator test/test_common_mtp_orchestrator.cpp)
target_include_directories(test_common_mtp_orchestrator PRIVATE
${DFLASH27B_SRC_INCLUDE_DIRS}
${CMAKE_CURRENT_SOURCE_DIR}/deps/llama.cpp/ggml/include)
target_link_libraries(test_common_mtp_orchestrator PRIVATE dflash_common ggml-base)
endif()
if(EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/test/test_draft_vs_reference.cpp")
add_executable(test_draft_vs_reference test/test_draft_vs_reference.cpp)
target_link_libraries(test_draft_vs_reference PRIVATE dflash_common)
Expand Down
50 changes: 48 additions & 2 deletions dflash/scripts/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,11 @@ def build_app(target: Path, draft: Path | None, bin_path: Path, budget: int, max
extra_daemon_args: list[str] | None = None,
lazy_draft: bool = False,
verbose_daemon: bool = False,
force_no_thinking: bool = False) -> FastAPI:
force_no_thinking: bool = False,
mtp_gguf: Path | None = None,
mtp_gamma: int = 3,
mtp_draft_source: str = "chain",
mtp_draft_topk: int = 1) -> FastAPI:
import asyncio
if _extra_daemon_has_target_sharding(extra_daemon_args):
if prefix_cache_slots > 0 or prefill_cache_slots > 0:
Expand Down Expand Up @@ -850,6 +854,19 @@ async def _openai_compat_error_handler(_request: Request, exc: OpenAICompatError
cmd = [bin_abs, str(target), "--daemon",
f"--max-ctx={max_ctx}",
f"--stream-fd={stream_fd_val}"]
elif mtp_gguf is not None:
# MTP mode: no --draft (MTP head lives inside target or mtp_gguf),
# no DFlash flags. Daemon dispatches to MTP code path via --mtp-gguf.
cmd = [bin_abs, str(target), "--daemon",
f"--max-ctx={max_ctx}",
f"--stream-fd={stream_fd_val}",
f"--mtp-gguf={mtp_gguf}",
f"--gamma={mtp_gamma}",
"--draft-source", mtp_draft_source]
if mtp_draft_source == "mtp_topk":
cmd.append(f"--draft-topk={mtp_draft_topk}")
if extra_daemon_args:
cmd.extend(extra_daemon_args)
else:
if draft is None:
raise SystemExit("qwen35 arch requires --draft <draft.gguf|model.safetensors>")
Expand Down Expand Up @@ -2858,6 +2875,20 @@ def main():
help="Server-level guard: prevent any request from enabling thinking mode "
"via chat_template_kwargs. Useful on hardware (e.g. gfx1151/Strix Halo) "
"where thinking chains consume n_gen budget without benefit.")
# ── MTP (Multi-Token Prediction) speculator ──────────────────────────────
# When --mtp-gguf is set, the daemon runs MTP-head speculation instead of
# DFlash+DDTree. --draft is ignored (the MTP head is in the same GGUF as
# target, or a separate fused GGUF). Prefix-cache slots are auto-disabled
# in MTP mode because RESTORE does not snapshot MTP head KV yet.
ap.add_argument("--mtp-gguf", type=Path, default=None,
help="Path to MTP-fused GGUF. When set, daemon runs MTP "
"speculation; --draft and DFlash flags are ignored.")
ap.add_argument("--mtp-gamma", type=int, default=3,
help="MTP chain depth (default 3; recommended D=3 per matrix bench)")
ap.add_argument("--mtp-draft-source", choices=["chain", "mtp_topk"], default="chain",
help="MTP draft generation strategy (default chain)")
ap.add_argument("--mtp-draft-topk", type=int, default=1,
help="Top-K for mtp_topk draft source (default 1, ignored for chain)")
add_cli_flags(ap)
args = ap.parse_args()
prefill_cfg = config_from_args(args)
Expand Down Expand Up @@ -2906,6 +2937,17 @@ def main():
# through the laguna daemon now, so --prefill-compression and
# --prefix-cache-slots behave the same as on the qwen35 path.
draft = None
elif args.mtp_gguf is not None:
# MTP mode: --draft is ignored; MTP head lives in the target (or in --mtp-gguf
# if separate). Force prefix/prefill cache off — RESTORE doesn't snapshot
# MTP head KV yet (planned for a follow-up PR).
if not args.mtp_gguf.is_file():
raise SystemExit(f"--mtp-gguf not found at {args.mtp_gguf}")
draft = None
if args.prefix_cache_slots > 0 or args.prefill_cache_slots > 0:
print(" [cfg] MTP mode: disabling prefix/prefill cache (MTP head KV snapshot not implemented)")
args.prefix_cache_slots = 0
args.prefill_cache_slots = 0
else:
draft = resolve_draft(args.draft) if args.draft.is_dir() else args.draft
if not draft.is_file():
Expand Down Expand Up @@ -2938,7 +2980,11 @@ def main():
extra_daemon_args=placement.daemon_args or None,
lazy_draft=args.lazy_draft,
verbose_daemon=args.verbose_daemon,
force_no_thinking=args.no_thinking)
force_no_thinking=args.no_thinking,
mtp_gguf=args.mtp_gguf,
mtp_gamma=args.mtp_gamma,
mtp_draft_source=args.mtp_draft_source,
mtp_draft_topk=args.mtp_draft_topk)

import uvicorn
logging.basicConfig(
Expand Down
15 changes: 13 additions & 2 deletions dflash/src/common/attn_masks.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,24 @@ inline void build_causal_mask(std::vector<uint16_t> & out,
// Build an ancestor-only attention mask for DDTree tree-structured verify.
// Each query position i can attend to its ancestors in the tree (including
// itself) plus all past KV positions.
//
// kv_pad_override: when nonzero, pin the kv (column) stride to this value
// instead of the helper's natural `align_up(past_length + N - win_start,
// kq_stride_pad)`. Needed when the consumer tensor was allocated with a
// fixed kv extent (e.g. build_target_step_tree sizes sg.attn_mask at
// align_up(cache.max_ctx + N, kq_stride_pad)) and the helper-computed
// stride would not match the tensor's actual row pitch. Default 0 keeps
// existing behavior.
inline void build_tree_mask(const DDTree & tree, int past_length,
std::vector<uint16_t> & out_mask,
int kq_stride_pad,
int win_start = 0) {
int win_start = 0,
int kv_pad_override = 0) {
const int N = 1 + tree.n_nodes;
const int win_len = past_length + N - win_start;
const int kv_pad = align_up(win_len, kq_stride_pad);
const int kv_pad = kv_pad_override > 0
? align_up(kv_pad_override, kq_stride_pad)
: align_up(win_len, kq_stride_pad);
const int q_pad = align_up(N, KQ_MASK_PAD);
out_mask.assign((size_t)kv_pad * q_pad, F16_NEG_INF);
for (int q = 0; q < N; q++) {
Expand Down
60 changes: 60 additions & 0 deletions dflash/src/common/backend_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include "qwen3_backend.h"
#include "gemma4_backend.h"

#include "gguf.h"

#include <cassert>
#include <cstdio>

namespace dflash::common {
Expand All @@ -17,6 +20,26 @@ std::string detect_arch(const char * model_path) {
return info.arch;
}

bool gguf_contains_mtp_tensors(const std::string & path) {
gguf_init_params gp{};
gp.no_alloc = true;
gp.ctx = nullptr;
gguf_context * gguf = gguf_init_from_file(path.c_str(), gp);
if (!gguf) return false;

// MTP-capable GGUF files carry `qwen35.nextn_predict_layers` > 0.
// This is the canonical indicator used by qwen35_mtp_loader.cpp.
bool found = false;
int64_t kid = gguf_find_key(gguf, "qwen35.nextn_predict_layers");
if (kid >= 0) {
uint32_t n = gguf_get_val_u32(gguf, kid);
found = (n > 0);
}

gguf_free(gguf);
return found;
}

std::unique_ptr<ModelBackend> create_backend(const BackendArgs & args) {
if (!args.model_path) {
std::fprintf(stderr, "[backend_factory] model_path is null\n");
Expand All @@ -32,6 +55,22 @@ std::unique_ptr<ModelBackend> create_backend(const BackendArgs & args) {

std::fprintf(stderr, "[backend_factory] detected arch=%s\n", arch.c_str());

// Unset must have been resolved to None by arg parsing before reaching here.
assert(args.mtp_source != MtpSource::Unset &&
"MtpSource::Unset must be resolved by arg parsing before reaching the backend factory");

// Resolve MtpSource::Auto before constructing the backend.
MtpSource resolved_source = args.mtp_source;
if (resolved_source == MtpSource::Auto) {
if (gguf_contains_mtp_tensors(args.model_path)) {
std::fprintf(stderr, "[backend_factory] mtp=auto: nextn_predict_layers found -> Native\n");
resolved_source = MtpSource::Native;
} else {
std::fprintf(stderr, "[backend_factory] mtp=auto: no nextn_predict_layers -> None\n");
resolved_source = MtpSource::None;
}
}

if (arch == "qwen35") {
Qwen35Config cfg;
cfg.target_path = args.model_path;
Expand All @@ -50,6 +89,27 @@ std::unique_ptr<ModelBackend> create_backend(const BackendArgs & args) {
cfg.ddtree_temp = args.ddtree_temp;
cfg.ddtree_chain_seed = args.ddtree_chain_seed;
cfg.use_feature_mirror = args.use_feature_mirror;
cfg.mtp_gamma = args.mtp_gamma;
cfg.mtp_use_topk = args.mtp_use_topk;
cfg.mtp_draft_topk = args.mtp_draft_topk;

// Map resolved MtpSource to the paths Qwen35Backend expects.
// Qwen35Backend uses cfg_.mtp_gguf_path != nullptr as the MTP-active sentinel.
switch (resolved_source) {
case MtpSource::Native:
// MTP tensors live inside the target GGUF itself.
cfg.mtp_gguf_path = args.model_path;
break;
case MtpSource::ExternalDrafter:
cfg.mtp_gguf_path = args.mtp_gguf_path;
break;
case MtpSource::None:
case MtpSource::Auto: // fully resolved above; arm is unreachable.
case MtpSource::Unset: // guarded by assert above; arm is unreachable.
default:
cfg.mtp_gguf_path = nullptr;
break;
}

auto backend = std::make_unique<Qwen35Backend>(cfg);
if (!backend->init()) {
Expand Down
34 changes: 34 additions & 0 deletions dflash/src/common/backend_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@

namespace dflash::common {

// ─── MTP source selection ────────────────────────────────────────────────
// Replaces the old free-form mtp_draft_source string (@howard0su #237, line 59).
enum class MtpSource {
Unset, // internal sentinel: --mtp-source not provided (never escapes arg parsing)
None, // no MTP speculator
Native, // MTP heads co-located in the target GGUF (e.g. unsloth single-file)
ExternalDrafter, // separate MTP-head GGUF supplied via mtp_gguf_path
Auto, // probe target GGUF for nextn_predict_layers; Native if found, else None
};

// ─── Backend creation arguments ─────────────────────────────────────────
// A superset of all per-arch config fields. The factory reads only those
// relevant to the detected arch; unused fields are silently ignored.
Expand Down Expand Up @@ -51,15 +61,39 @@ struct BackendArgs {
float ddtree_temp = 1.0f;
bool ddtree_chain_seed = true;
bool use_feature_mirror = false;

// MTP (Multi-Token Prediction) speculator — mutually exclusive with --draft.
// mtp_source drives which loading path is taken:
// Unset → internal default; --mtp-source not provided; resolved to None after
// legacy-flag inference (never reaches the backend factory as Unset).
// None → MTP disabled; mtp_gguf_path ignored.
// Native → MTP heads embedded in model_path GGUF (single-file, e.g. unsloth).
// mtp_gguf_path is left nullptr; the factory sets it to model_path.
// ExternalDrafter→ Separate MTP-head GGUF at mtp_gguf_path (required).
// Auto → factory calls gguf_contains_mtp_tensors(model_path): if true,
// resolves to Native; otherwise resolves to None.
MtpSource mtp_source = MtpSource::Unset;
const char * mtp_gguf_path = nullptr; // required only for ExternalDrafter
int mtp_gamma = 0; // 0 = MTP loaded but not active; >0 = chain depth
bool mtp_use_topk = false; // false = chain (default), true = mtp_topk strategy
int mtp_draft_topk = 1;
};

// ─── Factory function ───────────────────────────────────────────────────
// Inspects model_path GGUF metadata, constructs the correct backend, and
// calls init(). Returns nullptr on failure (diagnostic printed to stderr).
// When args.mtp_source == Auto, resolves to Native or None before
// constructing; the resolved value is not written back into args.
std::unique_ptr<ModelBackend> create_backend(const BackendArgs & args);

// Returns the detected architecture string without creating a backend.
// Useful for early dispatch (e.g. printing which backend will be used).
std::string detect_arch(const char * model_path);

// Returns true if the GGUF at `path` contains MTP-head tensors.
// Heuristic: presence of `qwen35.nextn_predict_layers` metadata key with
// a value > 0. Pure metadata scan — no tensor allocation, no GPU touch.
// Used by create_backend() when mtp_source == Auto.
bool gguf_contains_mtp_tensors(const std::string & path);

} // namespace dflash::common
Loading
Loading